Compare commits

...

1 commit
v1.1.0 ... main

Author SHA1 Message Date
f6292b1a13
v1.1.1
Gracefully shutdown program

format the code
2023-07-21 12:58:39 +08:00
3 changed files with 204 additions and 134 deletions

239
bot.py
View file

@ -1,4 +1,5 @@
import os
import signal
import sys
import traceback
from typing import Union, Optional
@ -6,12 +7,27 @@ import aiofiles
import asyncio
import uuid
import json
from nio import (AsyncClient, AsyncClientConfig, InviteMemberEvent, JoinError,
KeyVerificationCancel, KeyVerificationEvent, DownloadError,
KeyVerificationKey, KeyVerificationMac, KeyVerificationStart,
LocalProtocolError, LoginResponse, MatrixRoom, MegolmEvent,
RoomMessageAudio, RoomEncryptedAudio, ToDeviceError, crypto,
EncryptionError)
from nio import (
AsyncClient,
AsyncClientConfig,
InviteMemberEvent,
JoinError,
KeyVerificationCancel,
KeyVerificationEvent,
DownloadError,
KeyVerificationKey,
KeyVerificationMac,
KeyVerificationStart,
LocalProtocolError,
LoginResponse,
MatrixRoom,
MegolmEvent,
RoomMessageAudio,
RoomEncryptedAudio,
ToDeviceError,
crypto,
EncryptionError,
)
from nio.store.database import SqliteStore
from faster_whisper import WhisperModel
@ -40,12 +56,11 @@ class Bot:
num_workers: int = 1,
download_root: str = "models",
):
if (homeserver is None or user_id is None
or device_id is None):
if homeserver is None or user_id is None or device_id is None:
logger.warning("homeserver && user_id && device_id is required")
sys.exit(1)
if (password is None and access_token is None):
if password is None and access_token is None:
logger.warning("password or access_toekn is required")
sys.exit(1)
@ -87,26 +102,36 @@ class Bot:
# initialize AsyncClient object
self.store_path = os.getcwd()
self.config = AsyncClientConfig(store=SqliteStore,
self.config = AsyncClientConfig(
store=SqliteStore,
store_name="db",
store_sync_tokens=True,
encryption_enabled=True,
)
self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id,
config=self.config, store_path=self.store_path,)
self.client = AsyncClient(
homeserver=self.homeserver,
user=self.user_id,
device_id=self.device_id,
config=self.config,
store_path=self.store_path,
)
if self.access_token is not None:
self.client.access_token = self.access_token
# setup event callbacks
self.client.add_event_callback(
self.message_callback, (RoomMessageAudio, RoomEncryptedAudio, ))
self.client.add_event_callback(
self.decryption_failure, (MegolmEvent, ))
self.client.add_event_callback(
self.invite_callback, (InviteMemberEvent, ))
self.message_callback,
(
RoomMessageAudio,
RoomEncryptedAudio,
),
)
self.client.add_event_callback(self.decryption_failure, (MegolmEvent,))
self.client.add_event_callback(self.invite_callback, (InviteMemberEvent,))
self.client.add_to_device_callback(
self.to_device_callback, (KeyVerificationEvent, ))
self.to_device_callback, (KeyVerificationEvent,)
)
# intialize whisper model
self.model = WhisperModel(
@ -115,23 +140,19 @@ class Bot:
compute_type=self.compute_type,
cpu_threads=self.cpu_threads,
num_workers=self.num_workers,
download_root=self.download_root,)
download_root=self.download_root,
)
def __del__(self):
try:
loop = asyncio.get_running_loop()
except RuntimeError as e:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._close())
async def _close(self):
async def close(self, task: asyncio.Task = None) -> None:
await self.client.close()
logger.info("Bot stopped!")
task.cancel()
logger.info("Bot closed!")
# message_callback event
async def message_callback(self, room: MatrixRoom,
event: Union[RoomMessageAudio, RoomEncryptedAudio]) -> None:
async def message_callback(
self, room: MatrixRoom, event: Union[RoomMessageAudio, RoomEncryptedAudio]
) -> None:
if self.room_id is None:
room_id = room.room_id
else:
@ -178,12 +199,8 @@ class Bot:
await f.write(
crypto.attachments.decrypt_attachment(
media_data,
event.source["content"]["file"]["key"][
"k"
],
event.source["content"]["file"]["hashes"][
"sha256"
],
event.source["content"]["file"]["key"]["k"],
event.source["content"]["file"]["hashes"]["sha256"],
event.source["content"]["file"]["iv"],
)
)
@ -216,8 +233,8 @@ class Bot:
return
logger.error(
f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n" +
"Please make sure the bot current session is verified"
f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n"
+ "Please make sure the bot current session is verified"
)
# invite_callback event
@ -233,7 +250,8 @@ class Bot:
if type(result) == JoinError:
logger.error(
f"Error joining room {room.room_id} (attempt %d): %s",
attempt, result.message,
attempt,
result.message,
)
else:
break
@ -255,8 +273,8 @@ class Bot:
try:
client = self.client
logger.debug(
f"Device Event of type {type(event)} received in "
"to_device_cb().")
f"Device Event of type {type(event)} received in " "to_device_cb()."
)
if isinstance(event, KeyVerificationStart): # first step
"""first step: receive KeyVerificationStart
@ -289,13 +307,14 @@ class Bot:
"""
if "emoji" not in event.short_authentication_string:
estr = ("Other device does not support emoji verification "
f"{event.short_authentication_string}. Aborting.")
estr = (
"Other device does not support emoji verification "
f"{event.short_authentication_string}. Aborting."
)
print(estr)
logger.info(estr)
return
resp = await client.accept_key_verification(
event.transaction_id)
resp = await client.accept_key_verification(event.transaction_id)
if isinstance(resp, ToDeviceError):
estr = f"accept_key_verification() failed with {resp}"
print(estr)
@ -328,8 +347,10 @@ class Bot:
# client.cancel_key_verification(tx_id, reject=False)
# here. The SAS flow is already cancelled.
# We only need to inform the user.
estr = (f"Verification has been cancelled by {event.sender} "
f"for reason \"{event.reason}\".")
estr = (
f"Verification has been cancelled by {event.sender} "
f'for reason "{event.reason}".'
)
print(estr)
logger.info(estr)
@ -359,37 +380,39 @@ class Bot:
# automatic match, so we use y
yn = "y"
if yn.lower() == "y":
estr = ("Match! The verification for this "
"device will be accepted.")
estr = (
"Match! The verification for this " "device will be accepted."
)
print(estr)
logger.info(estr)
resp = await client.confirm_short_auth_string(
event.transaction_id)
resp = await client.confirm_short_auth_string(event.transaction_id)
if isinstance(resp, ToDeviceError):
estr = ("confirm_short_auth_string() "
f"failed with {resp}")
estr = "confirm_short_auth_string() " f"failed with {resp}"
print(estr)
logger.info(estr)
elif yn.lower() == "n": # no, don't match, reject
estr = ("No match! Device will NOT be verified "
"by rejecting verification.")
estr = (
"No match! Device will NOT be verified "
"by rejecting verification."
)
print(estr)
logger.info(estr)
resp = await client.cancel_key_verification(
event.transaction_id, reject=True)
event.transaction_id, reject=True
)
if isinstance(resp, ToDeviceError):
estr = (f"cancel_key_verification failed with {resp}")
estr = f"cancel_key_verification failed with {resp}"
print(estr)
logger.info(estr)
else: # C or anything for cancel
estr = ("Cancelled by user! Verification will be "
"cancelled.")
estr = "Cancelled by user! Verification will be " "cancelled."
print(estr)
logger.info(estr)
resp = await client.cancel_key_verification(
event.transaction_id, reject=False)
event.transaction_id, reject=False
)
if isinstance(resp, ToDeviceError):
estr = (f"cancel_key_verification failed with {resp}")
estr = f"cancel_key_verification failed with {resp}"
print(estr)
logger.info(estr)
@ -414,9 +437,11 @@ class Bot:
todevice_msg = sas.get_mac()
except LocalProtocolError as e:
# e.g. it might have been cancelled by ourselves
estr = (f"Cancelled or protocol error: Reason: {e}.\n"
estr = (
f"Cancelled or protocol error: Reason: {e}.\n"
f"Verification with {event.sender} not concluded. "
"Try again?")
"Try again?"
)
print(estr)
logger.info(estr)
else:
@ -425,25 +450,31 @@ class Bot:
estr = f"to_device failed with {resp}"
print(estr)
logger.info(estr)
estr = (f"sas.we_started_it = {sas.we_started_it}\n"
estr = (
f"sas.we_started_it = {sas.we_started_it}\n"
f"sas.sas_accepted = {sas.sas_accepted}\n"
f"sas.canceled = {sas.canceled}\n"
f"sas.timed_out = {sas.timed_out}\n"
f"sas.verified = {sas.verified}\n"
f"sas.verified_devices = {sas.verified_devices}\n")
f"sas.verified_devices = {sas.verified_devices}\n"
)
print(estr)
logger.info(estr)
estr = ("Emoji verification was successful!\n"
estr = (
"Emoji verification was successful!\n"
"Initiate another Emoji verification from "
"another device or room if desired. "
"Or if done verifying, hit Control-C to stop the "
"bot in order to restart it as a service or to "
"run it in the background.")
"run it in the background."
)
print(estr)
logger.info(estr)
else:
estr = (f"Received unexpected event type {type(event)}. "
f"Event is {event}. Event will be ignored.")
estr = (
f"Received unexpected event type {type(event)}. "
f"Event is {event}. Event will be ignored."
)
print(estr)
logger.info(estr)
except BaseException:
@ -452,7 +483,6 @@ class Bot:
logger.info(estr)
# bot login
async def login(self) -> None:
if self.access_token is not None:
logger.info("Login via access_token")
@ -480,14 +510,14 @@ class Bot:
# import keys
async def import_keys(self):
resp = await self.client.import_keys(
self.import_keys_path,
self.import_keys_password
self.import_keys_path, self.import_keys_password
)
if isinstance(resp, EncryptionError):
logger.error(f"import_keys failed with {resp}")
else:
logger.info(
f"import_keys success, please remove import_keys configuration!!!")
f"import_keys success, please remove import_keys configuration!!!"
)
# whisper function
def transcribe(self, filename: str) -> str:
@ -507,30 +537,33 @@ async def main():
config = json.load(fp)
bot = Bot(
homeserver=config.get('homeserver'),
user_id=config.get('user_id'),
password=config.get('password'),
device_id=config.get('device_id'),
room_id=config.get('room_id'),
access_token=config.get('access_token'),
import_keys_path=config.get('import_keys_path'),
import_keys_password=config.get('import_keys_password'),
model_size=config.get('model_size'),
device=config.get('device'),
compute_type=config.get('compute_type'),
cpu_threads=config.get('cpu_threads'),
num_workers=config.get('num_workers'),
download_root=config.get('download_root'),
homeserver=config.get("homeserver"),
user_id=config.get("user_id"),
password=config.get("password"),
device_id=config.get("device_id"),
room_id=config.get("room_id"),
access_token=config.get("access_token"),
import_keys_path=config.get("import_keys_path"),
import_keys_password=config.get("import_keys_password"),
model_size=config.get("model_size"),
device=config.get("device"),
compute_type=config.get("compute_type"),
cpu_threads=config.get("cpu_threads"),
num_workers=config.get("num_workers"),
download_root=config.get("download_root"),
)
if config.get('import_keys_path') and config.get('import_keys_password') is not None:
if (
config.get("import_keys_path")
and config.get("import_keys_password") is not None
):
need_import_keys = True
else:
bot = Bot(
homeserver=os.environ.get('HOMESERVER'),
user_id=os.environ.get('USER_ID'),
password=os.environ.get('PASSWORD'),
homeserver=os.environ.get("HOMESERVER"),
user_id=os.environ.get("USER_ID"),
password=os.environ.get("PASSWORD"),
device_id=os.environ.get("DEVICE_ID"),
room_id=os.environ.get("ROOM_ID"),
access_token=os.environ.get("ACCESS_TOKEN"),
@ -543,7 +576,10 @@ async def main():
num_workers=os.environ.get("NUM_WORKERS"),
download_root=os.environ.get("DOWNLOAD_ROOT"),
)
if os.environ.get("IMPORT_KEYS_PATH") and os.environ.get("IMPORT_KEYS_PASSWORD") is not None:
if (
os.environ.get("IMPORT_KEYS_PATH")
and os.environ.get("IMPORT_KEYS_PASSWORD") is not None
):
need_import_keys = True
await bot.login()
@ -551,8 +587,21 @@ async def main():
logger.info("start import_keys process, this may take a while...")
await bot.import_keys()
await bot.sync_forever()
sync_task = asyncio.create_task(bot.sync_forever())
if __name__ == '__main__':
# handle signal interrupt
loop = asyncio.get_running_loop()
for signame in (
"SIGINT",
"SIGTERM",
):
loop.add_signal_handler(
getattr(signal, signame), lambda: asyncio.create_task(bot.close(sync_task))
)
await sync_task
if __name__ == "__main__":
logger.info("Bot started!")
asyncio.run(main())

10
log.py
View file

@ -9,17 +9,19 @@ def getlogger():
# create handlers
warn_handler = logging.StreamHandler()
info_handler = logging.StreamHandler()
error_handler = logging.FileHandler('bot.log', mode='a')
error_handler = logging.FileHandler("bot.log", mode="a")
warn_handler.setLevel(logging.WARNING)
error_handler.setLevel(logging.ERROR)
info_handler.setLevel(logging.INFO)
# create formatters
warn_format = logging.Formatter(
'%(asctime)s - %(funcName)s - %(levelname)s - %(message)s')
"%(asctime)s - %(funcName)s - %(levelname)s - %(message)s"
)
error_format = logging.Formatter(
'%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s')
info_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
"%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s"
)
info_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
# set formatter
warn_handler.setFormatter(warn_format)

View file

@ -1,23 +1,42 @@
from nio import AsyncClient
async def send_room_message(client: AsyncClient,
async def send_room_message(
client: AsyncClient,
room_id: str,
reply_message: str,
sender_id: str = '',
reply_to_event_id: str = '',
sender_id: str = "",
reply_to_event_id: str = "",
) -> None:
NORMAL_BODY = content = {"msgtype": "m.text", "body": reply_message, }
if reply_to_event_id == '':
NORMAL_BODY = content = {
"msgtype": "m.text",
"body": reply_message,
}
if reply_to_event_id == "":
content = NORMAL_BODY
else:
body = r'> <' + sender_id + r'> sent an audio file.\n\n' + reply_message
format = r'org.matrix.custom.html'
formatted_body = r'<mx-reply><blockquote><a href="https://matrix.to/#/' + room_id + r'/' + reply_to_event_id \
+ r'">In reply to</a> <a href="https://matrix.to/#/' + sender_id + r'">' + sender_id \
+ r'</a><br>sent an audio file.</blockquote></mx-reply>' + reply_message
body = r"> <" + sender_id + r"> sent an audio file.\n\n" + reply_message
format = r"org.matrix.custom.html"
formatted_body = (
r'<mx-reply><blockquote><a href="https://matrix.to/#/'
+ room_id
+ r"/"
+ reply_to_event_id
+ r'">In reply to</a> <a href="https://matrix.to/#/'
+ sender_id
+ r'">'
+ sender_id
+ r"</a><br>sent an audio file.</blockquote></mx-reply>"
+ reply_message
)
content = {"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
content = {
"msgtype": "m.text",
"body": body,
"format": format,
"formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}},
}
await client.room_send(
room_id,
message_type="m.room.message",