mattermost_bot/v3.py

325 lines
11 KiB
Python
Raw Normal View History

2023-04-17 17:16:03 +00:00
"""
Code derived from: https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
"""
import json
import os
from typing import AsyncGenerator
import httpx
import requests
import tiktoken
class Chatbot:
"""
Official ChatGPT API
"""
def __init__(
self,
api_key: str,
engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
proxy: str = None,
timeout: float = None,
max_tokens: int = None,
temperature: float = 0.5,
top_p: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
reply_count: int = 1,
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
) -> None:
"""
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
"""
self.engine: str = engine
self.api_key: str = api_key
self.system_prompt: str = system_prompt
self.max_tokens: int = max_tokens or (
31000 if engine == "gpt-4-32k" else 7000 if engine == "gpt-4" else 4000
)
self.truncate_limit: int = (
30500 if engine == "gpt-4-32k" else 6500 if engine == "gpt-4" else 3500
)
self.temperature: float = temperature
self.top_p: float = top_p
self.presence_penalty: float = presence_penalty
self.frequency_penalty: float = frequency_penalty
self.reply_count: int = reply_count
self.timeout: float = timeout
self.proxy = proxy
self.session = requests.Session()
self.session.proxies.update(
{
"http": proxy,
"https": proxy,
},
)
proxy = (
proxy or os.environ.get("all_proxy") or os.environ.get("ALL_PROXY") or None
)
if proxy:
if "socks5h" not in proxy:
self.aclient = httpx.AsyncClient(
follow_redirects=True,
proxies=proxy,
timeout=timeout,
)
else:
self.aclient = httpx.AsyncClient(
follow_redirects=True,
proxies=proxy,
timeout=timeout,
)
self.conversation: dict[str, list[dict]] = {
"default": [
{
"role": "system",
"content": system_prompt,
},
],
}
def add_to_conversation(
self,
message: str,
role: str,
convo_id: str = "default",
) -> None:
"""
Add a message to the conversation
"""
self.conversation[convo_id].append({"role": role, "content": message})
def __truncate_conversation(self, convo_id: str = "default") -> None:
"""
Truncate the conversation
"""
while True:
if (
self.get_token_count(convo_id) > self.truncate_limit
and len(self.conversation[convo_id]) > 1
):
# Don't remove the first message
self.conversation[convo_id].pop(1)
else:
break
def get_token_count(self, convo_id: str = "default") -> int:
"""
Get token count
"""
if self.engine not in [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-4",
"gpt-4-0314",
"gpt-4-32k",
"gpt-4-32k-0314",
]:
raise NotImplementedError("Unsupported engine {self.engine}")
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
encoding = tiktoken.encoding_for_model(self.engine)
num_tokens = 0
for message in self.conversation[convo_id]:
# every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 5
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += 5 # role is always required and always 1 token
num_tokens += 5 # every reply is primed with <im_start>assistant
return num_tokens
def get_max_tokens(self, convo_id: str) -> int:
"""
Get max tokens
"""
return self.max_tokens - self.get_token_count(convo_id)
def ask_stream(
self,
prompt: str,
role: str = "user",
convo_id: str = "default",
**kwargs,
):
"""
Ask a question
"""
# Make conversation if it doesn't exist
if convo_id not in self.conversation:
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
self.add_to_conversation(prompt, "user", convo_id=convo_id)
self.__truncate_conversation(convo_id=convo_id)
# Get response
response = self.session.post(
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
json={
"model": self.engine,
"messages": self.conversation[convo_id],
"stream": True,
# kwargs
"temperature": kwargs.get("temperature", self.temperature),
"top_p": kwargs.get("top_p", self.top_p),
"presence_penalty": kwargs.get(
"presence_penalty",
self.presence_penalty,
),
"frequency_penalty": kwargs.get(
"frequency_penalty",
self.frequency_penalty,
),
"n": kwargs.get("n", self.reply_count),
"user": role,
"max_tokens": self.get_max_tokens(convo_id=convo_id),
},
timeout=kwargs.get("timeout", self.timeout),
stream=True,
)
response_role: str = None
full_response: str = ""
for line in response.iter_lines():
if not line:
continue
# Remove "data: "
line = line.decode("utf-8")[6:]
if line == "[DONE]":
break
resp: dict = json.loads(line)
choices = resp.get("choices")
if not choices:
continue
delta = choices[0].get("delta")
if not delta:
continue
if "role" in delta:
response_role = delta["role"]
if "content" in delta:
content = delta["content"]
full_response += content
yield content
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
async def ask_stream_async(
self,
prompt: str,
role: str = "user",
convo_id: str = "default",
**kwargs,
) -> AsyncGenerator[str, None]:
"""
Ask a question
"""
# Make conversation if it doesn't exist
if convo_id not in self.conversation:
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
self.add_to_conversation(prompt, "user", convo_id=convo_id)
self.__truncate_conversation(convo_id=convo_id)
# Get response
async with self.aclient.stream(
"post",
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
json={
"model": self.engine,
"messages": self.conversation[convo_id],
"stream": True,
# kwargs
"temperature": kwargs.get("temperature", self.temperature),
"top_p": kwargs.get("top_p", self.top_p),
"presence_penalty": kwargs.get(
"presence_penalty",
self.presence_penalty,
),
"frequency_penalty": kwargs.get(
"frequency_penalty",
self.frequency_penalty,
),
"n": kwargs.get("n", self.reply_count),
"user": role,
"max_tokens": self.get_max_tokens(convo_id=convo_id),
},
timeout=kwargs.get("timeout", self.timeout),
) as response:
if response.status_code != 200:
await response.aread()
response_role: str = ""
full_response: str = ""
async for line in response.aiter_lines():
line = line.strip()
if not line:
continue
# Remove "data: "
line = line[6:]
if line == "[DONE]":
break
resp: dict = json.loads(line)
choices = resp.get("choices")
if not choices:
continue
delta: dict[str, str] = choices[0].get("delta")
if not delta:
continue
if "role" in delta:
response_role = delta["role"]
if "content" in delta:
content: str = delta["content"]
full_response += content
yield content
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
async def ask_async(
self,
prompt: str,
role: str = "user",
convo_id: str = "default",
**kwargs,
) -> str:
"""
Non-streaming ask
"""
response = self.ask_stream_async(
prompt=prompt,
role=role,
convo_id=convo_id,
**kwargs,
)
full_response: str = "".join([r async for r in response])
return full_response
def ask(
self,
prompt: str,
role: str = "user",
convo_id: str = "default",
**kwargs,
) -> str:
"""
Non-streaming ask
"""
response = self.ask_stream(
prompt=prompt,
role=role,
convo_id=convo_id,
**kwargs,
)
full_response: str = "".join(response)
return full_response
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
"""
Reset the conversation
"""
self.conversation[convo_id] = [
{"role": "system", "content": system_prompt or self.system_prompt},
]