From fab7a36bc48ff28d1cea6d01547b5eac9df21f62 Mon Sep 17 00:00:00 2001 From: hibobmaster Date: Fri, 10 Mar 2023 18:53:51 +0800 Subject: [PATCH] add official api_endpoint, properly handle failing request --- ask_gpt.py | 26 +++++--------------------- bot.py | 23 +++++++++++++++++++++-- test.py | 28 ++++++++++++++++++++++++---- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/ask_gpt.py b/ask_gpt.py index e94b19a..2c2ce2b 100644 --- a/ask_gpt.py +++ b/ask_gpt.py @@ -1,15 +1,9 @@ -""" -api_endpoint from https://github.com/ayaka14732/ChatGPTAPIFree -""" import aiohttp import asyncio import json -api_endpoint_free = "https://chatgpt-api.shn.hk/v1/" -headers = {'Content-Type': "application/json"} - -async def ask(prompt: str) -> str: +async def ask(prompt: str, api_endpoint: str, headers: dict) -> str: jsons = { "model": "gpt-3.5-turbo", "messages": [ @@ -20,13 +14,14 @@ async def ask(prompt: str) -> str: ], } async with aiohttp.ClientSession() as session: - - while True: + max_try = 5 + while max_try > 0: try: - async with session.post(url=api_endpoint_free, + async with session.post(url=api_endpoint, json=jsons, headers=headers, timeout=10) as response: status_code = response.status if not status_code == 200: + max_try = max_try - 1 # wait 2s await asyncio.sleep(2) continue @@ -37,14 +32,3 @@ async def ask(prompt: str) -> str: except Exception as e: print(e) pass - - -async def test() -> None: - resp = await ask("Hello World") - print(resp) - # type: str - print(type(resp)) - - -if __name__ == "__main__": - asyncio.run(test()) diff --git a/bot.py b/bot.py index b14f69b..fe73ddb 100644 --- a/bot.py +++ b/bot.py @@ -9,6 +9,14 @@ from ask_gpt import ask from send_message import send_room_message from v3 import Chatbot +""" +free api_endpoint from https://github.com/ayaka14732/ChatGPTAPIFree +""" +api_endpoint_list = { + "free": "https://chatgpt-api.shn.hk/v1/", + "paid": "https://api.openai.com/v1/chat/completions" +} + class Bot: def __init__( @@ -37,10 +45,21 @@ class Bot: # regular expression to match keyword [!gpt {prompt}] [!chat {prompt}] self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$") self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$") - # initialize chatbot + # initialize chatbot and api_endpoint if self.api_key != '': self.chatbot = Chatbot(api_key=self.api_key) + self.api_endpoint = api_endpoint_list['paid'] + # request header for !gpt command + self.headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + } + else: + self.headers = { + "Content-Type": "application/json", + } + # message_callback event async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> None: # remove newline character from event.body @@ -55,7 +74,7 @@ class Bot: # sending typing state await self.client.room_typing(room_id) prompt = m.group(1) - text = await ask(prompt) + text = await ask(prompt, self.api_endpoint, self.headers) text = text.strip() await send_room_message(self.client, room_id, send_text=text) diff --git a/test.py b/test.py index 082bd54..825fe3c 100644 --- a/test.py +++ b/test.py @@ -4,17 +4,37 @@ from ask_gpt import ask import json fp = open("config.json", "r") config = json.load(fp) +api_key = config.get('api_key', '') +api_endpoint_list = { + "free": "https://chatgpt-api.shn.hk/v1/", + "paid": "https://api.openai.com/v1/chat/completions" +} def test_v3(prompt: str): - bot = Chatbot(api_key=config['api_key']) + bot = Chatbot(api_key=api_key) resp = bot.ask(prompt=prompt) print(resp) -async def test_ask(prompt: str): - print(await ask(prompt=prompt)) +async def test_ask_gpt_paid(prompt: str): + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + api_key, + } + api_endpoint = api_endpoint_list['paid'] + # test ask_gpt.py ask() + print(await ask(prompt, api_endpoint, headers)) + + +async def test_ask_gpt_free(prompt: str): + headers = { + "Content-Type": "application/json", + } + api_endpoint = api_endpoint_list['free'] + print(await ask(prompt, api_endpoint, headers)) if __name__ == "__main__": test_v3("Hello World") - asyncio.run(test_ask("Hello World")) + asyncio.run(test_ask_gpt_paid("Hello World")) + asyncio.run(test_ask_gpt_free("Hello World"))