Compare commits
No commits in common. "main" and "v1.0.0" have entirely different histories.
33 changed files with 1403 additions and 2526 deletions
|
@ -18,6 +18,4 @@ __pycache__
|
||||||
.env.example
|
.env.example
|
||||||
.github
|
.github
|
||||||
settings.js
|
settings.js
|
||||||
.vscode
|
|
||||||
Dockerfile_dev
|
|
||||||
LICENSE
|
|
||||||
|
|
13
.env.example
13
.env.example
|
@ -1,6 +1,11 @@
|
||||||
HOMESERVER="https://matrix-client.matrix.org" # required
|
# Please remove the option that is blank
|
||||||
|
HOMESERVER="https://matrix.xxxxxx.xxxx" # required
|
||||||
USER_ID="@lullap:xxxxxxxxxxxxx.xxx" # required
|
USER_ID="@lullap:xxxxxxxxxxxxx.xxx" # required
|
||||||
PASSWORD="xxxxxxxxxxxxxxx" # required
|
PASSWORD="xxxxxxxxxxxxxxx" # required
|
||||||
DEVICE_ID="MatrixChatGPTBot" # required
|
DEVICE_ID="xxxxxxxxxxxxxx" # required
|
||||||
ROOM_ID="!FYCmBSkCRUXXXXXXXXX:matrix.XXX.XXX" # Optional, if not set, bot will work on the room it is in
|
ROOM_ID="!FYCmBSkCRUXXXXXXXXX:matrix.XXX.XXX" # Optional, if the property is blank, bot will work on the room it is in (Unencrypted room only as for now)
|
||||||
OPENAI_API_KEY="xxxxxxxxxxxxxxxxx" # Optional
|
OPENAI_API_KEY="xxxxxxxxxxxxxxxxx" # Optional, for !chat and !gpt command
|
||||||
|
BING_API_ENDPOINT="xxxxxxxxxxxxxxx" # Optional, for !bing command
|
||||||
|
ACCESS_TOKEN="xxxxxxxxxxxxxxxxxxxxx" # Optional, use user_id and password is recommended
|
||||||
|
JAILBREAKENABLED="true" # Optional
|
||||||
|
BING_AUTH_COOKIE="xxxxxxxxxxxxxxxxxxx" # _U cookie, Optional, for Bing Image Creator
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
HOMESERVER="https://matrix-client.matrix.org"
|
|
||||||
USER_ID="@lullap:xxxxxxxxxxxxx.xxx"
|
|
||||||
PASSWORD="xxxxxxxxxxxxxxx"
|
|
||||||
ACCESS_TOKEN="xxxxxxxxxxx"
|
|
||||||
DEVICE_ID="xxxxxxxxxxxxxx"
|
|
||||||
ROOM_ID="!FYCmBSkCRUXXXXXXXXX:matrix.XXX.XXX"
|
|
||||||
IMPORT_KEYS_PATH="element-keys.txt"
|
|
||||||
IMPORT_KEYS_PASSWORD="xxxxxxxxxxxx"
|
|
||||||
OPENAI_API_KEY="xxxxxxxxxxxxxxxxx"
|
|
||||||
GPT_API_ENDPOINT="https://api.openai.com/v1/chat/completions"
|
|
||||||
GPT_MODEL="gpt-3.5-turbo"
|
|
||||||
MAX_TOKENS=4000
|
|
||||||
TOP_P=1.0
|
|
||||||
PRESENCE_PENALTY=0.0
|
|
||||||
FREQUENCY_PENALTY=0.0
|
|
||||||
REPLY_COUNT=1
|
|
||||||
SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respond conversationally"
|
|
||||||
TEMPERATURE=0.8
|
|
||||||
LC_ADMIN="@admin:xxxxxx.xxx,@admin2:xxxxxx.xxx"
|
|
||||||
IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img"
|
|
||||||
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai
|
|
||||||
IMAGE_GENERATION_SIZE="512x512"
|
|
||||||
IMAGE_FORMAT="webp"
|
|
||||||
SDWUI_STEPS=20
|
|
||||||
SDWUI_SAMPLER_NAME="Euler a"
|
|
||||||
SDWUI_CFG_SCALE=7
|
|
||||||
TIMEOUT=120.0
|
|
58
.github/workflows/docker-release.yml
vendored
58
.github/workflows/docker-release.yml
vendored
|
@ -7,26 +7,21 @@ on:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
push_to_registry:
|
push_to_registry:
|
||||||
name: Push Docker image to registry
|
name: Push Docker image to Docker Hub
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check out the repo
|
-
|
||||||
|
name: Check out the repo
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
-
|
||||||
- name: Log in to Docker Hub
|
name: Log in to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USERNAME }}
|
username: ${{ secrets.DOCKER_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||||
|
|
||||||
- name: Login to GitHub Container Registry
|
-
|
||||||
uses: docker/login-action@v2
|
name: Docker metadata
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.repository_owner }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Docker metadata
|
|
||||||
id: meta
|
id: meta
|
||||||
uses: docker/metadata-action@v4
|
uses: docker/metadata-action@v4
|
||||||
with:
|
with:
|
||||||
|
@ -34,40 +29,23 @@ jobs:
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value=latest
|
type=raw,value=latest
|
||||||
type=ref,event=tag
|
type=ref,event=tag
|
||||||
|
|
||||||
- name: Set up QEMU
|
-
|
||||||
|
name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v2
|
uses: docker/setup-qemu-action@v2
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
-
|
||||||
|
name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v2
|
||||||
|
|
||||||
- name: Build and push Docker image(dockerhub)
|
-
|
||||||
|
name: Build and push Docker image
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/386,linux/amd64,linux/arm/v6,linux/arm/v7,linux/arm64/v8,linux/ppc64le,linux/s390x
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
cache-from: type=gha
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max
|
||||||
|
|
||||||
- name: Docker metadata(ghcr)
|
|
||||||
id: meta2
|
|
||||||
uses: docker/metadata-action@v4
|
|
||||||
with:
|
|
||||||
images: ghcr.io/hibobmaster/matrixchatgptbot
|
|
||||||
tags: |
|
|
||||||
type=raw,value=latest
|
|
||||||
type=sha,format=long
|
|
||||||
|
|
||||||
- name: Build and push Docker image(ghcr)
|
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
platforms: linux/amd64,linux/arm64
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.meta2.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta2.outputs.labels }}
|
|
||||||
cache-from: type=gha
|
|
||||||
cache-to: type=gha,mode=max
|
|
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -28,7 +28,6 @@ share/python-wheels/
|
||||||
MANIFEST
|
MANIFEST
|
||||||
db
|
db
|
||||||
bot.log
|
bot.log
|
||||||
Dockerfile_dev
|
|
||||||
|
|
||||||
# image generation folder
|
# image generation folder
|
||||||
images/
|
images/
|
||||||
|
@ -168,8 +167,3 @@ cython_debug/
|
||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
# Custom
|
|
||||||
sync_db
|
|
||||||
manage_db
|
|
||||||
element-keys.txt
|
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
repos:
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v4.5.0
|
|
||||||
hooks:
|
|
||||||
- id: trailing-whitespace
|
|
||||||
- id: end-of-file-fixer
|
|
||||||
- id: check-yaml
|
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 23.12.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
||||||
rev: v0.1.7
|
|
||||||
hooks:
|
|
||||||
- id: ruff
|
|
||||||
args: [--fix, --exit-non-zero-on-fix]
|
|
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"python.analysis.typeCheckingMode": "off"
|
||||||
|
}
|
315
BingImageGen.py
Normal file
315
BingImageGen.py
Normal file
|
@ -0,0 +1,315 @@
|
||||||
|
"""
|
||||||
|
Code derived from:
|
||||||
|
https://github.com/acheong08/EdgeGPT/blob/f940cecd24a4818015a8b42a2443dd97c3c2a8f4/src/ImageGen.py
|
||||||
|
"""
|
||||||
|
from log import getlogger
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from uuid import uuid4
|
||||||
|
import os
|
||||||
|
import contextlib
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import regex
|
||||||
|
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
BING_URL = "https://www.bing.com"
|
||||||
|
# Generate random IP between range 13.104.0.0/14
|
||||||
|
FORWARDED_IP = (
|
||||||
|
f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
||||||
|
)
|
||||||
|
HEADERS = {
|
||||||
|
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||||
|
"accept-language": "en-US,en;q=0.9",
|
||||||
|
"cache-control": "max-age=0",
|
||||||
|
"content-type": "application/x-www-form-urlencoded",
|
||||||
|
"referrer": "https://www.bing.com/images/create/",
|
||||||
|
"origin": "https://www.bing.com",
|
||||||
|
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.63",
|
||||||
|
"x-forwarded-for": FORWARDED_IP,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Error messages
|
||||||
|
error_timeout = "Your request has timed out."
|
||||||
|
error_redirect = "Redirect failed"
|
||||||
|
error_blocked_prompt = (
|
||||||
|
"Your prompt has been blocked by Bing. Try to change any bad words and try again."
|
||||||
|
)
|
||||||
|
error_noresults = "Could not get results"
|
||||||
|
error_unsupported_lang = "\nthis language is currently not supported by bing"
|
||||||
|
error_bad_images = "Bad images"
|
||||||
|
error_no_images = "No images"
|
||||||
|
#
|
||||||
|
sending_message = "Sending request..."
|
||||||
|
wait_message = "Waiting for results..."
|
||||||
|
download_message = "\nDownloading images..."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def debug(debug_file, text_var):
|
||||||
|
"""helper function for debug"""
|
||||||
|
with open(f"{debug_file}", "a") as f:
|
||||||
|
f.write(str(text_var))
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGen:
|
||||||
|
"""
|
||||||
|
Image generation by Microsoft Bing
|
||||||
|
Parameters:3
|
||||||
|
auth_cookie: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, auth_cookie: str, debug_file: Union[str, None] = None, quiet: bool = False
|
||||||
|
) -> None:
|
||||||
|
self.session: requests.Session = requests.Session()
|
||||||
|
self.session.headers = HEADERS
|
||||||
|
self.session.cookies.set("_U", auth_cookie)
|
||||||
|
self.quiet = quiet
|
||||||
|
self.debug_file = debug_file
|
||||||
|
if self.debug_file:
|
||||||
|
self.debug = partial(debug, self.debug_file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_images(self, prompt: str) -> list:
|
||||||
|
"""
|
||||||
|
Fetches image links from Bing
|
||||||
|
Parameters:
|
||||||
|
prompt: str
|
||||||
|
"""
|
||||||
|
if not self.quiet:
|
||||||
|
print(sending_message)
|
||||||
|
if self.debug_file:
|
||||||
|
self.debug(sending_message)
|
||||||
|
url_encoded_prompt = requests.utils.quote(prompt)
|
||||||
|
# https://www.bing.com/images/create?q=<PROMPT>&rt=3&FORM=GENCRE
|
||||||
|
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||||
|
response = self.session.post(url, allow_redirects=False)
|
||||||
|
# check for content waring message
|
||||||
|
if "this prompt has been blocked" in response.text.lower():
|
||||||
|
if self.debug_file:
|
||||||
|
self.debug(f"ERROR: {error_blocked_prompt}")
|
||||||
|
raise Exception(
|
||||||
|
error_blocked_prompt,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"we're working hard to offer image creator in more languages"
|
||||||
|
in response.text.lower()
|
||||||
|
):
|
||||||
|
if self.debug_file:
|
||||||
|
self.debug(f"ERROR: {error_unsupported_lang}")
|
||||||
|
raise Exception(error_unsupported_lang)
|
||||||
|
if response.status_code != 302:
|
||||||
|
# if rt4 fails, try rt3
|
||||||
|
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=3&FORM=GENCRE"
|
||||||
|
response3 = self.session.post(url, allow_redirects=False, timeout=200)
|
||||||
|
if response3.status_code != 302:
|
||||||
|
if self.debug_file:
|
||||||
|
self.debug(f"ERROR: {error_redirect}")
|
||||||
|
print(f"ERROR: {response3.text}")
|
||||||
|
raise Exception(error_redirect)
|
||||||
|
response = response3
|
||||||
|
# Get redirect URL
|
||||||
|
redirect_url = response.headers["Location"].replace("&nfy=1", "")
|
||||||
|
request_id = redirect_url.split("id=")[-1]
|
||||||
|
self.session.get(f"{BING_URL}{redirect_url}")
|
||||||
|
# https://www.bing.com/images/create/async/results/{ID}?q={PROMPT}
|
||||||
|
polling_url = f"{BING_URL}/images/create/async/results/{request_id}?q={url_encoded_prompt}"
|
||||||
|
# Poll for results
|
||||||
|
if self.debug_file:
|
||||||
|
self.debug("Polling and waiting for result")
|
||||||
|
if not self.quiet:
|
||||||
|
print("Waiting for results...")
|
||||||
|
start_wait = time.time()
|
||||||
|
while True:
|
||||||
|
if int(time.time() - start_wait) > 200:
|
||||||
|
if self.debug_file:
|
||||||
|
self.debug(f"ERROR: {error_timeout}")
|
||||||
|
raise Exception(error_timeout)
|
||||||
|
if not self.quiet:
|
||||||
|
print(".", end="", flush=True)
|
||||||
|
response = self.session.get(polling_url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
if self.debug_file:
|
||||||
|
self.debug(f"ERROR: {error_noresults}")
|
||||||
|
raise Exception(error_noresults)
|
||||||
|
if not response.text or response.text.find("errorMessage") != -1:
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
# Use regex to search for src=""
|
||||||
|
image_links = regex.findall(r'src="([^"]+)"', response.text)
|
||||||
|
# Remove size limit
|
||||||
|
normal_image_links = [link.split("?w=")[0] for link in image_links]
|
||||||
|
# Remove duplicates
|
||||||
|
normal_image_links = list(set(normal_image_links))
|
||||||
|
|
||||||
|
# bad_images = [
|
||||||
|
# "https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||||
|
# "https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||||
|
# ]
|
||||||
|
# for img in normal_image_links:
|
||||||
|
# if img in bad_images:
|
||||||
|
# raise Exception("Bad images")
|
||||||
|
# No images
|
||||||
|
if not normal_image_links:
|
||||||
|
raise Exception(error_no_images)
|
||||||
|
return normal_image_links
|
||||||
|
|
||||||
|
def save_images(self, links: list, output_dir: str) -> str:
|
||||||
|
"""
|
||||||
|
Saves images to output directory
|
||||||
|
"""
|
||||||
|
|
||||||
|
# image name
|
||||||
|
image_name = str(uuid4())
|
||||||
|
# since matrix only support one media attachment per message, we just need one link
|
||||||
|
if links:
|
||||||
|
link = links.pop()
|
||||||
|
|
||||||
|
image_path = os.path.join(output_dir, f"{image_name}.jpeg")
|
||||||
|
|
||||||
|
with contextlib.suppress(FileExistsError):
|
||||||
|
os.mkdir(output_dir)
|
||||||
|
try:
|
||||||
|
with self.session.get(link, stream=True) as response:
|
||||||
|
# save response to file
|
||||||
|
response.raise_for_status()
|
||||||
|
with open(
|
||||||
|
os.path.join(output_dir, image_path), "wb"
|
||||||
|
) as output_file:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
output_file.write(chunk)
|
||||||
|
return image_path
|
||||||
|
except requests.exceptions.MissingSchema as url_exception:
|
||||||
|
raise Exception(
|
||||||
|
"Inappropriate contents found in the generated images. Please try again or try another prompt.",
|
||||||
|
) from url_exception
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenAsync:
|
||||||
|
"""
|
||||||
|
Image generation by Microsoft Bing
|
||||||
|
Parameters:
|
||||||
|
auth_cookie: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auth_cookie: str, quiet: bool = True) -> None:
|
||||||
|
self.session = aiohttp.ClientSession(
|
||||||
|
headers=HEADERS,
|
||||||
|
cookies={"_U": auth_cookie},
|
||||||
|
)
|
||||||
|
self.quiet = quiet
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *excinfo) -> None:
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
async def get_images(self, prompt: str) -> list:
|
||||||
|
"""
|
||||||
|
Fetches image links from Bing
|
||||||
|
Parameters:
|
||||||
|
prompt: str
|
||||||
|
"""
|
||||||
|
if not self.quiet:
|
||||||
|
print("Sending request...")
|
||||||
|
url_encoded_prompt = requests.utils.quote(prompt)
|
||||||
|
# https://www.bing.com/images/create?q=<PROMPT>&rt=3&FORM=GENCRE
|
||||||
|
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||||
|
async with self.session.post(url, allow_redirects=False) as response:
|
||||||
|
content = await response.text()
|
||||||
|
if "this prompt has been blocked" in content.lower():
|
||||||
|
raise Exception(
|
||||||
|
"Your prompt has been blocked by Bing. Try to change any bad words and try again.",
|
||||||
|
)
|
||||||
|
if response.status != 302:
|
||||||
|
# if rt4 fails, try rt3
|
||||||
|
url = (
|
||||||
|
f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=3&FORM=GENCRE"
|
||||||
|
)
|
||||||
|
async with self.session.post(
|
||||||
|
url,
|
||||||
|
allow_redirects=False,
|
||||||
|
timeout=200,
|
||||||
|
) as response3:
|
||||||
|
if response3.status != 302:
|
||||||
|
print(f"ERROR: {response3.text}")
|
||||||
|
raise Exception("Redirect failed")
|
||||||
|
response = response3
|
||||||
|
# Get redirect URL
|
||||||
|
redirect_url = response.headers["Location"].replace("&nfy=1", "")
|
||||||
|
request_id = redirect_url.split("id=")[-1]
|
||||||
|
await self.session.get(f"{BING_URL}{redirect_url}")
|
||||||
|
# https://www.bing.com/images/create/async/results/{ID}?q={PROMPT}
|
||||||
|
polling_url = f"{BING_URL}/images/create/async/results/{request_id}?q={url_encoded_prompt}"
|
||||||
|
# Poll for results
|
||||||
|
if not self.quiet:
|
||||||
|
print("Waiting for results...")
|
||||||
|
while True:
|
||||||
|
if not self.quiet:
|
||||||
|
print(".", end="", flush=True)
|
||||||
|
# By default, timeout is 300s, change as needed
|
||||||
|
response = await self.session.get(polling_url)
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception("Could not get results")
|
||||||
|
content = await response.text()
|
||||||
|
if content and content.find("errorMessage") == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
# Use regex to search for src=""
|
||||||
|
image_links = regex.findall(r'src="([^"]+)"', content)
|
||||||
|
# Remove size limit
|
||||||
|
normal_image_links = [link.split("?w=")[0] for link in image_links]
|
||||||
|
# Remove duplicates
|
||||||
|
normal_image_links = list(set(normal_image_links))
|
||||||
|
|
||||||
|
# Bad images
|
||||||
|
bad_images = [
|
||||||
|
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||||
|
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||||
|
]
|
||||||
|
for im in normal_image_links:
|
||||||
|
if im in bad_images:
|
||||||
|
raise Exception("Bad images")
|
||||||
|
# No images
|
||||||
|
if not normal_image_links:
|
||||||
|
raise Exception("No images")
|
||||||
|
return normal_image_links
|
||||||
|
|
||||||
|
async def save_images(self, links: list, output_dir: str) -> str:
|
||||||
|
"""
|
||||||
|
Saves images to output directory
|
||||||
|
"""
|
||||||
|
if not self.quiet:
|
||||||
|
print("\nDownloading images...")
|
||||||
|
with contextlib.suppress(FileExistsError):
|
||||||
|
os.mkdir(output_dir)
|
||||||
|
|
||||||
|
# image name
|
||||||
|
image_name = str(uuid4())
|
||||||
|
# since matrix only support one media attachment per message, we just need one link
|
||||||
|
if links:
|
||||||
|
link = links.pop()
|
||||||
|
|
||||||
|
image_path = os.path.join(output_dir, f"{image_name}.jpeg")
|
||||||
|
try:
|
||||||
|
async with self.session.get(link, raise_for_status=True) as response:
|
||||||
|
# save response to file
|
||||||
|
with open(image_path, "wb") as output_file:
|
||||||
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
|
output_file.write(chunk)
|
||||||
|
return f"{output_dir}/{image_name}.jpeg"
|
||||||
|
|
||||||
|
except aiohttp.client_exceptions.InvalidURL as url_exception:
|
||||||
|
raise Exception(
|
||||||
|
"Inappropriate contents found in the generated images. Please try again or try another prompt.",
|
||||||
|
) from url_exception
|
37
CHANGELOG.md
37
CHANGELOG.md
|
@ -1,37 +0,0 @@
|
||||||
# Changelog
|
|
||||||
|
|
||||||
## 1.5.3
|
|
||||||
- Make gptbot more compatible by using non-streaming method
|
|
||||||
|
|
||||||
## 1.5.2
|
|
||||||
- Expose more stable diffusion webui api parameters
|
|
||||||
|
|
||||||
## 1.5.1
|
|
||||||
- fix: set timeout not work in image generation
|
|
||||||
|
|
||||||
## 1.5.0
|
|
||||||
- Fix localai v2.0+ image generation
|
|
||||||
- Fallback to gpt-3.5-turbo when caculate tokens using custom model
|
|
||||||
|
|
||||||
## 1.4.1
|
|
||||||
- Fix variable type imported from environment variable
|
|
||||||
- Bump pre-commit hook version
|
|
||||||
|
|
||||||
## 1.4.0
|
|
||||||
- Fix access_token login method not work in E2EE Room
|
|
||||||
|
|
||||||
## 1.3.0
|
|
||||||
- remove support for bing,bard,pandora
|
|
||||||
- refactor chat logic, add self host model support
|
|
||||||
- support new image generation endpoint
|
|
||||||
- admin system to manage langchain(flowise backend)
|
|
||||||
|
|
||||||
## 1.2.0
|
|
||||||
- rename `api_key` to `openai_api_key` in `config.json`
|
|
||||||
- rename `bing_api_endpoint` to `api_endpoint` in `config.json` and `env` file
|
|
||||||
- add `temperature` option to control ChatGPT model temperature
|
|
||||||
- remove `jailbreakEnabled` option
|
|
||||||
- session isolation for `!chat`, `!bing`, `!bard` command
|
|
||||||
- `!new + {chat,bing,bard,talk}` now can be used to create new conversation
|
|
||||||
- send some error message to user
|
|
||||||
- bug fix and code cleanup
|
|
14
Dockerfile
14
Dockerfile
|
@ -1,16 +1,20 @@
|
||||||
FROM python:3.11-alpine as base
|
FROM python:3.11-alpine as base
|
||||||
|
|
||||||
FROM base as pybuilder
|
FROM base as pybuilder
|
||||||
# RUN sed -i 's|v3\.\d*|edge|' /etc/apk/repositories
|
RUN sed -i 's|v3\.\d*|edge|' /etc/apk/repositories
|
||||||
RUN apk update && apk add --no-cache olm-dev gcc musl-dev libmagic libffi-dev cmake make g++ git python3-dev
|
RUN apk update && apk add olm-dev gcc musl-dev libmagic
|
||||||
COPY requirements.txt /requirements.txt
|
COPY requirements.txt /requirements.txt
|
||||||
RUN pip install -U pip setuptools wheel && pip install --user -r /requirements.txt && rm /requirements.txt
|
RUN pip3 install --user -r /requirements.txt && rm /requirements.txt
|
||||||
|
|
||||||
|
|
||||||
FROM base as runner
|
FROM base as runner
|
||||||
RUN apk update && apk add --no-cache olm-dev libmagic libffi-dev
|
LABEL "org.opencontainers.image.source"="https://github.com/hibobmaster/matrix_chatgpt_bot"
|
||||||
|
RUN apk update && apk add olm-dev libmagic
|
||||||
COPY --from=pybuilder /root/.local /usr/local
|
COPY --from=pybuilder /root/.local /usr/local
|
||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
|
|
||||||
FROM runner
|
FROM runner
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
CMD ["python", "src/main.py"]
|
CMD ["python", "main.py"]
|
||||||
|
|
||||||
|
|
21
LICENSE
21
LICENSE
|
@ -1,21 +0,0 @@
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2023 hibobmaster
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
112
README.md
112
README.md
|
@ -1,117 +1,79 @@
|
||||||
## Introduction
|
## Introduction
|
||||||
|
This is a simple Matrix bot that uses OpenAI's GPT API and Bing AI to generate responses to user inputs. The bot responds to four types of prompts: `!gpt`, `!chat` and `!bing` and `!pic` depending on the first word of the prompt.
|
||||||
This is a simple Matrix bot that support using OpenAI API, Langchain to generate responses from user inputs. The bot responds to these commands: `!gpt`, `!chat`, `!pic`, `!new`, `!lc` and `!help` depending on the first word of the prompt.
|
![demo](https://i.imgur.com/kK4rnPf.jpeg "demo")
|
||||||
![ChatGPT](https://i.imgur.com/kK4rnPf.jpeg)
|
|
||||||
|
|
||||||
## Feature
|
## Feature
|
||||||
|
1. Support openai and Bing AI
|
||||||
1. Support official openai api and self host models([LocalAI](https://localai.io/model-compatibility/))
|
2. Support Bing Image Creator
|
||||||
2. Support E2E Encrypted Room
|
3. Support E2E Encrypted Room
|
||||||
3. Colorful code blocks
|
|
||||||
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
|
|
||||||
5. Image Generation with [DALL·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
|
|
||||||
|
|
||||||
|
|
||||||
## Installation and Setup
|
## Installation and Setup
|
||||||
|
|
||||||
Docker method(Recommended):<br>
|
Docker method(Recommended):<br>
|
||||||
Edit `config.json` or `.env` with proper values <br>
|
Edit `config.json` or `.env` with proper values <br>
|
||||||
For explainations and complete parameter list see: https://github.com/hibobmaster/matrix_chatgpt_bot/wiki <br>
|
Create an empty file, for persist database only<br>
|
||||||
Create two empty file, for persist database only<br>
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
touch sync_db manage_db
|
touch db
|
||||||
sudo docker compose up -d
|
sudo docker compose up -d
|
||||||
```
|
```
|
||||||
manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database<br>
|
|
||||||
<hr>
|
<hr>
|
||||||
Normal Method:<br>
|
|
||||||
system dependece: <code>libolm-dev</code>
|
|
||||||
|
|
||||||
|
To run this application, follow the steps below:<br>
|
||||||
1. Clone the repository and create virtual environment:
|
1. Clone the repository and create virtual environment:
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://github.com/hibobmaster/matrix_chatgpt_bot.git
|
git clone https://github.com/hibobmaster/matrix_chatgpt_bot.git
|
||||||
|
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
source venv/bin/activate
|
source venv/bin/activate
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the required dependencies:<br>
|
2. Install the required dependencies:<br>
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install -U pip setuptools wheel
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
3. Create a new config.json file and fill it with the necessary information:<br>
|
||||||
3. Create a new config.json file and complete it with the necessary information:<br>
|
Use password to login(recommended) or provide `access_token` <br>
|
||||||
If not set:<br>
|
If not set:<br>
|
||||||
`room_id`: bot will work in the room where it is in <br>
|
`room_id`: bot will work in the room where it is in <br>
|
||||||
|
`api_key`: `!chat` command will not work <br>
|
||||||
|
`bing_api_endpoint`: `!bing` command will not work <br>
|
||||||
|
`bing_auth_cookie`: `!pic` command will not work
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"homeserver": "YOUR_HOMESERVER",
|
"homeserver": "YOUR_HOMESERVER",
|
||||||
"user_id": "YOUR_USER_ID",
|
"user_id": "YOUR_USER_ID",
|
||||||
"password": "YOUR_PASSWORD",
|
"password": "YOUR_PASSWORD",
|
||||||
"device_id": "YOUR_DEVICE_ID",
|
"device_id": "YOUR_DEVICE_ID",
|
||||||
"room_id": "YOUR_ROOM_ID",
|
"room_id": "YOUR_ROOM_ID",
|
||||||
"openai_api_key": "YOUR_API_KEY",
|
"api_key": "YOUR_API_KEY",
|
||||||
"gpt_api_endpoint": "xxxxxxxxx"
|
"access_token": "xxxxxxxxxxxxxx",
|
||||||
|
"bing_api_endpoint": "xxxxxxxxx",
|
||||||
|
"bing_auth_cookie": "xxxxxxxxxx"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
4. Start the bot:
|
||||||
4. Launch the bot:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
python src/main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
To interact with the bot, simply send a message to the bot in the Matrix room with one of the two prompts:<br>
|
||||||
To interact with the bot, simply send a message to the bot in the Matrix room with one of the following prompts:<br>
|
- `!gpt` To generate a response using free_endpoint API:
|
||||||
- `!help` help message
|
|
||||||
|
|
||||||
- `!gpt` To generate a one time response:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
!gpt What is the meaning of life?
|
!gpt What is the meaning of life?
|
||||||
```
|
```
|
||||||
|
|
||||||
- `!chat` To chat using official api with context conversation
|
- `!chat` To chat using official api with context conversation
|
||||||
|
|
||||||
```
|
```
|
||||||
!chat Can you tell me a joke?
|
!chat Can you tell me a joke?
|
||||||
```
|
```
|
||||||
|
- `!bing` To chat with Bing AI with context conversation
|
||||||
- `!lc` To chat using langchain api endpoint
|
|
||||||
```
|
```
|
||||||
!lc All the world is a stage
|
!bing Do you know Victor Marie Hugo?
|
||||||
```
|
```
|
||||||
- `!pic` To generate an image using openai DALL·E or LocalAI
|
- `!pic` To generate an image from Microsoft Bing
|
||||||
|
|
||||||
```
|
```
|
||||||
!pic A bridal bouquet made of succulents
|
!pic A bridal bouquet made of succulents
|
||||||
```
|
```
|
||||||
- `!agent` display or set langchain agent
|
## Bing AI and Image Generation
|
||||||
```
|
https://github.com/waylaidwanderer/node-chatgpt-api <br>
|
||||||
!agent list
|
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/Bing-AI <br>
|
||||||
!agent use {agent_name}
|
https://github.com/acheong08/EdgeGPT/blob/master/src/ImageGen.py
|
||||||
```
|
![](https://i.imgur.com/KuYddd5.jpg)
|
||||||
- `!new + {chat}` Start a new converstaion
|
![](https://i.imgur.com/3SRQdN0.jpg)
|
||||||
|
|
||||||
LangChain(flowise) admin: https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/Langchain-(flowise)
|
|
||||||
|
|
||||||
## Image Generation
|
|
||||||
![demo1](https://i.imgur.com/voeomsF.jpg)
|
|
||||||
![demo2](https://i.imgur.com/BKZktWd.jpg)
|
|
||||||
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/ <br>
|
|
||||||
|
|
||||||
|
|
||||||
## Thanks
|
|
||||||
1. [matrix-nio](https://github.com/poljar/matrix-nio)
|
|
||||||
2. [acheong08](https://github.com/acheong08)
|
|
||||||
3. [8go](https://github.com/8go/)
|
|
||||||
|
|
||||||
<a href="https://jb.gg/OpenSourceSupport" target="_blank">
|
|
||||||
<img src="https://resources.jetbrains.com/storage/products/company/brand/logos/jb_beam.png" alt="JetBrains Logo (Main) logo." width="200" height="200">
|
|
||||||
</a>
|
|
||||||
|
|
38
askgpt.py
Normal file
38
askgpt.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from log import getlogger
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
class askGPT:
|
||||||
|
def __init__(self):
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
async def oneTimeAsk(self, prompt: str, api_endpoint: str, headers: dict) -> str:
|
||||||
|
jsons = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
max_try = 3
|
||||||
|
while max_try > 0:
|
||||||
|
try:
|
||||||
|
async with self.session.post(url=api_endpoint,
|
||||||
|
json=jsons, headers=headers, timeout=60) as response:
|
||||||
|
status_code = response.status
|
||||||
|
if not status_code == 200:
|
||||||
|
# print failed reason
|
||||||
|
logger.warning(str(response.reason))
|
||||||
|
max_try = max_try - 1
|
||||||
|
# wait 2s
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
resp = await response.read()
|
||||||
|
return json.loads(resp)['choices'][0]['message']['content']
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error Exception", exc_info=True)
|
52
bing.py
Normal file
52
bing.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from log import getlogger
|
||||||
|
# api_endpoint = "http://localhost:3000/conversation"
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
|
||||||
|
class BingBot:
|
||||||
|
def __init__(self, bing_api_endpoint: str, jailbreakEnabled: bool = False):
|
||||||
|
self.data = {
|
||||||
|
'clientOptions.clientToUse': 'bing',
|
||||||
|
}
|
||||||
|
self.bing_api_endpoint = bing_api_endpoint
|
||||||
|
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
self.jailbreakEnabled = jailbreakEnabled
|
||||||
|
|
||||||
|
if self.jailbreakEnabled:
|
||||||
|
self.data['jailbreakConversationId'] = json.dumps(True)
|
||||||
|
|
||||||
|
async def ask_bing(self, prompt) -> str:
|
||||||
|
self.data['message'] = prompt
|
||||||
|
max_try = 3
|
||||||
|
while max_try > 0:
|
||||||
|
try:
|
||||||
|
resp = await self.session.post(url=self.bing_api_endpoint, json=self.data, timeout=60)
|
||||||
|
status_code = resp.status
|
||||||
|
body = await resp.read()
|
||||||
|
if not status_code == 200:
|
||||||
|
# print failed reason
|
||||||
|
logger.warning(str(resp.reason))
|
||||||
|
max_try = max_try - 1
|
||||||
|
# print(await resp.text())
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
continue
|
||||||
|
json_body = json.loads(body)
|
||||||
|
if self.jailbreakEnabled:
|
||||||
|
self.data['jailbreakConversationId'] = json_body['jailbreakConversationId']
|
||||||
|
self.data['parentMessageId'] = json_body['messageId']
|
||||||
|
else:
|
||||||
|
self.data['conversationSignature'] = json_body['conversationSignature']
|
||||||
|
self.data['conversationId'] = json_body['conversationId']
|
||||||
|
self.data['clientId'] = json_body['clientId']
|
||||||
|
self.data['invocationId'] = json_body['invocationId']
|
||||||
|
return json_body['response']
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error Exception", exc_info=True)
|
||||||
|
print(f"Error: {e}")
|
||||||
|
pass
|
||||||
|
return "Error, please retry"
|
600
bot.py
Normal file
600
bot.py
Normal file
|
@ -0,0 +1,600 @@
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
import traceback
|
||||||
|
from typing import Optional, Union
|
||||||
|
from nio import (
|
||||||
|
AsyncClient,
|
||||||
|
MatrixRoom,
|
||||||
|
RoomMessageText,
|
||||||
|
InviteMemberEvent,
|
||||||
|
MegolmEvent,
|
||||||
|
LoginResponse,
|
||||||
|
JoinError,
|
||||||
|
ToDeviceError,
|
||||||
|
LocalProtocolError,
|
||||||
|
KeyVerificationEvent,
|
||||||
|
KeyVerificationStart,
|
||||||
|
KeyVerificationCancel,
|
||||||
|
KeyVerificationKey,
|
||||||
|
KeyVerificationMac,
|
||||||
|
AsyncClientConfig
|
||||||
|
)
|
||||||
|
from nio.store.database import SqliteStore
|
||||||
|
from askgpt import askGPT
|
||||||
|
from send_message import send_room_message
|
||||||
|
from v3 import Chatbot
|
||||||
|
from log import getlogger
|
||||||
|
from bing import BingBot
|
||||||
|
from BingImageGen import ImageGenAsync
|
||||||
|
from send_image import send_room_image
|
||||||
|
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
|
||||||
|
class Bot:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
homeserver: str,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
chatgpt_api_endpoint: str = os.environ.get(
|
||||||
|
"CHATGPT_API_ENDPOINT") or "https://api.openai.com/v1/chat/completions",
|
||||||
|
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") or "",
|
||||||
|
room_id: Union[str, None] = None,
|
||||||
|
bing_api_endpoint: Optional[str] = '',
|
||||||
|
password: Union[str, None] = None,
|
||||||
|
access_token: Union[str, None] = None,
|
||||||
|
jailbreakEnabled: Optional[bool] = True,
|
||||||
|
bing_auth_cookie: Optional[str] = '',
|
||||||
|
):
|
||||||
|
if (homeserver is None or user_id is None
|
||||||
|
or device_id is None):
|
||||||
|
logger.warning("homeserver && user_id && device_id is required")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if (password is None and access_token is None):
|
||||||
|
logger.warning("password and access_toekn is required")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
self.homeserver = homeserver
|
||||||
|
self.user_id = user_id
|
||||||
|
self.password = password
|
||||||
|
self.access_token = access_token
|
||||||
|
self.device_id = device_id
|
||||||
|
self.room_id = room_id
|
||||||
|
self.api_key = api_key
|
||||||
|
self.chatgpt_api_endpoint = chatgpt_api_endpoint
|
||||||
|
self.bing_api_endpoint = bing_api_endpoint
|
||||||
|
self.jailbreakEnabled = jailbreakEnabled
|
||||||
|
self.bing_auth_cookie = bing_auth_cookie
|
||||||
|
# initialize AsyncClient object
|
||||||
|
self.store_path = os.getcwd()
|
||||||
|
self.config = AsyncClientConfig(store=SqliteStore,
|
||||||
|
store_name="db",
|
||||||
|
store_sync_tokens=True,
|
||||||
|
encryption_enabled=True,
|
||||||
|
)
|
||||||
|
self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id,
|
||||||
|
config=self.config, store_path=self.store_path,)
|
||||||
|
|
||||||
|
if self.access_token is not None:
|
||||||
|
self.client.access_token = self.access_token
|
||||||
|
|
||||||
|
# setup event callbacks
|
||||||
|
self.client.add_event_callback(
|
||||||
|
self.message_callback, (RoomMessageText, ))
|
||||||
|
self.client.add_event_callback(
|
||||||
|
self.decryption_failure, (MegolmEvent, ))
|
||||||
|
self.client.add_event_callback(
|
||||||
|
self.invite_callback, (InviteMemberEvent, ))
|
||||||
|
self.client.add_to_device_callback(
|
||||||
|
self.to_device_callback, (KeyVerificationEvent, ))
|
||||||
|
|
||||||
|
# regular expression to match keyword [!gpt {prompt}] [!chat {prompt}]
|
||||||
|
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
||||||
|
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
||||||
|
self.bing_prog = re.compile(r"^\s*!bing\s*(.+)$")
|
||||||
|
self.pic_prog = re.compile(r"^\s*!pic\s*(.+)$")
|
||||||
|
self.help_prog = re.compile(r"^\s*!help\s*.*$")
|
||||||
|
|
||||||
|
# initialize chatbot and chatgpt_api_endpoint
|
||||||
|
if self.api_key != '':
|
||||||
|
self.chatbot = Chatbot(api_key=self.api_key, timeout=60)
|
||||||
|
|
||||||
|
self.chatgpt_api_endpoint = self.chatgpt_api_endpoint
|
||||||
|
# request header for !gpt command
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.chatgpt_api_endpoint = self.chatgpt_api_endpoint
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# initialize askGPT class
|
||||||
|
self.askgpt = askGPT()
|
||||||
|
|
||||||
|
# initialize bingbot
|
||||||
|
if self.bing_api_endpoint != '':
|
||||||
|
self.bingbot = BingBot(
|
||||||
|
bing_api_endpoint, jailbreakEnabled=self.jailbreakEnabled)
|
||||||
|
|
||||||
|
# initialize BingImageGenAsync
|
||||||
|
if self.bing_auth_cookie != '':
|
||||||
|
self.imageGen = ImageGenAsync(self.bing_auth_cookie, quiet=True)
|
||||||
|
|
||||||
|
# get current event loop
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
# message_callback RoomMessageText event
|
||||||
|
async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
|
if self.room_id is None:
|
||||||
|
room_id = room.room_id
|
||||||
|
else:
|
||||||
|
# if event room id does not match the room id in config, return
|
||||||
|
if room.room_id != self.room_id:
|
||||||
|
return
|
||||||
|
room_id = self.room_id
|
||||||
|
|
||||||
|
# reply event_id
|
||||||
|
reply_to_event_id = event.event_id
|
||||||
|
|
||||||
|
# sender_id
|
||||||
|
sender_id = event.sender
|
||||||
|
|
||||||
|
# user_message
|
||||||
|
raw_user_message = event.body
|
||||||
|
|
||||||
|
# print info to console
|
||||||
|
print(
|
||||||
|
f"Message received in room {room.display_name}\n"
|
||||||
|
f"{room.user_name(event.sender)} | {raw_user_message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# prevent command trigger loop
|
||||||
|
if self.user_id != event.sender:
|
||||||
|
# remove newline character from event.body
|
||||||
|
content_body = re.sub("\r\n|\r|\n", " ", raw_user_message)
|
||||||
|
|
||||||
|
# chatgpt
|
||||||
|
n = self.chat_prog.match(content_body)
|
||||||
|
if n:
|
||||||
|
prompt = n.group(1)
|
||||||
|
if self.api_key != '':
|
||||||
|
try:
|
||||||
|
await self.chat(room_id,
|
||||||
|
reply_to_event_id,
|
||||||
|
prompt,
|
||||||
|
sender_id,
|
||||||
|
raw_user_message
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
await send_room_message(self.client, room_id, reply_message=str(e))
|
||||||
|
else:
|
||||||
|
logger.warning("No API_KEY provided")
|
||||||
|
await send_room_message(self.client, room_id, reply_message="API_KEY not provided")
|
||||||
|
|
||||||
|
m = self.gpt_prog.match(content_body)
|
||||||
|
if m:
|
||||||
|
prompt = m.group(1)
|
||||||
|
try:
|
||||||
|
await self.gpt(
|
||||||
|
room_id,
|
||||||
|
reply_to_event_id,
|
||||||
|
prompt, sender_id,
|
||||||
|
raw_user_message
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
await send_room_message(self.client, room_id, reply_message=str(e))
|
||||||
|
|
||||||
|
# bing ai
|
||||||
|
if self.bing_api_endpoint != '':
|
||||||
|
b = self.bing_prog.match(content_body)
|
||||||
|
if b:
|
||||||
|
prompt = b.group(1)
|
||||||
|
# raw_content_body used for construct formatted_body
|
||||||
|
try:
|
||||||
|
await self.bing(
|
||||||
|
room_id,
|
||||||
|
reply_to_event_id,
|
||||||
|
prompt,
|
||||||
|
sender_id,
|
||||||
|
raw_user_message
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await send_room_message(self.client, room_id, reply_message=str(e))
|
||||||
|
|
||||||
|
# Image Generation by Microsoft Bing
|
||||||
|
if self.bing_auth_cookie != '':
|
||||||
|
i = self.pic_prog.match(content_body)
|
||||||
|
if i:
|
||||||
|
prompt = i.group(1)
|
||||||
|
try:
|
||||||
|
await self.pic(room_id, prompt)
|
||||||
|
except Exception as e:
|
||||||
|
await send_room_message(self.client, room_id, reply_message=str(e))
|
||||||
|
|
||||||
|
# help command
|
||||||
|
h = self.help_prog.match(content_body)
|
||||||
|
if h:
|
||||||
|
await self.help(room_id)
|
||||||
|
|
||||||
|
# message_callback decryption_failure event
|
||||||
|
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None:
|
||||||
|
if not isinstance(event, MegolmEvent):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Failed to decrypt message: {event.event_id} from {event.sender} in {room.room_id}\n" +
|
||||||
|
"Please make sure the bot current session is verified"
|
||||||
|
)
|
||||||
|
|
||||||
|
# invite_callback event
|
||||||
|
async def invite_callback(self, room: MatrixRoom, event: InviteMemberEvent) -> None:
|
||||||
|
"""Handle an incoming invite event.
|
||||||
|
https://github.com/8go/matrix-eno-bot/blob/ad037e02bd2960941109e9526c1033dd157bb212/callbacks.py#L104
|
||||||
|
If an invite is received, then join the room specified in the invite.
|
||||||
|
code copied from:
|
||||||
|
"""
|
||||||
|
logger.debug(f"Got invite to {room.room_id} from {event.sender}.")
|
||||||
|
# Attempt to join 3 times before giving up
|
||||||
|
for attempt in range(3):
|
||||||
|
result = await self.client.join(room.room_id)
|
||||||
|
if type(result) == JoinError:
|
||||||
|
logger.error(
|
||||||
|
f"Error joining room {room.room_id} (attempt %d): %s",
|
||||||
|
attempt, result.message,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error("Unable to join room: %s", room.room_id)
|
||||||
|
|
||||||
|
# Successfully joined room
|
||||||
|
logger.info(f"Joined {room.room_id}")
|
||||||
|
|
||||||
|
# to_device_callback event
|
||||||
|
async def to_device_callback(self, event: KeyVerificationEvent) -> None:
|
||||||
|
"""Handle events sent to device.
|
||||||
|
|
||||||
|
Specifically this will perform Emoji verification.
|
||||||
|
It will accept an incoming Emoji verification requests
|
||||||
|
and follow the verification protocol.
|
||||||
|
code copied from: https://github.com/8go/matrix-eno-bot/blob/ad037e02bd2960941109e9526c1033dd157bb212/callbacks.py#L127
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self.client
|
||||||
|
logger.debug(
|
||||||
|
f"Device Event of type {type(event)} received in "
|
||||||
|
"to_device_cb().")
|
||||||
|
|
||||||
|
if isinstance(event, KeyVerificationStart): # first step
|
||||||
|
""" first step: receive KeyVerificationStart
|
||||||
|
KeyVerificationStart(
|
||||||
|
source={'content':
|
||||||
|
{'method': 'm.sas.v1',
|
||||||
|
'from_device': 'DEVICEIDXY',
|
||||||
|
'key_agreement_protocols':
|
||||||
|
['curve25519-hkdf-sha256', 'curve25519'],
|
||||||
|
'hashes': ['sha256'],
|
||||||
|
'message_authentication_codes':
|
||||||
|
['hkdf-hmac-sha256', 'hmac-sha256'],
|
||||||
|
'short_authentication_string':
|
||||||
|
['decimal', 'emoji'],
|
||||||
|
'transaction_id': 'SomeTxId'
|
||||||
|
},
|
||||||
|
'type': 'm.key.verification.start',
|
||||||
|
'sender': '@user2:example.org'
|
||||||
|
},
|
||||||
|
sender='@user2:example.org',
|
||||||
|
transaction_id='SomeTxId',
|
||||||
|
from_device='DEVICEIDXY',
|
||||||
|
method='m.sas.v1',
|
||||||
|
key_agreement_protocols=[
|
||||||
|
'curve25519-hkdf-sha256', 'curve25519'],
|
||||||
|
hashes=['sha256'],
|
||||||
|
message_authentication_codes=[
|
||||||
|
'hkdf-hmac-sha256', 'hmac-sha256'],
|
||||||
|
short_authentication_string=['decimal', 'emoji'])
|
||||||
|
"""
|
||||||
|
|
||||||
|
if "emoji" not in event.short_authentication_string:
|
||||||
|
estr = ("Other device does not support emoji verification "
|
||||||
|
f"{event.short_authentication_string}. Aborting.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
return
|
||||||
|
resp = await client.accept_key_verification(
|
||||||
|
event.transaction_id)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = f"accept_key_verification() failed with {resp}"
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
sas = client.key_verifications[event.transaction_id]
|
||||||
|
|
||||||
|
todevice_msg = sas.share_key()
|
||||||
|
resp = await client.to_device(todevice_msg)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = f"to_device() failed with {resp}"
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
elif isinstance(event, KeyVerificationCancel): # anytime
|
||||||
|
""" at any time: receive KeyVerificationCancel
|
||||||
|
KeyVerificationCancel(source={
|
||||||
|
'content': {'code': 'm.mismatched_sas',
|
||||||
|
'reason': 'Mismatched authentication string',
|
||||||
|
'transaction_id': 'SomeTxId'},
|
||||||
|
'type': 'm.key.verification.cancel',
|
||||||
|
'sender': '@user2:example.org'},
|
||||||
|
sender='@user2:example.org',
|
||||||
|
transaction_id='SomeTxId',
|
||||||
|
code='m.mismatched_sas',
|
||||||
|
reason='Mismatched short authentication string')
|
||||||
|
"""
|
||||||
|
|
||||||
|
# There is no need to issue a
|
||||||
|
# client.cancel_key_verification(tx_id, reject=False)
|
||||||
|
# here. The SAS flow is already cancelled.
|
||||||
|
# We only need to inform the user.
|
||||||
|
estr = (f"Verification has been cancelled by {event.sender} "
|
||||||
|
f"for reason \"{event.reason}\".")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
elif isinstance(event, KeyVerificationKey): # second step
|
||||||
|
""" Second step is to receive KeyVerificationKey
|
||||||
|
KeyVerificationKey(
|
||||||
|
source={'content': {
|
||||||
|
'key': 'SomeCryptoKey',
|
||||||
|
'transaction_id': 'SomeTxId'},
|
||||||
|
'type': 'm.key.verification.key',
|
||||||
|
'sender': '@user2:example.org'
|
||||||
|
},
|
||||||
|
sender='@user2:example.org',
|
||||||
|
transaction_id='SomeTxId',
|
||||||
|
key='SomeCryptoKey')
|
||||||
|
"""
|
||||||
|
sas = client.key_verifications[event.transaction_id]
|
||||||
|
|
||||||
|
print(f"{sas.get_emoji()}")
|
||||||
|
# don't log the emojis
|
||||||
|
|
||||||
|
# The bot process must run in forground with a screen and
|
||||||
|
# keyboard so that user can accept/reject via keyboard.
|
||||||
|
# For emoji verification bot must not run as service or
|
||||||
|
# in background.
|
||||||
|
# yn = input("Do the emojis match? (Y/N) (C for Cancel) ")
|
||||||
|
# automatic match, so we use y
|
||||||
|
yn = "y"
|
||||||
|
if yn.lower() == "y":
|
||||||
|
estr = ("Match! The verification for this "
|
||||||
|
"device will be accepted.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
resp = await client.confirm_short_auth_string(
|
||||||
|
event.transaction_id)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = ("confirm_short_auth_string() "
|
||||||
|
f"failed with {resp}")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
elif yn.lower() == "n": # no, don't match, reject
|
||||||
|
estr = ("No match! Device will NOT be verified "
|
||||||
|
"by rejecting verification.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
resp = await client.cancel_key_verification(
|
||||||
|
event.transaction_id, reject=True)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = (f"cancel_key_verification failed with {resp}")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
else: # C or anything for cancel
|
||||||
|
estr = ("Cancelled by user! Verification will be "
|
||||||
|
"cancelled.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
resp = await client.cancel_key_verification(
|
||||||
|
event.transaction_id, reject=False)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = (f"cancel_key_verification failed with {resp}")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
elif isinstance(event, KeyVerificationMac): # third step
|
||||||
|
""" Third step is to receive KeyVerificationMac
|
||||||
|
KeyVerificationMac(
|
||||||
|
source={'content': {
|
||||||
|
'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
|
||||||
|
'ed25519:SomeKey2': 'SomeKey3'},
|
||||||
|
'keys': 'SomeCryptoKey4',
|
||||||
|
'transaction_id': 'SomeTxId'},
|
||||||
|
'type': 'm.key.verification.mac',
|
||||||
|
'sender': '@user2:example.org'},
|
||||||
|
sender='@user2:example.org',
|
||||||
|
transaction_id='SomeTxId',
|
||||||
|
mac={'ed25519:DEVICEIDXY': 'SomeKey1',
|
||||||
|
'ed25519:SomeKey2': 'SomeKey3'},
|
||||||
|
keys='SomeCryptoKey4')
|
||||||
|
"""
|
||||||
|
sas = client.key_verifications[event.transaction_id]
|
||||||
|
try:
|
||||||
|
todevice_msg = sas.get_mac()
|
||||||
|
except LocalProtocolError as e:
|
||||||
|
# e.g. it might have been cancelled by ourselves
|
||||||
|
estr = (f"Cancelled or protocol error: Reason: {e}.\n"
|
||||||
|
f"Verification with {event.sender} not concluded. "
|
||||||
|
"Try again?")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
else:
|
||||||
|
resp = await client.to_device(todevice_msg)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = f"to_device failed with {resp}"
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
estr = (f"sas.we_started_it = {sas.we_started_it}\n"
|
||||||
|
f"sas.sas_accepted = {sas.sas_accepted}\n"
|
||||||
|
f"sas.canceled = {sas.canceled}\n"
|
||||||
|
f"sas.timed_out = {sas.timed_out}\n"
|
||||||
|
f"sas.verified = {sas.verified}\n"
|
||||||
|
f"sas.verified_devices = {sas.verified_devices}\n")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
estr = ("Emoji verification was successful!\n"
|
||||||
|
"Initiate another Emoji verification from "
|
||||||
|
"another device or room if desired. "
|
||||||
|
"Or if done verifying, hit Control-C to stop the "
|
||||||
|
"bot in order to restart it as a service or to "
|
||||||
|
"run it in the background.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
else:
|
||||||
|
estr = (f"Received unexpected event type {type(event)}. "
|
||||||
|
f"Event is {event}. Event will be ignored.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
except BaseException:
|
||||||
|
estr = traceback.format_exc()
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
# !chat command
|
||||||
|
async def chat(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
|
||||||
|
await self.client.room_typing(room_id, timeout=180000)
|
||||||
|
try:
|
||||||
|
text = await asyncio.wait_for(self.chatbot.ask_async(prompt), timeout=180)
|
||||||
|
except TimeoutError as e:
|
||||||
|
logger.error("timeoutException", exc_info=True)
|
||||||
|
text = "Timeout error"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error", exc_info=True)
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
try:
|
||||||
|
await send_room_message(self.client, room_id, reply_message=text,
|
||||||
|
reply_to_event_id="", sender_id=sender_id, user_message=raw_user_message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# !gpt command
|
||||||
|
async def gpt(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
|
||||||
|
try:
|
||||||
|
# sending typing state
|
||||||
|
await self.client.room_typing(room_id, timeout=180000)
|
||||||
|
# timeout 120s
|
||||||
|
text = await asyncio.wait_for(self.askgpt.oneTimeAsk(prompt, self.chatgpt_api_endpoint, self.headers), timeout=180)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error("timeoutException", exc_info=True)
|
||||||
|
text = "Timeout error"
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
try:
|
||||||
|
await send_room_message(self.client, room_id, reply_message=text,
|
||||||
|
reply_to_event_id="", sender_id=sender_id, user_message=raw_user_message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# !bing command
|
||||||
|
async def bing(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
|
||||||
|
try:
|
||||||
|
# sending typing state
|
||||||
|
await self.client.room_typing(room_id, timeout=180000)
|
||||||
|
# timeout 120s
|
||||||
|
text = await asyncio.wait_for(self.bingbot.ask_bing(prompt), timeout=180)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error("timeoutException", exc_info=True)
|
||||||
|
text = "Timeout error"
|
||||||
|
text = text.strip()
|
||||||
|
try:
|
||||||
|
await send_room_message(self.client, room_id, reply_message=text,
|
||||||
|
reply_to_event_id="", sender_id=sender_id, user_message=raw_user_message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# !pic command
|
||||||
|
async def pic(self, room_id, prompt):
|
||||||
|
try:
|
||||||
|
await self.client.room_typing(room_id, timeout=180000)
|
||||||
|
# generate image
|
||||||
|
try:
|
||||||
|
|
||||||
|
links = await self.imageGen.get_images(prompt)
|
||||||
|
image_path = await self.imageGen.save_images(links, "images")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image Generation error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# send image
|
||||||
|
try:
|
||||||
|
await send_room_image(self.client, room_id, image_path)
|
||||||
|
await self.client.room_typing(room_id, typing_state=False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# !help command
|
||||||
|
async def help(self, room_id):
|
||||||
|
try:
|
||||||
|
# sending typing state
|
||||||
|
await self.client.room_typing(room_id)
|
||||||
|
help_info = "!gpt [content], generate response without context conversation\n" + \
|
||||||
|
"!chat [content], chat with context conversation\n" + \
|
||||||
|
"!bing [content], chat with context conversation powered by Bing AI\n" + \
|
||||||
|
"!pic [prompt], Image generation by Microsoft Bing\n" + \
|
||||||
|
"!help, help message"
|
||||||
|
|
||||||
|
await send_room_message(self.client, room_id, reply_message=help_info)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# bot login
|
||||||
|
async def login(self) -> None:
|
||||||
|
if self.access_token is not None:
|
||||||
|
logger.info("Login via access_token")
|
||||||
|
else:
|
||||||
|
logger.info("Login via password")
|
||||||
|
try:
|
||||||
|
resp = await self.client.login(password=self.password)
|
||||||
|
if not isinstance(resp, LoginResponse):
|
||||||
|
logger.error("Login Failed")
|
||||||
|
print(f"Login Failed: {resp}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# sync messages in the room
|
||||||
|
async def sync_forever(self, timeout=30000, full_state=True) -> None:
|
||||||
|
|
||||||
|
await self.client.sync_forever(timeout=timeout, full_state=full_state)
|
||||||
|
|
||||||
|
# Sync encryption keys with the server
|
||||||
|
async def sync_encryption_key(self) -> None:
|
||||||
|
if self.client.should_upload_keys:
|
||||||
|
await self.client.keys_upload()
|
||||||
|
|
||||||
|
# Trust own devices
|
||||||
|
async def trust_own_devices(self) -> None:
|
||||||
|
await self.client.sync(timeout=30000, full_state=True)
|
||||||
|
for device_id, olm_device in self.client.device_store[
|
||||||
|
self.user_id].items():
|
||||||
|
logger.debug("My other devices are: "
|
||||||
|
f"device_id={device_id}, "
|
||||||
|
f"olm_device={olm_device}.")
|
||||||
|
logger.info("Setting up trust for my own "
|
||||||
|
f"device {device_id} and session key "
|
||||||
|
f"{olm_device.keys['ed25519']}.")
|
||||||
|
self.client.verify_device(olm_device)
|
19
compose.yaml
19
compose.yaml
|
@ -2,7 +2,7 @@ services:
|
||||||
app:
|
app:
|
||||||
image: hibobmaster/matrixchatgptbot:latest
|
image: hibobmaster/matrixchatgptbot:latest
|
||||||
container_name: matrix_chatgpt_bot
|
container_name: matrix_chatgpt_bot
|
||||||
restart: unless-stopped
|
restart: always
|
||||||
# build:
|
# build:
|
||||||
# context: .
|
# context: .
|
||||||
# dockerfile: ./Dockerfile
|
# dockerfile: ./Dockerfile
|
||||||
|
@ -11,14 +11,19 @@ services:
|
||||||
volumes:
|
volumes:
|
||||||
# use env file or config.json
|
# use env file or config.json
|
||||||
# - ./config.json:/app/config.json
|
# - ./config.json:/app/config.json
|
||||||
# use touch to create empty db file, for persist database only
|
# use touch to create an empty file db, for persist database only
|
||||||
# manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database
|
- ./db:/app/db
|
||||||
- ./sync_db:/app/sync_db
|
|
||||||
# - ./manage_db:/app/manage_db
|
|
||||||
# import_keys path
|
|
||||||
# - ./element-keys.txt:/app/element-keys.txt
|
|
||||||
networks:
|
networks:
|
||||||
- matrix_network
|
- matrix_network
|
||||||
|
# api:
|
||||||
|
# # bing api
|
||||||
|
# image: hibobmaster/node-chatgpt-api:latest
|
||||||
|
# container_name: node-chatgpt-api
|
||||||
|
# restart: always
|
||||||
|
# volumes:
|
||||||
|
# - ./settings.js:/var/chatgpt-api/settings.js
|
||||||
|
# networks:
|
||||||
|
# - matrix_network
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
matrix_network:
|
matrix_network:
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
{
|
|
||||||
"homeserver": "https://matrix-client.matrix.org",
|
|
||||||
"user_id": "@lullap:xxxxx.org",
|
|
||||||
"password": "xxxxxxxxxxxxxxxxxx",
|
|
||||||
"device_id": "MatrixChatGPTBot",
|
|
||||||
"openai_api_key": "xxxxxxxxxxxxxxxxxxxxxxxx"
|
|
||||||
}
|
|
9
config.json.sample
Normal file
9
config.json.sample
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
{
|
||||||
|
"homeserver": "https://matrix.qqs.tw",
|
||||||
|
"user_id": "@lullap:xxxxx.org",
|
||||||
|
"password": "xxxxxxxxxxxxxxxxxx",
|
||||||
|
"device_id": "ECYEOKVPLG",
|
||||||
|
"room_id": "!FYCmBSkCRUNvZDBaDQ:matrix.qqs.tw",
|
||||||
|
"api_key": "xxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"access_token": "xxxxxxx"
|
||||||
|
}
|
|
@ -1,29 +0,0 @@
|
||||||
{
|
|
||||||
"homeserver": "https://matrix-client.matrix.org",
|
|
||||||
"user_id": "@lullap:xxxxx.org",
|
|
||||||
"password": "xxxxxxxxxxxxxxxxxx",
|
|
||||||
"access_token": "xxxxxxxxxxxxxx",
|
|
||||||
"device_id": "MatrixChatGPTBot",
|
|
||||||
"room_id": "!xxxxxxxxxxxxxxxxxxxxxx:xxxxx.org",
|
|
||||||
"import_keys_path": "element-keys.txt",
|
|
||||||
"import_keys_password": "xxxxxxxxxxxxxxxxxxxx",
|
|
||||||
"openai_api_key": "xxxxxxxxxxxxxxxxxxxxxxxx",
|
|
||||||
"gpt_api_endpoint": "https://api.openai.com/v1/chat/completions",
|
|
||||||
"gpt_model": "gpt-3.5-turbo",
|
|
||||||
"max_tokens": 4000,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"presence_penalty": 0.0,
|
|
||||||
"frequency_penalty": 0.0,
|
|
||||||
"reply_count": 1,
|
|
||||||
"temperature": 0.8,
|
|
||||||
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
|
||||||
"lc_admin": ["@admin:xxxxx.org"],
|
|
||||||
"image_generation_endpoint": "http://localai:8080/v1/images/generations",
|
|
||||||
"image_generation_backend": "localai",
|
|
||||||
"image_generation_size": "512x512",
|
|
||||||
"sdwui_steps": 20,
|
|
||||||
"sdwui_sampler_name": "Euler a",
|
|
||||||
"sdwui_cfg_scale": 7,
|
|
||||||
"image_format": "webp",
|
|
||||||
"timeout": 120.0
|
|
||||||
}
|
|
31
log.py
Normal file
31
log.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def getlogger():
|
||||||
|
# create a custom logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# create handlers
|
||||||
|
warn_handler = logging.StreamHandler()
|
||||||
|
info_handler = logging.StreamHandler()
|
||||||
|
error_handler = logging.FileHandler('bot.log', mode='a')
|
||||||
|
warn_handler.setLevel(logging.WARNING)
|
||||||
|
error_handler.setLevel(logging.ERROR)
|
||||||
|
info_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# create formatters
|
||||||
|
warn_format = logging.Formatter('%(name)s - %(funcName)s - %(levelname)s - %(message)s')
|
||||||
|
error_format = logging.Formatter('%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s')
|
||||||
|
info_format = logging.Formatter('%(message)s')
|
||||||
|
|
||||||
|
# set formatter
|
||||||
|
warn_handler.setFormatter(warn_format)
|
||||||
|
error_handler.setFormatter(error_format)
|
||||||
|
info_handler.setFormatter(info_format)
|
||||||
|
|
||||||
|
# add handlers to logger
|
||||||
|
logger.addHandler(warn_handler)
|
||||||
|
logger.addHandler(error_handler)
|
||||||
|
logger.addHandler(info_handler)
|
||||||
|
|
||||||
|
return logger
|
52
main.py
Normal file
52
main.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from bot import Bot
|
||||||
|
from log import getlogger
|
||||||
|
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
if os.path.exists('config.json'):
|
||||||
|
fp = open('config.json', 'r', encoding="utf8")
|
||||||
|
config = json.load(fp)
|
||||||
|
|
||||||
|
matrix_bot = Bot(homeserver=config.get('homeserver'),
|
||||||
|
user_id=config.get('user_id') ,
|
||||||
|
password=config.get('password'),
|
||||||
|
device_id=config.get('device_id'),
|
||||||
|
room_id=config.get('room_id'),
|
||||||
|
api_key=config.get('api_key'),
|
||||||
|
bing_api_endpoint=config.get('bing_api_endpoint'),
|
||||||
|
access_token=config.get('access_token'),
|
||||||
|
jailbreakEnabled=config.get('jailbreakEnabled'),
|
||||||
|
bing_auth_cookie=config.get('bing_auth_cookie'),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
matrix_bot = Bot(homeserver=os.environ.get('HOMESERVER'),
|
||||||
|
user_id=os.environ.get('USER_ID') ,
|
||||||
|
password=os.environ.get('PASSWORD'),
|
||||||
|
device_id=os.environ.get("DEVICE_ID"),
|
||||||
|
room_id=os.environ.get("ROOM_ID"),
|
||||||
|
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||||
|
bing_api_endpoint=os.environ.get("BING_API_ENDPOINT"),
|
||||||
|
access_token=os.environ.get("ACCESS_TOKEN"),
|
||||||
|
jailbreakEnabled=os.environ.get("JAILBREAKENABLED"),
|
||||||
|
bing_auth_cookie=os.environ.get("BING_AUTH_COOKIE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
await matrix_bot.login()
|
||||||
|
await matrix_bot.sync_encryption_key()
|
||||||
|
|
||||||
|
# await matrix_bot.trust_own_devices()
|
||||||
|
|
||||||
|
await matrix_bot.sync_forever(timeout=30000, full_state=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("matrix chatgpt bot start.....")
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
aiofiles
|
|
||||||
httpx
|
|
||||||
Markdown
|
|
||||||
matrix-nio[e2e]
|
|
||||||
Pillow
|
|
||||||
tiktoken
|
|
||||||
tenacity
|
|
||||||
python-magic
|
|
||||||
pytest
|
|
|
@ -1,8 +1,51 @@
|
||||||
aiofiles
|
aiofiles==23.1.0
|
||||||
httpx
|
aiohttp==3.8.4
|
||||||
Markdown
|
aiohttp-socks==0.7.1
|
||||||
matrix-nio[e2e]
|
aiosignal==1.3.1
|
||||||
Pillow
|
anyio==3.6.2
|
||||||
tiktoken
|
async-timeout==4.0.2
|
||||||
tenacity
|
atomicwrites==1.4.1
|
||||||
python-magic
|
attrs==22.2.0
|
||||||
|
blobfile==2.0.1
|
||||||
|
cachetools==4.2.4
|
||||||
|
certifi==2022.12.7
|
||||||
|
cffi==1.15.1
|
||||||
|
charset-normalizer==3.1.0
|
||||||
|
cryptography==40.0.1
|
||||||
|
filelock==3.11.0
|
||||||
|
frozenlist==1.3.3
|
||||||
|
future==0.18.3
|
||||||
|
h11==0.14.0
|
||||||
|
h2==4.1.0
|
||||||
|
hpack==4.0.0
|
||||||
|
httpcore==0.16.3
|
||||||
|
httpx==0.23.3
|
||||||
|
hyperframe==6.0.1
|
||||||
|
idna==3.4
|
||||||
|
jsonschema==4.17.3
|
||||||
|
Logbook==1.5.3
|
||||||
|
lxml==4.9.2
|
||||||
|
Markdown==3.4.3
|
||||||
|
matrix-nio[e2e]==0.20.2
|
||||||
|
multidict==6.0.4
|
||||||
|
peewee==3.16.0
|
||||||
|
Pillow==9.5.0
|
||||||
|
pycparser==2.21
|
||||||
|
pycryptodome==3.17
|
||||||
|
pycryptodomex==3.17
|
||||||
|
pyrsistent==0.19.3
|
||||||
|
python-cryptography-fernet-wrapper==1.0.4
|
||||||
|
python-magic==0.4.27
|
||||||
|
python-olm==3.1.3
|
||||||
|
python-socks==2.2.0
|
||||||
|
regex==2023.3.23
|
||||||
|
requests==2.28.2
|
||||||
|
rfc3986==1.5.0
|
||||||
|
six==1.16.0
|
||||||
|
sniffio==1.3.0
|
||||||
|
tiktoken==0.3.3
|
||||||
|
toml==0.10.2
|
||||||
|
unpaddedbase64==2.1.0
|
||||||
|
urllib3==1.26.15
|
||||||
|
wcwidth==0.2.6
|
||||||
|
yarl==1.8.2
|
||||||
|
|
|
@ -3,18 +3,17 @@ code derived from:
|
||||||
https://matrix-nio.readthedocs.io/en/latest/examples.html#sending-an-image
|
https://matrix-nio.readthedocs.io/en/latest/examples.html#sending-an-image
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import aiofiles.os
|
import aiofiles.os
|
||||||
import magic
|
import magic
|
||||||
from log import getlogger
|
|
||||||
from nio import AsyncClient
|
|
||||||
from nio import UploadResponse
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from nio import AsyncClient, UploadResponse
|
||||||
|
from log import getlogger
|
||||||
|
|
||||||
logger = getlogger()
|
logger = getlogger()
|
||||||
|
|
||||||
|
|
||||||
async def send_room_image(client: AsyncClient, room_id: str, image: str):
|
async def send_room_image(client: AsyncClient,
|
||||||
|
room_id: str, image: str):
|
||||||
"""
|
"""
|
||||||
image: image path
|
image: image path
|
||||||
"""
|
"""
|
||||||
|
@ -33,14 +32,11 @@ async def send_room_image(client: AsyncClient, room_id: str, image: str):
|
||||||
filesize=file_stat.st_size,
|
filesize=file_stat.st_size,
|
||||||
)
|
)
|
||||||
if not isinstance(resp, UploadResponse):
|
if not isinstance(resp, UploadResponse):
|
||||||
logger.warning(f"Failed to upload image. Failure response: {resp}")
|
logger.warning(f"Failed to generate image. Failure response: {resp}")
|
||||||
await client.room_send(
|
await client.room_send(
|
||||||
room_id,
|
room_id,
|
||||||
message_type="m.room.message",
|
message_type="m.room.message",
|
||||||
content={
|
content={"msgtype": "m.text", "body": f"Failed to generate image. Failure response: {resp}", },
|
||||||
"msgtype": "m.text",
|
|
||||||
"body": f"Failed to upload image. Failure response: {resp}",
|
|
||||||
},
|
|
||||||
ignore_unverified_devices=True,
|
ignore_unverified_devices=True,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -61,4 +57,3 @@ async def send_room_image(client: AsyncClient, room_id: str, image: str):
|
||||||
await client.room_send(room_id, message_type="m.room.message", content=content)
|
await client.room_send(room_id, message_type="m.room.message", content=content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Image send of file {image} failed.\n Error: {e}", exc_info=True)
|
logger.error(f"Image send of file {image} failed.\n Error: {e}", exc_info=True)
|
||||||
raise Exception(e)
|
|
26
send_message.py
Normal file
26
send_message.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from nio import AsyncClient
|
||||||
|
|
||||||
|
async def send_room_message(client: AsyncClient,
|
||||||
|
room_id: str,
|
||||||
|
reply_message: str,
|
||||||
|
sender_id: str = '',
|
||||||
|
user_message: str = '',
|
||||||
|
reply_to_event_id: str = '') -> None:
|
||||||
|
if reply_to_event_id == '':
|
||||||
|
content = {"msgtype": "m.text", "body": reply_message, }
|
||||||
|
else:
|
||||||
|
body = r'> <' + sender_id + r'> ' + user_message + r'\n\n' + reply_message
|
||||||
|
format = r'org.matrix.custom.html'
|
||||||
|
formatted_body = r'<mx-reply><blockquote><a href="https://matrix.to/#/' + room_id + r'/' + reply_to_event_id \
|
||||||
|
+ r'">In reply to</a> <a href="https://matrix.to/#/' + sender_id + r'">' + sender_id \
|
||||||
|
+ r'</a><br>' + user_message + r'</blockquote></mx-reply>' + reply_message
|
||||||
|
|
||||||
|
content={"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body,
|
||||||
|
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
|
||||||
|
await client.room_send(
|
||||||
|
room_id,
|
||||||
|
message_type="m.room.message",
|
||||||
|
content=content,
|
||||||
|
ignore_unverified_devices=True,
|
||||||
|
)
|
||||||
|
await client.room_typing(room_id, typing_state=False)
|
1496
src/bot.py
1496
src/bot.py
File diff suppressed because it is too large
Load diff
|
@ -1,41 +0,0 @@
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
async def flowise_query(
|
|
||||||
api_url: str, prompt: str, session: httpx.AsyncClient, headers: dict = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Sends a query to the Flowise API and returns the response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_url (str): The URL of the Flowise API.
|
|
||||||
prompt (str): The question to ask the API.
|
|
||||||
session (httpx.AsyncClient): The httpx session to use.
|
|
||||||
headers (dict, optional): The headers to use. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The response from the API.
|
|
||||||
"""
|
|
||||||
if headers:
|
|
||||||
response = await session.post(
|
|
||||||
api_url,
|
|
||||||
json={"question": prompt},
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await session.post(api_url, json={"question": prompt})
|
|
||||||
return response.text
|
|
||||||
|
|
||||||
|
|
||||||
async def test():
|
|
||||||
async with httpx.AsyncClient() as session:
|
|
||||||
api_url = "http://127.0.0.1:3000/api/v1/prediction/683f9ea8-e670-4d51-b657-0886eab9cea1"
|
|
||||||
prompt = "What is the capital of France?"
|
|
||||||
response = await flowise_query(api_url, prompt, session)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(test())
|
|
106
src/imagegen.py
106
src/imagegen.py
|
@ -1,106 +0,0 @@
|
||||||
import httpx
|
|
||||||
from pathlib import Path
|
|
||||||
import uuid
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
async def get_images(
|
|
||||||
aclient: httpx.AsyncClient,
|
|
||||||
url: str,
|
|
||||||
prompt: str,
|
|
||||||
backend_type: str,
|
|
||||||
output_path: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> list[str]:
|
|
||||||
timeout = kwargs.get("timeout", 180.0)
|
|
||||||
if backend_type == "openai":
|
|
||||||
resp = await aclient.post(
|
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {kwargs.get('api_key')}",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"prompt": prompt,
|
|
||||||
"n": kwargs.get("n", 1),
|
|
||||||
"size": kwargs.get("size", "512x512"),
|
|
||||||
"response_format": "b64_json",
|
|
||||||
},
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
b64_datas = []
|
|
||||||
for data in resp.json()["data"]:
|
|
||||||
b64_datas.append(data["b64_json"])
|
|
||||||
return save_images_b64(b64_datas, output_path, **kwargs)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
|
||||||
)
|
|
||||||
elif backend_type == "sdwui":
|
|
||||||
resp = await aclient.post(
|
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"prompt": prompt,
|
|
||||||
"sampler_name": kwargs.get("sampler_name", "Euler a"),
|
|
||||||
"cfg_scale": kwargs.get("cfg_scale", 7),
|
|
||||||
"batch_size": kwargs.get("n", 1),
|
|
||||||
"steps": kwargs.get("steps", 20),
|
|
||||||
"width": kwargs.get("width", 512),
|
|
||||||
"height": kwargs.get("height", 512),
|
|
||||||
},
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
b64_datas = resp.json()["images"]
|
|
||||||
return save_images_b64(b64_datas, output_path, **kwargs)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
|
||||||
)
|
|
||||||
elif backend_type == "localai":
|
|
||||||
resp = await aclient.post(
|
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {kwargs.get('api_key')}",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"prompt": prompt,
|
|
||||||
"size": kwargs.get("size", "512x512"),
|
|
||||||
},
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
image_url = resp.json()["data"][0]["url"]
|
|
||||||
return await save_image_url(image_url, aclient, output_path, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def save_images_b64(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
|
||||||
images_path_list = []
|
|
||||||
for b64_data in b64_datas:
|
|
||||||
image_path = path / (
|
|
||||||
str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg")
|
|
||||||
)
|
|
||||||
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
|
|
||||||
img.save(image_path)
|
|
||||||
images_path_list.append(image_path)
|
|
||||||
return images_path_list
|
|
||||||
|
|
||||||
|
|
||||||
async def save_image_url(
|
|
||||||
url: str, aclient: httpx.AsyncClient, path: Path, **kwargs
|
|
||||||
) -> list[str]:
|
|
||||||
images_path_list = []
|
|
||||||
r = await aclient.get(url)
|
|
||||||
image_path = path / (str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg"))
|
|
||||||
if r.status_code == 200:
|
|
||||||
img = Image.open(io.BytesIO(r.content))
|
|
||||||
img.save(image_path)
|
|
||||||
images_path_list.append(image_path)
|
|
||||||
return images_path_list
|
|
|
@ -1,200 +0,0 @@
|
||||||
import sqlite3
|
|
||||||
import sys
|
|
||||||
from log import getlogger
|
|
||||||
|
|
||||||
logger = getlogger()
|
|
||||||
|
|
||||||
|
|
||||||
class LCManager:
|
|
||||||
def __init__(self):
|
|
||||||
try:
|
|
||||||
self.conn = sqlite3.connect("manage_db")
|
|
||||||
self.c = self.conn.cursor()
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS lc_commands (
|
|
||||||
command_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
username TEXT NOT NULL,
|
|
||||||
agent TEXT NOT NULL,
|
|
||||||
api_url TEXT NOT NULL,
|
|
||||||
api_key TEXT,
|
|
||||||
permission INTEGER NOT NULL
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
def add_command(
|
|
||||||
self,
|
|
||||||
username: str,
|
|
||||||
agent: str,
|
|
||||||
api_url: str,
|
|
||||||
api_key: str = None,
|
|
||||||
permission: int = 0,
|
|
||||||
) -> None:
|
|
||||||
# check if username and agent already exists
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT username, agent FROM lc_commands
|
|
||||||
WHERE username = ? AND agent = ?
|
|
||||||
""",
|
|
||||||
(username, agent),
|
|
||||||
)
|
|
||||||
if self.c.fetchone() is not None:
|
|
||||||
raise Exception("agent already exists")
|
|
||||||
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO lc_commands (username, agent, api_url, api_key, permission)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
""",
|
|
||||||
(username, agent, api_url, api_key, permission),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def get_command_api_url(self, username: str, agent: str) -> list[any]:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT api_url FROM lc_commands
|
|
||||||
WHERE username = ? AND agent = ?
|
|
||||||
""",
|
|
||||||
(username, agent),
|
|
||||||
)
|
|
||||||
return self.c.fetchall()
|
|
||||||
|
|
||||||
def get_command_api_key(self, username: str, agent: str) -> list[any]:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT api_key FROM lc_commands
|
|
||||||
WHERE username = ? AND agent = ?
|
|
||||||
""",
|
|
||||||
(username, agent),
|
|
||||||
)
|
|
||||||
return self.c.fetchall()
|
|
||||||
|
|
||||||
def get_command_permission(self, username: str, agent: str) -> list[any]:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT permission FROM lc_commands
|
|
||||||
WHERE username = ? AND agent = ?
|
|
||||||
""",
|
|
||||||
(username, agent),
|
|
||||||
)
|
|
||||||
return self.c.fetchall()
|
|
||||||
|
|
||||||
def get_command_agent(self, username: str) -> list[any]:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT agent FROM lc_commands
|
|
||||||
WHERE username = ?
|
|
||||||
""",
|
|
||||||
(username,),
|
|
||||||
)
|
|
||||||
return self.c.fetchall()
|
|
||||||
|
|
||||||
def get_specific_by_username(self, username: str) -> list[any]:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT * FROM lc_commands
|
|
||||||
WHERE username = ?
|
|
||||||
""",
|
|
||||||
(username,),
|
|
||||||
)
|
|
||||||
return self.c.fetchall()
|
|
||||||
|
|
||||||
def get_specific_by_agent(self, agent: str) -> list[any]:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT * FROM lc_commands
|
|
||||||
WHERE agent = ?
|
|
||||||
""",
|
|
||||||
(agent,),
|
|
||||||
)
|
|
||||||
return self.c.fetchall()
|
|
||||||
|
|
||||||
def get_all(self) -> list[any]:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT * FROM lc_commands
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return self.c.fetchall()
|
|
||||||
|
|
||||||
def update_command_api_url(self, username: str, agent: str, api_url: str) -> None:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
UPDATE lc_commands
|
|
||||||
SET api_url = ?
|
|
||||||
WHERE username = ? AND agent = ?
|
|
||||||
""",
|
|
||||||
(api_url, username, agent),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def update_command_api_key(self, username: str, agent: str, api_key: str) -> None:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
UPDATE lc_commands
|
|
||||||
SET api_key = ?
|
|
||||||
WHERE username = ? AND agent = ?
|
|
||||||
""",
|
|
||||||
(api_key, username, agent),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def update_command_permission(
|
|
||||||
self, username: str, agent: str, permission: int
|
|
||||||
) -> None:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
UPDATE lc_commands
|
|
||||||
SET permission = ?
|
|
||||||
WHERE username = ? AND agent = ?
|
|
||||||
""",
|
|
||||||
(permission, username, agent),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def update_command_agent(self, username: str, agent: str, api_url: str) -> None:
|
|
||||||
# check if agent already exists
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
SELECT agent FROM lc_commands
|
|
||||||
WHERE agent = ?
|
|
||||||
""",
|
|
||||||
(agent,),
|
|
||||||
)
|
|
||||||
if self.c.fetchone() is not None:
|
|
||||||
raise Exception("agent already exists")
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
UPDATE lc_commands
|
|
||||||
SET agent = ?
|
|
||||||
WHERE username = ? AND api_url = ?
|
|
||||||
""",
|
|
||||||
(agent, username, api_url),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def delete_command(self, username: str, agent: str) -> None:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM lc_commands
|
|
||||||
WHERE username = ? AND agent = ?
|
|
||||||
""",
|
|
||||||
(username, agent),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def delete_commands(self, username: str) -> None:
|
|
||||||
self.c.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM lc_commands
|
|
||||||
WHERE username = ?
|
|
||||||
""",
|
|
||||||
(username,),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
40
src/log.py
40
src/log.py
|
@ -1,40 +0,0 @@
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
log_path = Path(os.path.dirname(__file__)).parent / "bot.log"
|
|
||||||
|
|
||||||
|
|
||||||
def getlogger():
|
|
||||||
# create a custom logger if no log handler
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
if not logger.hasHandlers():
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
# create handlers
|
|
||||||
warn_handler = logging.StreamHandler()
|
|
||||||
info_handler = logging.StreamHandler()
|
|
||||||
error_handler = logging.FileHandler("bot.log", mode="a")
|
|
||||||
warn_handler.setLevel(logging.WARNING)
|
|
||||||
error_handler.setLevel(logging.ERROR)
|
|
||||||
info_handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# create formatters
|
|
||||||
warn_format = logging.Formatter(
|
|
||||||
"%(asctime)s - %(funcName)s - %(levelname)s - %(message)s",
|
|
||||||
)
|
|
||||||
error_format = logging.Formatter(
|
|
||||||
"%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s",
|
|
||||||
)
|
|
||||||
info_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
||||||
|
|
||||||
# set formatter
|
|
||||||
warn_handler.setFormatter(warn_format)
|
|
||||||
error_handler.setFormatter(error_format)
|
|
||||||
info_handler.setFormatter(info_format)
|
|
||||||
|
|
||||||
# add handlers to logger
|
|
||||||
logger.addHandler(warn_handler)
|
|
||||||
logger.addHandler(error_handler)
|
|
||||||
logger.addHandler(info_handler)
|
|
||||||
|
|
||||||
return logger
|
|
121
src/main.py
121
src/main.py
|
@ -1,121 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from bot import Bot
|
|
||||||
from log import getlogger
|
|
||||||
|
|
||||||
logger = getlogger()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
need_import_keys = False
|
|
||||||
config_path = Path(os.path.dirname(__file__)).parent / "config.json"
|
|
||||||
if os.path.isfile(config_path):
|
|
||||||
try:
|
|
||||||
fp = open(config_path, encoding="utf8")
|
|
||||||
config = json.load(fp)
|
|
||||||
except Exception:
|
|
||||||
logger.error("config.json load error, please check the file")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
matrix_bot = Bot(
|
|
||||||
homeserver=config.get("homeserver"),
|
|
||||||
user_id=config.get("user_id"),
|
|
||||||
password=config.get("password"),
|
|
||||||
access_token=config.get("access_token"),
|
|
||||||
device_id=config.get("device_id"),
|
|
||||||
room_id=config.get("room_id"),
|
|
||||||
import_keys_path=config.get("import_keys_path"),
|
|
||||||
import_keys_password=config.get("import_keys_password"),
|
|
||||||
openai_api_key=config.get("openai_api_key"),
|
|
||||||
gpt_api_endpoint=config.get("gpt_api_endpoint"),
|
|
||||||
gpt_model=config.get("gpt_model"),
|
|
||||||
max_tokens=config.get("max_tokens"),
|
|
||||||
top_p=config.get("top_p"),
|
|
||||||
presence_penalty=config.get("presence_penalty"),
|
|
||||||
frequency_penalty=config.get("frequency_penalty"),
|
|
||||||
reply_count=config.get("reply_count"),
|
|
||||||
system_prompt=config.get("system_prompt"),
|
|
||||||
temperature=config.get("temperature"),
|
|
||||||
lc_admin=config.get("lc_admin"),
|
|
||||||
image_generation_endpoint=config.get("image_generation_endpoint"),
|
|
||||||
image_generation_backend=config.get("image_generation_backend"),
|
|
||||||
image_generation_size=config.get("image_generation_size"),
|
|
||||||
sdwui_steps=config.get("sdwui_steps"),
|
|
||||||
sdwui_sampler_name=config.get("sdwui_sampler_name"),
|
|
||||||
sdwui_cfg_scale=config.get("sdwui_cfg_scale"),
|
|
||||||
image_format=config.get("image_format"),
|
|
||||||
timeout=config.get("timeout"),
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
config.get("import_keys_path")
|
|
||||||
and config.get("import_keys_password") is not None
|
|
||||||
):
|
|
||||||
need_import_keys = True
|
|
||||||
|
|
||||||
else:
|
|
||||||
matrix_bot = Bot(
|
|
||||||
homeserver=os.environ.get("HOMESERVER"),
|
|
||||||
user_id=os.environ.get("USER_ID"),
|
|
||||||
password=os.environ.get("PASSWORD"),
|
|
||||||
access_token=os.environ.get("ACCESS_TOKEN"),
|
|
||||||
device_id=os.environ.get("DEVICE_ID"),
|
|
||||||
room_id=os.environ.get("ROOM_ID"),
|
|
||||||
import_keys_path=os.environ.get("IMPORT_KEYS_PATH"),
|
|
||||||
import_keys_password=os.environ.get("IMPORT_KEYS_PASSWORD"),
|
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY"),
|
|
||||||
gpt_api_endpoint=os.environ.get("GPT_API_ENDPOINT"),
|
|
||||||
gpt_model=os.environ.get("GPT_MODEL"),
|
|
||||||
max_tokens=int(os.environ.get("MAX_TOKENS", 4000)),
|
|
||||||
top_p=float(os.environ.get("TOP_P", 1.0)),
|
|
||||||
presence_penalty=float(os.environ.get("PRESENCE_PENALTY", 0.0)),
|
|
||||||
frequency_penalty=float(os.environ.get("FREQUENCY_PENALTY", 0.0)),
|
|
||||||
reply_count=int(os.environ.get("REPLY_COUNT", 1)),
|
|
||||||
system_prompt=os.environ.get("SYSTEM_PROMPT"),
|
|
||||||
temperature=float(os.environ.get("TEMPERATURE", 0.8)),
|
|
||||||
lc_admin=os.environ.get("LC_ADMIN"),
|
|
||||||
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
|
|
||||||
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
|
|
||||||
image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"),
|
|
||||||
sdwui_steps=int(os.environ.get("SDWUI_STEPS", 20)),
|
|
||||||
sdwui_sampler_name=os.environ.get("SDWUI_SAMPLER_NAME"),
|
|
||||||
sdwui_cfg_scale=float(os.environ.get("SDWUI_CFG_SCALE", 7)),
|
|
||||||
image_format=os.environ.get("IMAGE_FORMAT"),
|
|
||||||
timeout=float(os.environ.get("TIMEOUT", 120.0)),
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
os.environ.get("IMPORT_KEYS_PATH")
|
|
||||||
and os.environ.get("IMPORT_KEYS_PASSWORD") is not None
|
|
||||||
):
|
|
||||||
need_import_keys = True
|
|
||||||
|
|
||||||
await matrix_bot.login()
|
|
||||||
if need_import_keys:
|
|
||||||
logger.info("start import_keys process, this may take a while...")
|
|
||||||
await matrix_bot.import_keys()
|
|
||||||
|
|
||||||
sync_task = asyncio.create_task(
|
|
||||||
matrix_bot.sync_forever(timeout=30000, full_state=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle signal interrupt
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
for signame in ("SIGINT", "SIGTERM"):
|
|
||||||
loop.add_signal_handler(
|
|
||||||
getattr(signal, signame),
|
|
||||||
lambda: asyncio.create_task(matrix_bot.close(sync_task)),
|
|
||||||
)
|
|
||||||
|
|
||||||
if matrix_bot.client.should_upload_keys:
|
|
||||||
await matrix_bot.client.keys_upload()
|
|
||||||
|
|
||||||
await sync_task
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logger.info("matrix chatgpt bot start.....")
|
|
||||||
asyncio.run(main())
|
|
|
@ -1,63 +0,0 @@
|
||||||
import markdown
|
|
||||||
from log import getlogger
|
|
||||||
from nio import AsyncClient
|
|
||||||
|
|
||||||
logger = getlogger()
|
|
||||||
|
|
||||||
|
|
||||||
async def send_room_message(
|
|
||||||
client: AsyncClient,
|
|
||||||
room_id: str,
|
|
||||||
reply_message: str,
|
|
||||||
sender_id: str = "",
|
|
||||||
user_message: str = "",
|
|
||||||
reply_to_event_id: str = "",
|
|
||||||
) -> None:
|
|
||||||
if reply_to_event_id == "":
|
|
||||||
content = {
|
|
||||||
"msgtype": "m.text",
|
|
||||||
"body": reply_message,
|
|
||||||
"format": "org.matrix.custom.html",
|
|
||||||
"formatted_body": markdown.markdown(
|
|
||||||
reply_message,
|
|
||||||
extensions=["nl2br", "tables", "fenced_code"],
|
|
||||||
),
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
body = "> <" + sender_id + "> " + user_message + "\n\n" + reply_message
|
|
||||||
format = r"org.matrix.custom.html"
|
|
||||||
formatted_body = (
|
|
||||||
r'<mx-reply><blockquote><a href="https://matrix.to/#/'
|
|
||||||
+ room_id
|
|
||||||
+ r"/"
|
|
||||||
+ reply_to_event_id
|
|
||||||
+ r'">In reply to</a> <a href="https://matrix.to/#/'
|
|
||||||
+ sender_id
|
|
||||||
+ r'">'
|
|
||||||
+ sender_id
|
|
||||||
+ r"</a><br>"
|
|
||||||
+ user_message
|
|
||||||
+ r"</blockquote></mx-reply>"
|
|
||||||
+ markdown.markdown(
|
|
||||||
reply_message,
|
|
||||||
extensions=["nl2br", "tables", "fenced_code"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
content = {
|
|
||||||
"msgtype": "m.text",
|
|
||||||
"body": body,
|
|
||||||
"format": format,
|
|
||||||
"formatted_body": formatted_body,
|
|
||||||
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}},
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
await client.room_send(
|
|
||||||
room_id,
|
|
||||||
message_type="m.room.message",
|
|
||||||
content=content,
|
|
||||||
ignore_unverified_devices=True,
|
|
||||||
)
|
|
||||||
await client.room_typing(room_id, typing_state=False)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
|
@ -1,26 +1,12 @@
|
||||||
"""
|
|
||||||
Code derived from https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
|
|
||||||
A simple wrapper for the official ChatGPT API
|
|
||||||
"""
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from tenacity import retry, wait_random_exponential, stop_after_attempt
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
ENGINES = [
|
|
||||||
"gpt-3.5-turbo",
|
|
||||||
"gpt-3.5-turbo-16k",
|
|
||||||
"gpt-3.5-turbo-0613",
|
|
||||||
"gpt-3.5-turbo-16k-0613",
|
|
||||||
"gpt-4",
|
|
||||||
"gpt-4-32k",
|
|
||||||
"gpt-4-0613",
|
|
||||||
"gpt-4-32k-0613",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Chatbot:
|
class Chatbot:
|
||||||
"""
|
"""
|
||||||
Official ChatGPT API
|
Official ChatGPT API
|
||||||
|
@ -28,48 +14,29 @@ class Chatbot:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
aclient: httpx.AsyncClient,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_url: str = None,
|
engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
|
||||||
engine: str = None,
|
proxy: str = None,
|
||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.5,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
reply_count: int = 1,
|
reply_count: int = 1,
|
||||||
truncate_limit: int = None,
|
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||||
system_prompt: str = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
||||||
"""
|
"""
|
||||||
self.engine: str = engine or "gpt-3.5-turbo"
|
self.engine: str = engine
|
||||||
self.api_key: str = api_key
|
self.api_key: str = api_key
|
||||||
self.api_url: str = api_url or "https://api.openai.com/v1/chat/completions"
|
self.system_prompt: str = system_prompt
|
||||||
self.system_prompt: str = (
|
|
||||||
system_prompt
|
|
||||||
or "You are ChatGPT, \
|
|
||||||
a large language model trained by OpenAI. Respond conversationally"
|
|
||||||
)
|
|
||||||
self.max_tokens: int = max_tokens or (
|
self.max_tokens: int = max_tokens or (
|
||||||
31000
|
31000 if engine == "gpt-4-32k" else 7000 if engine == "gpt-4" else 4000
|
||||||
if "gpt-4-32k" in engine
|
|
||||||
else 7000
|
|
||||||
if "gpt-4" in engine
|
|
||||||
else 15000
|
|
||||||
if "gpt-3.5-turbo-16k" in engine
|
|
||||||
else 4000
|
|
||||||
)
|
)
|
||||||
self.truncate_limit: int = truncate_limit or (
|
self.truncate_limit: int = (
|
||||||
30500
|
30500 if engine == "gpt-4-32k" else 6500 if engine == "gpt-4" else 3500
|
||||||
if "gpt-4-32k" in engine
|
|
||||||
else 6500
|
|
||||||
if "gpt-4" in engine
|
|
||||||
else 14500
|
|
||||||
if "gpt-3.5-turbo-16k" in engine
|
|
||||||
else 3500
|
|
||||||
)
|
)
|
||||||
self.temperature: float = temperature
|
self.temperature: float = temperature
|
||||||
self.top_p: float = top_p
|
self.top_p: float = top_p
|
||||||
|
@ -77,8 +44,31 @@ class Chatbot:
|
||||||
self.frequency_penalty: float = frequency_penalty
|
self.frequency_penalty: float = frequency_penalty
|
||||||
self.reply_count: int = reply_count
|
self.reply_count: int = reply_count
|
||||||
self.timeout: float = timeout
|
self.timeout: float = timeout
|
||||||
|
self.proxy = proxy
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.session.proxies.update(
|
||||||
|
{
|
||||||
|
"http": proxy,
|
||||||
|
"https": proxy,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
proxy = (
|
||||||
|
proxy or os.environ.get("all_proxy") or os.environ.get("ALL_PROXY") or None
|
||||||
|
)
|
||||||
|
|
||||||
self.aclient = aclient
|
if proxy:
|
||||||
|
if "socks5h" not in proxy:
|
||||||
|
self.aclient = httpx.AsyncClient(
|
||||||
|
follow_redirects=True,
|
||||||
|
proxies=proxy,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.aclient = httpx.AsyncClient(
|
||||||
|
follow_redirects=True,
|
||||||
|
proxies=proxy,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
self.conversation: dict[str, list[dict]] = {
|
self.conversation: dict[str, list[dict]] = {
|
||||||
"default": [
|
"default": [
|
||||||
|
@ -89,8 +79,6 @@ class Chatbot:
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.get_token_count("default") > self.max_tokens:
|
|
||||||
raise Exception("System prompt is too long")
|
|
||||||
|
|
||||||
def add_to_conversation(
|
def add_to_conversation(
|
||||||
self,
|
self,
|
||||||
|
@ -117,26 +105,31 @@ class Chatbot:
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
def get_token_count(self, convo_id: str = "default") -> int:
|
def get_token_count(self, convo_id: str = "default") -> int:
|
||||||
"""
|
"""
|
||||||
Get token count
|
Get token count
|
||||||
"""
|
"""
|
||||||
_engine = self.engine
|
if self.engine not in [
|
||||||
if self.engine not in ENGINES:
|
"gpt-3.5-turbo",
|
||||||
# use gpt-3.5-turbo to caculate token
|
"gpt-3.5-turbo-0301",
|
||||||
_engine = "gpt-3.5-turbo"
|
"gpt-4",
|
||||||
|
"gpt-4-0314",
|
||||||
|
"gpt-4-32k",
|
||||||
|
"gpt-4-32k-0314",
|
||||||
|
]:
|
||||||
|
raise NotImplementedError("Unsupported engine {self.engine}")
|
||||||
|
|
||||||
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
|
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
|
||||||
|
|
||||||
encoding = tiktoken.encoding_for_model(_engine)
|
encoding = tiktoken.encoding_for_model(self.engine)
|
||||||
|
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for message in self.conversation[convo_id]:
|
for message in self.conversation[convo_id]:
|
||||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
num_tokens += 5
|
num_tokens += 5
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
if value:
|
num_tokens += len(encoding.encode(value))
|
||||||
num_tokens += len(encoding.encode(value))
|
|
||||||
if key == "name": # if there's a name, the role is omitted
|
if key == "name": # if there's a name, the role is omitted
|
||||||
num_tokens += 5 # role is always required and always 1 token
|
num_tokens += 5 # role is always required and always 1 token
|
||||||
num_tokens += 5 # every reply is primed with <im_start>assistant
|
num_tokens += 5 # every reply is primed with <im_start>assistant
|
||||||
|
@ -148,15 +141,13 @@ class Chatbot:
|
||||||
"""
|
"""
|
||||||
return self.max_tokens - self.get_token_count(convo_id)
|
return self.max_tokens - self.get_token_count(convo_id)
|
||||||
|
|
||||||
async def ask_stream_async(
|
def ask_stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
role: str = "user",
|
role: str = "user",
|
||||||
convo_id: str = "default",
|
convo_id: str = "default",
|
||||||
model: str = None,
|
|
||||||
pass_history: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
):
|
||||||
"""
|
"""
|
||||||
Ask a question
|
Ask a question
|
||||||
"""
|
"""
|
||||||
|
@ -166,13 +157,12 @@ class Chatbot:
|
||||||
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
||||||
self.__truncate_conversation(convo_id=convo_id)
|
self.__truncate_conversation(convo_id=convo_id)
|
||||||
# Get response
|
# Get response
|
||||||
async with self.aclient.stream(
|
response = self.session.post(
|
||||||
"post",
|
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
||||||
self.api_url,
|
|
||||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||||
json={
|
json={
|
||||||
"model": model or self.engine,
|
"model": self.engine,
|
||||||
"messages": self.conversation[convo_id] if pass_history else [prompt],
|
"messages": self.conversation[convo_id],
|
||||||
"stream": True,
|
"stream": True,
|
||||||
# kwargs
|
# kwargs
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
@ -187,18 +177,79 @@ class Chatbot:
|
||||||
),
|
),
|
||||||
"n": kwargs.get("n", self.reply_count),
|
"n": kwargs.get("n", self.reply_count),
|
||||||
"user": role,
|
"user": role,
|
||||||
"max_tokens": min(
|
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||||
self.get_max_tokens(convo_id=convo_id),
|
},
|
||||||
kwargs.get("max_tokens", self.max_tokens),
|
timeout=kwargs.get("timeout", self.timeout),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_role: str = None
|
||||||
|
full_response: str = ""
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# Remove "data: "
|
||||||
|
line = line.decode("utf-8")[6:]
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
resp: dict = json.loads(line)
|
||||||
|
choices = resp.get("choices")
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
delta = choices[0].get("delta")
|
||||||
|
if not delta:
|
||||||
|
continue
|
||||||
|
if "role" in delta:
|
||||||
|
response_role = delta["role"]
|
||||||
|
if "content" in delta:
|
||||||
|
content = delta["content"]
|
||||||
|
full_response += content
|
||||||
|
yield content
|
||||||
|
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
||||||
|
|
||||||
|
async def ask_stream_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
role: str = "user",
|
||||||
|
convo_id: str = "default",
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
Ask a question
|
||||||
|
"""
|
||||||
|
# Make conversation if it doesn't exist
|
||||||
|
if convo_id not in self.conversation:
|
||||||
|
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
||||||
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
||||||
|
self.__truncate_conversation(convo_id=convo_id)
|
||||||
|
# Get response
|
||||||
|
async with self.aclient.stream(
|
||||||
|
"post",
|
||||||
|
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
||||||
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||||
|
json={
|
||||||
|
"model": self.engine,
|
||||||
|
"messages": self.conversation[convo_id],
|
||||||
|
"stream": True,
|
||||||
|
# kwargs
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"top_p": kwargs.get("top_p", self.top_p),
|
||||||
|
"presence_penalty": kwargs.get(
|
||||||
|
"presence_penalty",
|
||||||
|
self.presence_penalty,
|
||||||
),
|
),
|
||||||
|
"frequency_penalty": kwargs.get(
|
||||||
|
"frequency_penalty",
|
||||||
|
self.frequency_penalty,
|
||||||
|
),
|
||||||
|
"n": kwargs.get("n", self.reply_count),
|
||||||
|
"user": role,
|
||||||
|
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||||
},
|
},
|
||||||
timeout=kwargs.get("timeout", self.timeout),
|
timeout=kwargs.get("timeout", self.timeout),
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
await response.aread()
|
await response.aread()
|
||||||
raise Exception(
|
|
||||||
f"{response.status_code} {response.reason_phrase} {response.text}",
|
|
||||||
)
|
|
||||||
|
|
||||||
response_role: str = ""
|
response_role: str = ""
|
||||||
full_response: str = ""
|
full_response: str = ""
|
||||||
|
@ -211,8 +262,6 @@ class Chatbot:
|
||||||
if line == "[DONE]":
|
if line == "[DONE]":
|
||||||
break
|
break
|
||||||
resp: dict = json.loads(line)
|
resp: dict = json.loads(line)
|
||||||
if "error" in resp:
|
|
||||||
raise Exception(f"{resp['error']}")
|
|
||||||
choices = resp.get("choices")
|
choices = resp.get("choices")
|
||||||
if not choices:
|
if not choices:
|
||||||
continue
|
continue
|
||||||
|
@ -232,8 +281,6 @@ class Chatbot:
|
||||||
prompt: str,
|
prompt: str,
|
||||||
role: str = "user",
|
role: str = "user",
|
||||||
convo_id: str = "default",
|
convo_id: str = "default",
|
||||||
model: str = None,
|
|
||||||
pass_history: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -243,102 +290,35 @@ class Chatbot:
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
role=role,
|
role=role,
|
||||||
convo_id=convo_id,
|
convo_id=convo_id,
|
||||||
model=model,
|
|
||||||
pass_history=pass_history,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
full_response: str = "".join([r async for r in response])
|
full_response: str = "".join([r async for r in response])
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
async def ask_async_v2(
|
def ask(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
role: str = "user",
|
role: str = "user",
|
||||||
convo_id: str = "default",
|
convo_id: str = "default",
|
||||||
model: str = None,
|
|
||||||
pass_history: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Make conversation if it doesn't exist
|
"""
|
||||||
if convo_id not in self.conversation:
|
Non-streaming ask
|
||||||
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
"""
|
||||||
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
response = self.ask_stream(
|
||||||
self.__truncate_conversation(convo_id=convo_id)
|
prompt=prompt,
|
||||||
# Get response
|
role=role,
|
||||||
response = await self.aclient.post(
|
convo_id=convo_id,
|
||||||
url=self.api_url,
|
**kwargs,
|
||||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
|
||||||
json={
|
|
||||||
"model": model or self.engine,
|
|
||||||
"messages": self.conversation[convo_id] if pass_history else [prompt],
|
|
||||||
# kwargs
|
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
|
||||||
"top_p": kwargs.get("top_p", self.top_p),
|
|
||||||
"presence_penalty": kwargs.get(
|
|
||||||
"presence_penalty",
|
|
||||||
self.presence_penalty,
|
|
||||||
),
|
|
||||||
"frequency_penalty": kwargs.get(
|
|
||||||
"frequency_penalty",
|
|
||||||
self.frequency_penalty,
|
|
||||||
),
|
|
||||||
"n": kwargs.get("n", self.reply_count),
|
|
||||||
"user": role,
|
|
||||||
"max_tokens": min(
|
|
||||||
self.get_max_tokens(convo_id=convo_id),
|
|
||||||
kwargs.get("max_tokens", self.max_tokens),
|
|
||||||
),
|
|
||||||
},
|
|
||||||
timeout=kwargs.get("timeout", self.timeout),
|
|
||||||
)
|
|
||||||
resp = response.json()
|
|
||||||
full_response = resp["choices"][0]["message"]["content"]
|
|
||||||
self.add_to_conversation(
|
|
||||||
full_response, resp["choices"][0]["message"]["role"], convo_id=convo_id
|
|
||||||
)
|
)
|
||||||
|
full_response: str = "".join(response)
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
|
|
||||||
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Reset the conversation
|
Reset the conversation
|
||||||
"""
|
"""
|
||||||
self.conversation[convo_id] = [
|
self.conversation[convo_id] = [
|
||||||
{"role": "system", "content": system_prompt or self.system_prompt},
|
{"role": "system", "content": system_prompt or self.system_prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
@retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(3))
|
|
||||||
async def oneTimeAsk(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
role: str = "user",
|
|
||||||
model: str = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
response = await self.aclient.post(
|
|
||||||
url=self.api_url,
|
|
||||||
json={
|
|
||||||
"model": model or self.engine,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": role,
|
|
||||||
"content": prompt,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
# kwargs
|
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
|
||||||
"top_p": kwargs.get("top_p", self.top_p),
|
|
||||||
"presence_penalty": kwargs.get(
|
|
||||||
"presence_penalty",
|
|
||||||
self.presence_penalty,
|
|
||||||
),
|
|
||||||
"frequency_penalty": kwargs.get(
|
|
||||||
"frequency_penalty",
|
|
||||||
self.frequency_penalty,
|
|
||||||
),
|
|
||||||
"user": role,
|
|
||||||
},
|
|
||||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
|
||||||
timeout=kwargs.get("timeout", self.timeout),
|
|
||||||
)
|
|
||||||
resp = response.json()
|
|
||||||
return resp["choices"][0]["message"]["content"]
|
|
Loading…
Reference in a new issue