Gracefully shutdown program

format the code
This commit is contained in:
hibobmaster 2023-07-21 10:41:47 +08:00
parent f20b1d5d40
commit f6292b1a13
Signed by: bobmaster
SSH key fingerprint: SHA256:5ZYgd8fg+PcNZNy4SzcSKu5JtqZyBF8kUhY7/k2viDk
3 changed files with 204 additions and 134 deletions

275
bot.py
View file

@ -1,4 +1,5 @@
import os
import signal
import sys
import traceback
from typing import Union, Optional
@ -6,12 +7,27 @@ import aiofiles
import asyncio
import uuid
import json
from nio import (AsyncClient, AsyncClientConfig, InviteMemberEvent, JoinError,
KeyVerificationCancel, KeyVerificationEvent, DownloadError,
KeyVerificationKey, KeyVerificationMac, KeyVerificationStart,
LocalProtocolError, LoginResponse, MatrixRoom, MegolmEvent,
RoomMessageAudio, RoomEncryptedAudio, ToDeviceError, crypto,
EncryptionError)
from nio import (
AsyncClient,
AsyncClientConfig,
InviteMemberEvent,
JoinError,
KeyVerificationCancel,
KeyVerificationEvent,
DownloadError,
KeyVerificationKey,
KeyVerificationMac,
KeyVerificationStart,
LocalProtocolError,
LoginResponse,
MatrixRoom,
MegolmEvent,
RoomMessageAudio,
RoomEncryptedAudio,
ToDeviceError,
crypto,
EncryptionError,
)
from nio.store.database import SqliteStore
from faster_whisper import WhisperModel
@ -40,12 +56,11 @@ class Bot:
num_workers: int = 1,
download_root: str = "models",
):
if (homeserver is None or user_id is None
or device_id is None):
if homeserver is None or user_id is None or device_id is None:
logger.warning("homeserver && user_id && device_id is required")
sys.exit(1)
if (password is None and access_token is None):
if password is None and access_token is None:
logger.warning("password or access_toekn is required")
sys.exit(1)
@ -87,26 +102,36 @@ class Bot:
# initialize AsyncClient object
self.store_path = os.getcwd()
self.config = AsyncClientConfig(store=SqliteStore,
store_name="db",
store_sync_tokens=True,
encryption_enabled=True,
)
self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id,
config=self.config, store_path=self.store_path,)
self.config = AsyncClientConfig(
store=SqliteStore,
store_name="db",
store_sync_tokens=True,
encryption_enabled=True,
)
self.client = AsyncClient(
homeserver=self.homeserver,
user=self.user_id,
device_id=self.device_id,
config=self.config,
store_path=self.store_path,
)
if self.access_token is not None:
self.client.access_token = self.access_token
# setup event callbacks
self.client.add_event_callback(
self.message_callback, (RoomMessageAudio, RoomEncryptedAudio, ))
self.client.add_event_callback(
self.decryption_failure, (MegolmEvent, ))
self.client.add_event_callback(
self.invite_callback, (InviteMemberEvent, ))
self.message_callback,
(
RoomMessageAudio,
RoomEncryptedAudio,
),
)
self.client.add_event_callback(self.decryption_failure, (MegolmEvent,))
self.client.add_event_callback(self.invite_callback, (InviteMemberEvent,))
self.client.add_to_device_callback(
self.to_device_callback, (KeyVerificationEvent, ))
self.to_device_callback, (KeyVerificationEvent,)
)
# intialize whisper model
self.model = WhisperModel(
@ -115,23 +140,19 @@ class Bot:
compute_type=self.compute_type,
cpu_threads=self.cpu_threads,
num_workers=self.num_workers,
download_root=self.download_root,)
download_root=self.download_root,
)
def __del__(self):
try:
loop = asyncio.get_running_loop()
except RuntimeError as e:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._close())
async def _close(self):
async def close(self, task: asyncio.Task = None) -> None:
await self.client.close()
logger.info("Bot stopped!")
task.cancel()
logger.info("Bot closed!")
# message_callback event
async def message_callback(self, room: MatrixRoom,
event: Union[RoomMessageAudio, RoomEncryptedAudio]) -> None:
# message_callback event
async def message_callback(
self, room: MatrixRoom, event: Union[RoomMessageAudio, RoomEncryptedAudio]
) -> None:
if self.room_id is None:
room_id = room.room_id
else:
@ -178,12 +199,8 @@ class Bot:
await f.write(
crypto.attachments.decrypt_attachment(
media_data,
event.source["content"]["file"]["key"][
"k"
],
event.source["content"]["file"]["hashes"][
"sha256"
],
event.source["content"]["file"]["key"]["k"],
event.source["content"]["file"]["hashes"]["sha256"],
event.source["content"]["file"]["iv"],
)
)
@ -216,8 +233,8 @@ class Bot:
return
logger.error(
f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n" +
"Please make sure the bot current session is verified"
f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n"
+ "Please make sure the bot current session is verified"
)
# invite_callback event
@ -233,7 +250,8 @@ class Bot:
if type(result) == JoinError:
logger.error(
f"Error joining room {room.room_id} (attempt %d): %s",
attempt, result.message,
attempt,
result.message,
)
else:
break
@ -255,11 +273,11 @@ class Bot:
try:
client = self.client
logger.debug(
f"Device Event of type {type(event)} received in "
"to_device_cb().")
f"Device Event of type {type(event)} received in " "to_device_cb()."
)
if isinstance(event, KeyVerificationStart): # first step
""" first step: receive KeyVerificationStart
"""first step: receive KeyVerificationStart
KeyVerificationStart(
source={'content':
{'method': 'm.sas.v1',
@ -289,13 +307,14 @@ class Bot:
"""
if "emoji" not in event.short_authentication_string:
estr = ("Other device does not support emoji verification "
f"{event.short_authentication_string}. Aborting.")
estr = (
"Other device does not support emoji verification "
f"{event.short_authentication_string}. Aborting."
)
print(estr)
logger.info(estr)
return
resp = await client.accept_key_verification(
event.transaction_id)
resp = await client.accept_key_verification(event.transaction_id)
if isinstance(resp, ToDeviceError):
estr = f"accept_key_verification() failed with {resp}"
print(estr)
@ -311,7 +330,7 @@ class Bot:
logger.info(estr)
elif isinstance(event, KeyVerificationCancel): # anytime
""" at any time: receive KeyVerificationCancel
"""at any time: receive KeyVerificationCancel
KeyVerificationCancel(source={
'content': {'code': 'm.mismatched_sas',
'reason': 'Mismatched authentication string',
@ -328,13 +347,15 @@ class Bot:
# client.cancel_key_verification(tx_id, reject=False)
# here. The SAS flow is already cancelled.
# We only need to inform the user.
estr = (f"Verification has been cancelled by {event.sender} "
f"for reason \"{event.reason}\".")
estr = (
f"Verification has been cancelled by {event.sender} "
f'for reason "{event.reason}".'
)
print(estr)
logger.info(estr)
elif isinstance(event, KeyVerificationKey): # second step
""" Second step is to receive KeyVerificationKey
"""Second step is to receive KeyVerificationKey
KeyVerificationKey(
source={'content': {
'key': 'SomeCryptoKey',
@ -359,42 +380,44 @@ class Bot:
# automatic match, so we use y
yn = "y"
if yn.lower() == "y":
estr = ("Match! The verification for this "
"device will be accepted.")
estr = (
"Match! The verification for this " "device will be accepted."
)
print(estr)
logger.info(estr)
resp = await client.confirm_short_auth_string(
event.transaction_id)
resp = await client.confirm_short_auth_string(event.transaction_id)
if isinstance(resp, ToDeviceError):
estr = ("confirm_short_auth_string() "
f"failed with {resp}")
estr = "confirm_short_auth_string() " f"failed with {resp}"
print(estr)
logger.info(estr)
elif yn.lower() == "n": # no, don't match, reject
estr = ("No match! Device will NOT be verified "
"by rejecting verification.")
estr = (
"No match! Device will NOT be verified "
"by rejecting verification."
)
print(estr)
logger.info(estr)
resp = await client.cancel_key_verification(
event.transaction_id, reject=True)
event.transaction_id, reject=True
)
if isinstance(resp, ToDeviceError):
estr = (f"cancel_key_verification failed with {resp}")
estr = f"cancel_key_verification failed with {resp}"
print(estr)
logger.info(estr)
else: # C or anything for cancel
estr = ("Cancelled by user! Verification will be "
"cancelled.")
estr = "Cancelled by user! Verification will be " "cancelled."
print(estr)
logger.info(estr)
resp = await client.cancel_key_verification(
event.transaction_id, reject=False)
event.transaction_id, reject=False
)
if isinstance(resp, ToDeviceError):
estr = (f"cancel_key_verification failed with {resp}")
estr = f"cancel_key_verification failed with {resp}"
print(estr)
logger.info(estr)
elif isinstance(event, KeyVerificationMac): # third step
""" Third step is to receive KeyVerificationMac
"""Third step is to receive KeyVerificationMac
KeyVerificationMac(
source={'content': {
'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
@ -414,9 +437,11 @@ class Bot:
todevice_msg = sas.get_mac()
except LocalProtocolError as e:
# e.g. it might have been cancelled by ourselves
estr = (f"Cancelled or protocol error: Reason: {e}.\n"
f"Verification with {event.sender} not concluded. "
"Try again?")
estr = (
f"Cancelled or protocol error: Reason: {e}.\n"
f"Verification with {event.sender} not concluded. "
"Try again?"
)
print(estr)
logger.info(estr)
else:
@ -425,25 +450,31 @@ class Bot:
estr = f"to_device failed with {resp}"
print(estr)
logger.info(estr)
estr = (f"sas.we_started_it = {sas.we_started_it}\n"
f"sas.sas_accepted = {sas.sas_accepted}\n"
f"sas.canceled = {sas.canceled}\n"
f"sas.timed_out = {sas.timed_out}\n"
f"sas.verified = {sas.verified}\n"
f"sas.verified_devices = {sas.verified_devices}\n")
estr = (
f"sas.we_started_it = {sas.we_started_it}\n"
f"sas.sas_accepted = {sas.sas_accepted}\n"
f"sas.canceled = {sas.canceled}\n"
f"sas.timed_out = {sas.timed_out}\n"
f"sas.verified = {sas.verified}\n"
f"sas.verified_devices = {sas.verified_devices}\n"
)
print(estr)
logger.info(estr)
estr = ("Emoji verification was successful!\n"
"Initiate another Emoji verification from "
"another device or room if desired. "
"Or if done verifying, hit Control-C to stop the "
"bot in order to restart it as a service or to "
"run it in the background.")
estr = (
"Emoji verification was successful!\n"
"Initiate another Emoji verification from "
"another device or room if desired. "
"Or if done verifying, hit Control-C to stop the "
"bot in order to restart it as a service or to "
"run it in the background."
)
print(estr)
logger.info(estr)
else:
estr = (f"Received unexpected event type {type(event)}. "
f"Event is {event}. Event will be ignored.")
estr = (
f"Received unexpected event type {type(event)}. "
f"Event is {event}. Event will be ignored."
)
print(estr)
logger.info(estr)
except BaseException:
@ -452,7 +483,6 @@ class Bot:
logger.info(estr)
# bot login
async def login(self) -> None:
if self.access_token is not None:
logger.info("Login via access_token")
@ -480,14 +510,14 @@ class Bot:
# import keys
async def import_keys(self):
resp = await self.client.import_keys(
self.import_keys_path,
self.import_keys_password
self.import_keys_path, self.import_keys_password
)
if isinstance(resp, EncryptionError):
logger.error(f"import_keys failed with {resp}")
else:
logger.info(
f"import_keys success, please remove import_keys configuration!!!")
f"import_keys success, please remove import_keys configuration!!!"
)
# whisper function
def transcribe(self, filename: str) -> str:
@ -507,30 +537,33 @@ async def main():
config = json.load(fp)
bot = Bot(
homeserver=config.get('homeserver'),
user_id=config.get('user_id'),
password=config.get('password'),
device_id=config.get('device_id'),
room_id=config.get('room_id'),
access_token=config.get('access_token'),
import_keys_path=config.get('import_keys_path'),
import_keys_password=config.get('import_keys_password'),
model_size=config.get('model_size'),
device=config.get('device'),
compute_type=config.get('compute_type'),
cpu_threads=config.get('cpu_threads'),
num_workers=config.get('num_workers'),
download_root=config.get('download_root'),
homeserver=config.get("homeserver"),
user_id=config.get("user_id"),
password=config.get("password"),
device_id=config.get("device_id"),
room_id=config.get("room_id"),
access_token=config.get("access_token"),
import_keys_path=config.get("import_keys_path"),
import_keys_password=config.get("import_keys_password"),
model_size=config.get("model_size"),
device=config.get("device"),
compute_type=config.get("compute_type"),
cpu_threads=config.get("cpu_threads"),
num_workers=config.get("num_workers"),
download_root=config.get("download_root"),
)
if config.get('import_keys_path') and config.get('import_keys_password') is not None:
if (
config.get("import_keys_path")
and config.get("import_keys_password") is not None
):
need_import_keys = True
else:
bot = Bot(
homeserver=os.environ.get('HOMESERVER'),
user_id=os.environ.get('USER_ID'),
password=os.environ.get('PASSWORD'),
homeserver=os.environ.get("HOMESERVER"),
user_id=os.environ.get("USER_ID"),
password=os.environ.get("PASSWORD"),
device_id=os.environ.get("DEVICE_ID"),
room_id=os.environ.get("ROOM_ID"),
access_token=os.environ.get("ACCESS_TOKEN"),
@ -543,7 +576,10 @@ async def main():
num_workers=os.environ.get("NUM_WORKERS"),
download_root=os.environ.get("DOWNLOAD_ROOT"),
)
if os.environ.get("IMPORT_KEYS_PATH") and os.environ.get("IMPORT_KEYS_PASSWORD") is not None:
if (
os.environ.get("IMPORT_KEYS_PATH")
and os.environ.get("IMPORT_KEYS_PASSWORD") is not None
):
need_import_keys = True
await bot.login()
@ -551,8 +587,21 @@ async def main():
logger.info("start import_keys process, this may take a while...")
await bot.import_keys()
await bot.sync_forever()
sync_task = asyncio.create_task(bot.sync_forever())
if __name__ == '__main__':
# handle signal interrupt
loop = asyncio.get_running_loop()
for signame in (
"SIGINT",
"SIGTERM",
):
loop.add_signal_handler(
getattr(signal, signame), lambda: asyncio.create_task(bot.close(sync_task))
)
await sync_task
if __name__ == "__main__":
logger.info("Bot started!")
asyncio.run(main())

10
log.py
View file

@ -9,17 +9,19 @@ def getlogger():
# create handlers
warn_handler = logging.StreamHandler()
info_handler = logging.StreamHandler()
error_handler = logging.FileHandler('bot.log', mode='a')
error_handler = logging.FileHandler("bot.log", mode="a")
warn_handler.setLevel(logging.WARNING)
error_handler.setLevel(logging.ERROR)
info_handler.setLevel(logging.INFO)
# create formatters
warn_format = logging.Formatter(
'%(asctime)s - %(funcName)s - %(levelname)s - %(message)s')
"%(asctime)s - %(funcName)s - %(levelname)s - %(message)s"
)
error_format = logging.Formatter(
'%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s')
info_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
"%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s"
)
info_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
# set formatter
warn_handler.setFormatter(warn_format)

View file

@ -1,23 +1,42 @@
from nio import AsyncClient
async def send_room_message(client: AsyncClient,
room_id: str,
reply_message: str,
sender_id: str = '',
reply_to_event_id: str = '',
) -> None:
NORMAL_BODY = content = {"msgtype": "m.text", "body": reply_message, }
if reply_to_event_id == '':
content = NORMAL_BODY
else:
body = r'> <' + sender_id + r'> sent an audio file.\n\n' + reply_message
format = r'org.matrix.custom.html'
formatted_body = r'<mx-reply><blockquote><a href="https://matrix.to/#/' + room_id + r'/' + reply_to_event_id \
+ r'">In reply to</a> <a href="https://matrix.to/#/' + sender_id + r'">' + sender_id \
+ r'</a><br>sent an audio file.</blockquote></mx-reply>' + reply_message
content = {"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
async def send_room_message(
client: AsyncClient,
room_id: str,
reply_message: str,
sender_id: str = "",
reply_to_event_id: str = "",
) -> None:
NORMAL_BODY = content = {
"msgtype": "m.text",
"body": reply_message,
}
if reply_to_event_id == "":
content = NORMAL_BODY
else:
body = r"> <" + sender_id + r"> sent an audio file.\n\n" + reply_message
format = r"org.matrix.custom.html"
formatted_body = (
r'<mx-reply><blockquote><a href="https://matrix.to/#/'
+ room_id
+ r"/"
+ reply_to_event_id
+ r'">In reply to</a> <a href="https://matrix.to/#/'
+ sender_id
+ r'">'
+ sender_id
+ r"</a><br>sent an audio file.</blockquote></mx-reply>"
+ reply_message
)
content = {
"msgtype": "m.text",
"body": body,
"format": format,
"formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}},
}
await client.room_send(
room_id,
message_type="m.room.message",