""" Code derived from https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py A simple wrapper for the official ChatGPT API """ import json from typing import AsyncGenerator from tenacity import retry, stop_after_attempt, wait_random_exponential import httpx import tiktoken ENGINES = [ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4", "gpt-4-32k", "gpt-4-0613", "gpt-4-32k-0613", ] class Chatbot: """ Official ChatGPT API """ def __init__( self, aclient: httpx.AsyncClient, api_key: str, api_url: str = None, engine: str = None, timeout: float = None, max_tokens: int = None, temperature: float = 0.8, top_p: float = 1.0, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, reply_count: int = 1, truncate_limit: int = None, system_prompt: str = None, ) -> None: """ Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys) """ self.engine: str = engine or "gpt-3.5-turbo" self.api_key: str = api_key self.api_url: str = api_url or "https://api.openai.com/v1/chat/completions" self.system_prompt: str = ( system_prompt or "You are ChatGPT, \ a large language model trained by OpenAI. Respond conversationally" ) self.max_tokens: int = max_tokens or ( 31000 if "gpt-4-32k" in engine else 7000 if "gpt-4" in engine else 15000 if "gpt-3.5-turbo-16k" in engine else 4000 ) self.truncate_limit: int = truncate_limit or ( 30500 if "gpt-4-32k" in engine else 6500 if "gpt-4" in engine else 14500 if "gpt-3.5-turbo-16k" in engine else 3500 ) self.temperature: float = temperature self.top_p: float = top_p self.presence_penalty: float = presence_penalty self.frequency_penalty: float = frequency_penalty self.reply_count: int = reply_count self.timeout: float = timeout self.aclient = aclient self.conversation: dict[str, list[dict]] = { "default": [ { "role": "system", "content": system_prompt, }, ], } if self.get_token_count("default") > self.max_tokens: raise Exception("System prompt is too long") def add_to_conversation( self, message: str, role: str, convo_id: str = "default", ) -> None: """ Add a message to the conversation """ self.conversation[convo_id].append({"role": role, "content": message}) def __truncate_conversation(self, convo_id: str = "default") -> None: """ Truncate the conversation """ while True: if ( self.get_token_count(convo_id) > self.truncate_limit and len(self.conversation[convo_id]) > 1 ): # Don't remove the first message self.conversation[convo_id].pop(1) else: break # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def get_token_count(self, convo_id: str = "default") -> int: """ Get token count """ if self.engine not in ENGINES: raise NotImplementedError( f"Engine {self.engine} is not supported. Select from {ENGINES}", ) tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base" encoding = tiktoken.encoding_for_model(self.engine) num_tokens = 0 for message in self.conversation[convo_id]: # every message follows {role/name}\n{content}\n num_tokens += 5 for key, value in message.items(): if value: num_tokens += len(encoding.encode(value)) if key == "name": # if there's a name, the role is omitted num_tokens += 5 # role is always required and always 1 token num_tokens += 5 # every reply is primed with assistant return num_tokens def get_max_tokens(self, convo_id: str) -> int: """ Get max tokens """ return self.max_tokens - self.get_token_count(convo_id) async def ask_stream_async( self, prompt: str, role: str = "user", convo_id: str = "default", model: str = None, pass_history: bool = True, **kwargs, ) -> AsyncGenerator[str, None]: """ Ask a question """ # Make conversation if it doesn't exist if convo_id not in self.conversation: self.reset(convo_id=convo_id, system_prompt=self.system_prompt) self.add_to_conversation(prompt, "user", convo_id=convo_id) self.__truncate_conversation(convo_id=convo_id) # Get response async with self.aclient.stream( "post", self.api_url, headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"}, json={ "model": model or self.engine, "messages": self.conversation[convo_id] if pass_history else [prompt], "stream": True, # kwargs "temperature": kwargs.get("temperature", self.temperature), "top_p": kwargs.get("top_p", self.top_p), "presence_penalty": kwargs.get( "presence_penalty", self.presence_penalty, ), "frequency_penalty": kwargs.get( "frequency_penalty", self.frequency_penalty, ), "n": kwargs.get("n", self.reply_count), "user": role, "max_tokens": min( self.get_max_tokens(convo_id=convo_id), kwargs.get("max_tokens", self.max_tokens), ), }, timeout=kwargs.get("timeout", self.timeout), ) as response: if response.status_code != 200: await response.aread() raise Exception( f"{response.status_code} {response.reason_phrase} {response.text}", ) response_role: str = "" full_response: str = "" async for line in response.aiter_lines(): line = line.strip() if not line: continue # Remove "data: " line = line[6:] if line == "[DONE]": break resp: dict = json.loads(line) if "error" in resp: raise Exception(f"{resp['error']}") choices = resp.get("choices") if not choices: continue delta: dict[str, str] = choices[0].get("delta") if not delta: continue if "role" in delta: response_role = delta["role"] if "content" in delta: content: str = delta["content"] full_response += content yield content self.add_to_conversation(full_response, response_role, convo_id=convo_id) async def ask_async( self, prompt: str, role: str = "user", convo_id: str = "default", model: str = None, pass_history: bool = True, **kwargs, ) -> str: """ Non-streaming ask """ response = self.ask_stream_async( prompt=prompt, role=role, convo_id=convo_id, model=model, pass_history=pass_history, **kwargs, ) full_response: str = "".join([r async for r in response]) return full_response def reset(self, convo_id: str = "default", system_prompt: str = None) -> None: """ Reset the conversation """ self.conversation[convo_id] = [ {"role": "system", "content": system_prompt or self.system_prompt}, ] @retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(3)) async def oneTimeAsk( self, prompt: str, role: str = "user", model: str = None, **kwargs, ) -> str: async with self.aclient.post( url=self.api_url, json={ "model": model or self.engine, "messages": prompt, # kwargs "temperature": kwargs.get("temperature", self.temperature), "top_p": kwargs.get("top_p", self.top_p), "presence_penalty": kwargs.get( "presence_penalty", self.presence_penalty, ), "frequency_penalty": kwargs.get( "frequency_penalty", self.frequency_penalty, ), "user": role, }, headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"}, timeout=kwargs.get("timeout", self.timeout), ) as response: resp = await response.read() return json.loads(resp)["choices"][0]["message"]["content"]