Compare commits

..

No commits in common. "main" and "v1.1.0" have entirely different histories.
main ... v1.1.0

3 changed files with 133 additions and 203 deletions

247
bot.py
View file

@ -1,5 +1,4 @@
import os import os
import signal
import sys import sys
import traceback import traceback
from typing import Union, Optional from typing import Union, Optional
@ -7,27 +6,12 @@ import aiofiles
import asyncio import asyncio
import uuid import uuid
import json import json
from nio import ( from nio import (AsyncClient, AsyncClientConfig, InviteMemberEvent, JoinError,
AsyncClient, KeyVerificationCancel, KeyVerificationEvent, DownloadError,
AsyncClientConfig, KeyVerificationKey, KeyVerificationMac, KeyVerificationStart,
InviteMemberEvent, LocalProtocolError, LoginResponse, MatrixRoom, MegolmEvent,
JoinError, RoomMessageAudio, RoomEncryptedAudio, ToDeviceError, crypto,
KeyVerificationCancel, EncryptionError)
KeyVerificationEvent,
DownloadError,
KeyVerificationKey,
KeyVerificationMac,
KeyVerificationStart,
LocalProtocolError,
LoginResponse,
MatrixRoom,
MegolmEvent,
RoomMessageAudio,
RoomEncryptedAudio,
ToDeviceError,
crypto,
EncryptionError,
)
from nio.store.database import SqliteStore from nio.store.database import SqliteStore
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
@ -56,11 +40,12 @@ class Bot:
num_workers: int = 1, num_workers: int = 1,
download_root: str = "models", download_root: str = "models",
): ):
if homeserver is None or user_id is None or device_id is None: if (homeserver is None or user_id is None
or device_id is None):
logger.warning("homeserver && user_id && device_id is required") logger.warning("homeserver && user_id && device_id is required")
sys.exit(1) sys.exit(1)
if password is None and access_token is None: if (password is None and access_token is None):
logger.warning("password or access_toekn is required") logger.warning("password or access_toekn is required")
sys.exit(1) sys.exit(1)
@ -102,36 +87,26 @@ class Bot:
# initialize AsyncClient object # initialize AsyncClient object
self.store_path = os.getcwd() self.store_path = os.getcwd()
self.config = AsyncClientConfig( self.config = AsyncClientConfig(store=SqliteStore,
store=SqliteStore,
store_name="db", store_name="db",
store_sync_tokens=True, store_sync_tokens=True,
encryption_enabled=True, encryption_enabled=True,
) )
self.client = AsyncClient( self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id,
homeserver=self.homeserver, config=self.config, store_path=self.store_path,)
user=self.user_id,
device_id=self.device_id,
config=self.config,
store_path=self.store_path,
)
if self.access_token is not None: if self.access_token is not None:
self.client.access_token = self.access_token self.client.access_token = self.access_token
# setup event callbacks # setup event callbacks
self.client.add_event_callback( self.client.add_event_callback(
self.message_callback, self.message_callback, (RoomMessageAudio, RoomEncryptedAudio, ))
( self.client.add_event_callback(
RoomMessageAudio, self.decryption_failure, (MegolmEvent, ))
RoomEncryptedAudio, self.client.add_event_callback(
), self.invite_callback, (InviteMemberEvent, ))
)
self.client.add_event_callback(self.decryption_failure, (MegolmEvent,))
self.client.add_event_callback(self.invite_callback, (InviteMemberEvent,))
self.client.add_to_device_callback( self.client.add_to_device_callback(
self.to_device_callback, (KeyVerificationEvent,) self.to_device_callback, (KeyVerificationEvent, ))
)
# intialize whisper model # intialize whisper model
self.model = WhisperModel( self.model = WhisperModel(
@ -140,19 +115,23 @@ class Bot:
compute_type=self.compute_type, compute_type=self.compute_type,
cpu_threads=self.cpu_threads, cpu_threads=self.cpu_threads,
num_workers=self.num_workers, num_workers=self.num_workers,
download_root=self.download_root, download_root=self.download_root,)
)
async def close(self, task: asyncio.Task = None) -> None: def __del__(self):
try:
loop = asyncio.get_running_loop()
except RuntimeError as e:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._close())
async def _close(self):
await self.client.close() await self.client.close()
task.cancel() logger.info("Bot stopped!")
logger.info("Bot closed!")
# message_callback event # message_callback event
async def message_callback(self, room: MatrixRoom,
async def message_callback( event: Union[RoomMessageAudio, RoomEncryptedAudio]) -> None:
self, room: MatrixRoom, event: Union[RoomMessageAudio, RoomEncryptedAudio]
) -> None:
if self.room_id is None: if self.room_id is None:
room_id = room.room_id room_id = room.room_id
else: else:
@ -199,8 +178,12 @@ class Bot:
await f.write( await f.write(
crypto.attachments.decrypt_attachment( crypto.attachments.decrypt_attachment(
media_data, media_data,
event.source["content"]["file"]["key"]["k"], event.source["content"]["file"]["key"][
event.source["content"]["file"]["hashes"]["sha256"], "k"
],
event.source["content"]["file"]["hashes"][
"sha256"
],
event.source["content"]["file"]["iv"], event.source["content"]["file"]["iv"],
) )
) )
@ -233,8 +216,8 @@ class Bot:
return return
logger.error( logger.error(
f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n" f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n" +
+ "Please make sure the bot current session is verified" "Please make sure the bot current session is verified"
) )
# invite_callback event # invite_callback event
@ -250,8 +233,7 @@ class Bot:
if type(result) == JoinError: if type(result) == JoinError:
logger.error( logger.error(
f"Error joining room {room.room_id} (attempt %d): %s", f"Error joining room {room.room_id} (attempt %d): %s",
attempt, attempt, result.message,
result.message,
) )
else: else:
break break
@ -273,11 +255,11 @@ class Bot:
try: try:
client = self.client client = self.client
logger.debug( logger.debug(
f"Device Event of type {type(event)} received in " "to_device_cb()." f"Device Event of type {type(event)} received in "
) "to_device_cb().")
if isinstance(event, KeyVerificationStart): # first step if isinstance(event, KeyVerificationStart): # first step
"""first step: receive KeyVerificationStart """ first step: receive KeyVerificationStart
KeyVerificationStart( KeyVerificationStart(
source={'content': source={'content':
{'method': 'm.sas.v1', {'method': 'm.sas.v1',
@ -307,14 +289,13 @@ class Bot:
""" """
if "emoji" not in event.short_authentication_string: if "emoji" not in event.short_authentication_string:
estr = ( estr = ("Other device does not support emoji verification "
"Other device does not support emoji verification " f"{event.short_authentication_string}. Aborting.")
f"{event.short_authentication_string}. Aborting."
)
print(estr) print(estr)
logger.info(estr) logger.info(estr)
return return
resp = await client.accept_key_verification(event.transaction_id) resp = await client.accept_key_verification(
event.transaction_id)
if isinstance(resp, ToDeviceError): if isinstance(resp, ToDeviceError):
estr = f"accept_key_verification() failed with {resp}" estr = f"accept_key_verification() failed with {resp}"
print(estr) print(estr)
@ -330,7 +311,7 @@ class Bot:
logger.info(estr) logger.info(estr)
elif isinstance(event, KeyVerificationCancel): # anytime elif isinstance(event, KeyVerificationCancel): # anytime
"""at any time: receive KeyVerificationCancel """ at any time: receive KeyVerificationCancel
KeyVerificationCancel(source={ KeyVerificationCancel(source={
'content': {'code': 'm.mismatched_sas', 'content': {'code': 'm.mismatched_sas',
'reason': 'Mismatched authentication string', 'reason': 'Mismatched authentication string',
@ -347,15 +328,13 @@ class Bot:
# client.cancel_key_verification(tx_id, reject=False) # client.cancel_key_verification(tx_id, reject=False)
# here. The SAS flow is already cancelled. # here. The SAS flow is already cancelled.
# We only need to inform the user. # We only need to inform the user.
estr = ( estr = (f"Verification has been cancelled by {event.sender} "
f"Verification has been cancelled by {event.sender} " f"for reason \"{event.reason}\".")
f'for reason "{event.reason}".'
)
print(estr) print(estr)
logger.info(estr) logger.info(estr)
elif isinstance(event, KeyVerificationKey): # second step elif isinstance(event, KeyVerificationKey): # second step
"""Second step is to receive KeyVerificationKey """ Second step is to receive KeyVerificationKey
KeyVerificationKey( KeyVerificationKey(
source={'content': { source={'content': {
'key': 'SomeCryptoKey', 'key': 'SomeCryptoKey',
@ -380,44 +359,42 @@ class Bot:
# automatic match, so we use y # automatic match, so we use y
yn = "y" yn = "y"
if yn.lower() == "y": if yn.lower() == "y":
estr = ( estr = ("Match! The verification for this "
"Match! The verification for this " "device will be accepted." "device will be accepted.")
)
print(estr) print(estr)
logger.info(estr) logger.info(estr)
resp = await client.confirm_short_auth_string(event.transaction_id) resp = await client.confirm_short_auth_string(
event.transaction_id)
if isinstance(resp, ToDeviceError): if isinstance(resp, ToDeviceError):
estr = "confirm_short_auth_string() " f"failed with {resp}" estr = ("confirm_short_auth_string() "
f"failed with {resp}")
print(estr) print(estr)
logger.info(estr) logger.info(estr)
elif yn.lower() == "n": # no, don't match, reject elif yn.lower() == "n": # no, don't match, reject
estr = ( estr = ("No match! Device will NOT be verified "
"No match! Device will NOT be verified " "by rejecting verification.")
"by rejecting verification."
)
print(estr) print(estr)
logger.info(estr) logger.info(estr)
resp = await client.cancel_key_verification( resp = await client.cancel_key_verification(
event.transaction_id, reject=True event.transaction_id, reject=True)
)
if isinstance(resp, ToDeviceError): if isinstance(resp, ToDeviceError):
estr = f"cancel_key_verification failed with {resp}" estr = (f"cancel_key_verification failed with {resp}")
print(estr) print(estr)
logger.info(estr) logger.info(estr)
else: # C or anything for cancel else: # C or anything for cancel
estr = "Cancelled by user! Verification will be " "cancelled." estr = ("Cancelled by user! Verification will be "
"cancelled.")
print(estr) print(estr)
logger.info(estr) logger.info(estr)
resp = await client.cancel_key_verification( resp = await client.cancel_key_verification(
event.transaction_id, reject=False event.transaction_id, reject=False)
)
if isinstance(resp, ToDeviceError): if isinstance(resp, ToDeviceError):
estr = f"cancel_key_verification failed with {resp}" estr = (f"cancel_key_verification failed with {resp}")
print(estr) print(estr)
logger.info(estr) logger.info(estr)
elif isinstance(event, KeyVerificationMac): # third step elif isinstance(event, KeyVerificationMac): # third step
"""Third step is to receive KeyVerificationMac """ Third step is to receive KeyVerificationMac
KeyVerificationMac( KeyVerificationMac(
source={'content': { source={'content': {
'mac': {'ed25519:DEVICEIDXY': 'SomeKey1', 'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
@ -437,11 +414,9 @@ class Bot:
todevice_msg = sas.get_mac() todevice_msg = sas.get_mac()
except LocalProtocolError as e: except LocalProtocolError as e:
# e.g. it might have been cancelled by ourselves # e.g. it might have been cancelled by ourselves
estr = ( estr = (f"Cancelled or protocol error: Reason: {e}.\n"
f"Cancelled or protocol error: Reason: {e}.\n"
f"Verification with {event.sender} not concluded. " f"Verification with {event.sender} not concluded. "
"Try again?" "Try again?")
)
print(estr) print(estr)
logger.info(estr) logger.info(estr)
else: else:
@ -450,31 +425,25 @@ class Bot:
estr = f"to_device failed with {resp}" estr = f"to_device failed with {resp}"
print(estr) print(estr)
logger.info(estr) logger.info(estr)
estr = ( estr = (f"sas.we_started_it = {sas.we_started_it}\n"
f"sas.we_started_it = {sas.we_started_it}\n"
f"sas.sas_accepted = {sas.sas_accepted}\n" f"sas.sas_accepted = {sas.sas_accepted}\n"
f"sas.canceled = {sas.canceled}\n" f"sas.canceled = {sas.canceled}\n"
f"sas.timed_out = {sas.timed_out}\n" f"sas.timed_out = {sas.timed_out}\n"
f"sas.verified = {sas.verified}\n" f"sas.verified = {sas.verified}\n"
f"sas.verified_devices = {sas.verified_devices}\n" f"sas.verified_devices = {sas.verified_devices}\n")
)
print(estr) print(estr)
logger.info(estr) logger.info(estr)
estr = ( estr = ("Emoji verification was successful!\n"
"Emoji verification was successful!\n"
"Initiate another Emoji verification from " "Initiate another Emoji verification from "
"another device or room if desired. " "another device or room if desired. "
"Or if done verifying, hit Control-C to stop the " "Or if done verifying, hit Control-C to stop the "
"bot in order to restart it as a service or to " "bot in order to restart it as a service or to "
"run it in the background." "run it in the background.")
)
print(estr) print(estr)
logger.info(estr) logger.info(estr)
else: else:
estr = ( estr = (f"Received unexpected event type {type(event)}. "
f"Received unexpected event type {type(event)}. " f"Event is {event}. Event will be ignored.")
f"Event is {event}. Event will be ignored."
)
print(estr) print(estr)
logger.info(estr) logger.info(estr)
except BaseException: except BaseException:
@ -483,6 +452,7 @@ class Bot:
logger.info(estr) logger.info(estr)
# bot login # bot login
async def login(self) -> None: async def login(self) -> None:
if self.access_token is not None: if self.access_token is not None:
logger.info("Login via access_token") logger.info("Login via access_token")
@ -510,14 +480,14 @@ class Bot:
# import keys # import keys
async def import_keys(self): async def import_keys(self):
resp = await self.client.import_keys( resp = await self.client.import_keys(
self.import_keys_path, self.import_keys_password self.import_keys_path,
self.import_keys_password
) )
if isinstance(resp, EncryptionError): if isinstance(resp, EncryptionError):
logger.error(f"import_keys failed with {resp}") logger.error(f"import_keys failed with {resp}")
else: else:
logger.info( logger.info(
f"import_keys success, please remove import_keys configuration!!!" f"import_keys success, please remove import_keys configuration!!!")
)
# whisper function # whisper function
def transcribe(self, filename: str) -> str: def transcribe(self, filename: str) -> str:
@ -537,33 +507,30 @@ async def main():
config = json.load(fp) config = json.load(fp)
bot = Bot( bot = Bot(
homeserver=config.get("homeserver"), homeserver=config.get('homeserver'),
user_id=config.get("user_id"), user_id=config.get('user_id'),
password=config.get("password"), password=config.get('password'),
device_id=config.get("device_id"), device_id=config.get('device_id'),
room_id=config.get("room_id"), room_id=config.get('room_id'),
access_token=config.get("access_token"), access_token=config.get('access_token'),
import_keys_path=config.get("import_keys_path"), import_keys_path=config.get('import_keys_path'),
import_keys_password=config.get("import_keys_password"), import_keys_password=config.get('import_keys_password'),
model_size=config.get("model_size"), model_size=config.get('model_size'),
device=config.get("device"), device=config.get('device'),
compute_type=config.get("compute_type"), compute_type=config.get('compute_type'),
cpu_threads=config.get("cpu_threads"), cpu_threads=config.get('cpu_threads'),
num_workers=config.get("num_workers"), num_workers=config.get('num_workers'),
download_root=config.get("download_root"), download_root=config.get('download_root'),
) )
if ( if config.get('import_keys_path') and config.get('import_keys_password') is not None:
config.get("import_keys_path")
and config.get("import_keys_password") is not None
):
need_import_keys = True need_import_keys = True
else: else:
bot = Bot( bot = Bot(
homeserver=os.environ.get("HOMESERVER"), homeserver=os.environ.get('HOMESERVER'),
user_id=os.environ.get("USER_ID"), user_id=os.environ.get('USER_ID'),
password=os.environ.get("PASSWORD"), password=os.environ.get('PASSWORD'),
device_id=os.environ.get("DEVICE_ID"), device_id=os.environ.get("DEVICE_ID"),
room_id=os.environ.get("ROOM_ID"), room_id=os.environ.get("ROOM_ID"),
access_token=os.environ.get("ACCESS_TOKEN"), access_token=os.environ.get("ACCESS_TOKEN"),
@ -576,10 +543,7 @@ async def main():
num_workers=os.environ.get("NUM_WORKERS"), num_workers=os.environ.get("NUM_WORKERS"),
download_root=os.environ.get("DOWNLOAD_ROOT"), download_root=os.environ.get("DOWNLOAD_ROOT"),
) )
if ( if os.environ.get("IMPORT_KEYS_PATH") and os.environ.get("IMPORT_KEYS_PASSWORD") is not None:
os.environ.get("IMPORT_KEYS_PATH")
and os.environ.get("IMPORT_KEYS_PASSWORD") is not None
):
need_import_keys = True need_import_keys = True
await bot.login() await bot.login()
@ -587,21 +551,8 @@ async def main():
logger.info("start import_keys process, this may take a while...") logger.info("start import_keys process, this may take a while...")
await bot.import_keys() await bot.import_keys()
sync_task = asyncio.create_task(bot.sync_forever()) await bot.sync_forever()
# handle signal interrupt if __name__ == '__main__':
loop = asyncio.get_running_loop()
for signame in (
"SIGINT",
"SIGTERM",
):
loop.add_signal_handler(
getattr(signal, signame), lambda: asyncio.create_task(bot.close(sync_task))
)
await sync_task
if __name__ == "__main__":
logger.info("Bot started!") logger.info("Bot started!")
asyncio.run(main()) asyncio.run(main())

10
log.py
View file

@ -9,19 +9,17 @@ def getlogger():
# create handlers # create handlers
warn_handler = logging.StreamHandler() warn_handler = logging.StreamHandler()
info_handler = logging.StreamHandler() info_handler = logging.StreamHandler()
error_handler = logging.FileHandler("bot.log", mode="a") error_handler = logging.FileHandler('bot.log', mode='a')
warn_handler.setLevel(logging.WARNING) warn_handler.setLevel(logging.WARNING)
error_handler.setLevel(logging.ERROR) error_handler.setLevel(logging.ERROR)
info_handler.setLevel(logging.INFO) info_handler.setLevel(logging.INFO)
# create formatters # create formatters
warn_format = logging.Formatter( warn_format = logging.Formatter(
"%(asctime)s - %(funcName)s - %(levelname)s - %(message)s" '%(asctime)s - %(funcName)s - %(levelname)s - %(message)s')
)
error_format = logging.Formatter( error_format = logging.Formatter(
"%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s" '%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s')
) info_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
info_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
# set formatter # set formatter
warn_handler.setFormatter(warn_format) warn_handler.setFormatter(warn_format)

View file

@ -1,42 +1,23 @@
from nio import AsyncClient from nio import AsyncClient
async def send_room_message(client: AsyncClient,
async def send_room_message(
client: AsyncClient,
room_id: str, room_id: str,
reply_message: str, reply_message: str,
sender_id: str = "", sender_id: str = '',
reply_to_event_id: str = "", reply_to_event_id: str = '',
) -> None: ) -> None:
NORMAL_BODY = content = { NORMAL_BODY = content = {"msgtype": "m.text", "body": reply_message, }
"msgtype": "m.text", if reply_to_event_id == '':
"body": reply_message,
}
if reply_to_event_id == "":
content = NORMAL_BODY content = NORMAL_BODY
else: else:
body = r"> <" + sender_id + r"> sent an audio file.\n\n" + reply_message body = r'> <' + sender_id + r'> sent an audio file.\n\n' + reply_message
format = r"org.matrix.custom.html" format = r'org.matrix.custom.html'
formatted_body = ( formatted_body = r'<mx-reply><blockquote><a href="https://matrix.to/#/' + room_id + r'/' + reply_to_event_id \
r'<mx-reply><blockquote><a href="https://matrix.to/#/' + r'">In reply to</a> <a href="https://matrix.to/#/' + sender_id + r'">' + sender_id \
+ room_id + r'</a><br>sent an audio file.</blockquote></mx-reply>' + reply_message
+ r"/"
+ reply_to_event_id
+ r'">In reply to</a> <a href="https://matrix.to/#/'
+ sender_id
+ r'">'
+ sender_id
+ r"</a><br>sent an audio file.</blockquote></mx-reply>"
+ reply_message
)
content = { content = {"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body,
"msgtype": "m.text", "m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
"body": body,
"format": format,
"formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}},
}
await client.room_send( await client.room_send(
room_id, room_id,
message_type="m.room.message", message_type="m.room.message",