193 lines
5.9 KiB
Python
193 lines
5.9 KiB
Python
|
"""
|
||
|
A simple wrapper for the official ChatGPT API
|
||
|
https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
|
||
|
"""
|
||
|
|
||
|
import json
|
||
|
import os
|
||
|
|
||
|
|
||
|
import requests
|
||
|
import tiktoken
|
||
|
|
||
|
|
||
|
ENGINE = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo"
|
||
|
ENCODER = tiktoken.get_encoding("gpt2")
|
||
|
|
||
|
|
||
|
class Chatbot:
|
||
|
"""
|
||
|
Official ChatGPT API
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
api_key: str = None,
|
||
|
engine: str = None,
|
||
|
proxy: str = None,
|
||
|
max_tokens: int = 4096,
|
||
|
temperature: float = 0.5,
|
||
|
top_p: float = 1.0,
|
||
|
reply_count: int = 1,
|
||
|
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||
|
) -> None:
|
||
|
"""
|
||
|
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
||
|
"""
|
||
|
self.engine = engine or ENGINE
|
||
|
self.session = requests.Session()
|
||
|
self.api_key = api_key
|
||
|
# self.proxy = proxy
|
||
|
# if self.proxy:
|
||
|
# proxies = {
|
||
|
# "http": self.proxy,
|
||
|
# "https": self.proxy,
|
||
|
# }
|
||
|
# self.session.proxies = proxies
|
||
|
self.conversation: dict = {
|
||
|
"default": [
|
||
|
{
|
||
|
"role": "system",
|
||
|
"content": system_prompt,
|
||
|
},
|
||
|
],
|
||
|
}
|
||
|
self.system_prompt = system_prompt
|
||
|
self.max_tokens = max_tokens
|
||
|
self.temperature = temperature
|
||
|
self.top_p = top_p
|
||
|
self.reply_count = reply_count
|
||
|
|
||
|
initial_conversation = "\n".join(
|
||
|
[x["content"] for x in self.conversation["default"]],
|
||
|
)
|
||
|
if len(ENCODER.encode(initial_conversation)) > self.max_tokens:
|
||
|
raise Exception("System prompt is too long")
|
||
|
|
||
|
def add_to_conversation(
|
||
|
self, message: str, role: str, convo_id: str = "default"
|
||
|
) -> None:
|
||
|
"""
|
||
|
Add a message to the conversation
|
||
|
"""
|
||
|
self.conversation[convo_id].append({"role": role, "content": message})
|
||
|
|
||
|
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
||
|
"""
|
||
|
Truncate the conversation
|
||
|
"""
|
||
|
while True:
|
||
|
full_conversation = "".join(
|
||
|
message["role"] + ": " + message["content"] + "\n"
|
||
|
for message in self.conversation[convo_id]
|
||
|
)
|
||
|
if (
|
||
|
len(ENCODER.encode(full_conversation)) > self.max_tokens
|
||
|
and len(self.conversation[convo_id]) > 1
|
||
|
):
|
||
|
# Don't remove the first message
|
||
|
self.conversation[convo_id].pop(1)
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
def get_max_tokens(self, convo_id: str) -> int:
|
||
|
"""
|
||
|
Get max tokens
|
||
|
"""
|
||
|
full_conversation = "".join(
|
||
|
message["role"] + ": " + message["content"] + "\n"
|
||
|
for message in self.conversation[convo_id]
|
||
|
)
|
||
|
return 4000 - len(ENCODER.encode(full_conversation))
|
||
|
|
||
|
def ask_stream(
|
||
|
self,
|
||
|
prompt: str,
|
||
|
role: str = "user",
|
||
|
convo_id: str = "default",
|
||
|
**kwargs,
|
||
|
) -> str:
|
||
|
"""
|
||
|
Ask a question
|
||
|
"""
|
||
|
# Make conversation if it doesn't exist
|
||
|
if convo_id not in self.conversation:
|
||
|
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
||
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
||
|
self.__truncate_conversation(convo_id=convo_id)
|
||
|
# Get response
|
||
|
response = self.session.post(
|
||
|
"https://api.openai.com/v1/chat/completions",
|
||
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||
|
json={
|
||
|
"model": self.engine,
|
||
|
"messages": self.conversation[convo_id],
|
||
|
"stream": True,
|
||
|
# kwargs
|
||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||
|
"top_p": kwargs.get("top_p", self.top_p),
|
||
|
"n": kwargs.get("n", self.reply_count),
|
||
|
"user": role,
|
||
|
# "max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||
|
},
|
||
|
stream=True,
|
||
|
)
|
||
|
if response.status_code != 200:
|
||
|
raise Exception(
|
||
|
f"Error: {response.status_code} {response.reason} {response.text}",
|
||
|
)
|
||
|
response_role: str = None
|
||
|
full_response: str = ""
|
||
|
for line in response.iter_lines():
|
||
|
if not line:
|
||
|
continue
|
||
|
# Remove "data: "
|
||
|
line = line.decode("utf-8")[6:]
|
||
|
if line == "[DONE]":
|
||
|
break
|
||
|
resp: dict = json.loads(line)
|
||
|
choices = resp.get("choices")
|
||
|
if not choices:
|
||
|
continue
|
||
|
delta = choices[0].get("delta")
|
||
|
if not delta:
|
||
|
continue
|
||
|
if "role" in delta:
|
||
|
response_role = delta["role"]
|
||
|
if "content" in delta:
|
||
|
content = delta["content"]
|
||
|
full_response += content
|
||
|
yield content
|
||
|
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
||
|
|
||
|
def ask(
|
||
|
self, prompt: str, role: str = "user", convo_id: str = "default", **kwargs
|
||
|
) -> str:
|
||
|
"""
|
||
|
Non-streaming ask
|
||
|
"""
|
||
|
response = self.ask_stream(
|
||
|
prompt=prompt,
|
||
|
role=role,
|
||
|
convo_id=convo_id,
|
||
|
**kwargs,
|
||
|
)
|
||
|
full_response: str = "".join(response)
|
||
|
return full_response
|
||
|
|
||
|
def rollback(self, n: int = 1, convo_id: str = "default") -> None:
|
||
|
"""
|
||
|
Rollback the conversation
|
||
|
"""
|
||
|
for _ in range(n):
|
||
|
self.conversation[convo_id].pop()
|
||
|
|
||
|
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
||
|
"""
|
||
|
Reset the conversation
|
||
|
"""
|
||
|
self.conversation[convo_id] = [
|
||
|
{"role": "system", "content": system_prompt or self.system_prompt},
|
||
|
]
|
||
|
|