richardchai's picture
Upload main.py
cee1f89 verified
# https://docs.gpt4all.io/gpt4all_python.html#quickstart
from fastapi import FastAPI
from pydantic import BaseModel
from abc import ABC, abstractmethod
from gpt4all import GPT4All
from typing import Dict, List, Union
from pathlib import Path
class ChatBot(ABC):
def __init__(self, model: str):
self.model_name = model
self._convo: List[Dict] = [] # memory
@abstractmethod
def add_system_message(self, **kwargs):
# including formatting the message to what the model expects.
pass
@abstractmethod
def add_user_message(self, **kwargs):
# including formatting the message to what the model expects.
pass
@abstractmethod
def chat(self, **kwargs):
pass
@abstractmethod
def replay_messages(self, **kwargs):
pass
class GPT4ALLChatBot(ChatBot):
def __init__(self, model: str, user_prompt_template: str, system_prompt_template: str = '', allow_download=False):
super().__init__(model)
# https://github.com/nomic-ai/gpt4all/issues/1235
# https://docs.gpt4all.io/gpt4all_python.html#gpt4all.gpt4all.GPT4All
self._model = GPT4All(model_name=model, model_path="/code", allow_download=allow_download) # same folder as docker file
self._user_prompt_template = user_prompt_template
self._system_prompt_template = system_prompt_template
def print_model_attribs(self):
print(dir(self._model))
def add_system_message(self, msg):
# including formatting the message to what the model expects.
pass
def add_user_message(self, msg):
# including formatting the message to what the model expects.
pass
def chat(self, user_prompt: str, max_tokens=128, temp=0.7):
with self._model.chat_session(system_prompt=self._system_prompt_template,
prompt_template=self._user_prompt_template):
res = self._model.generate(prompt=user_prompt, max_tokens=max_tokens, temp=temp)
self._convo.append({'role': 'user', 'content': user_prompt})
self._convo.append({'role': 'assistant', 'content': res})
return res
def __call__(self, user_prompt: str, max_tokens=128, temp=0.7):
return self.chat(user_prompt, max_tokens=max_tokens, temp=temp)
def summarise(self, max_words=50):
summarise_prompt = """
Given the CONTEXT below, summarize it to less than {MAX_WORDS}.
### CONTEXT
{CONTEXT}
"""
prompt = summarise_prompt.format(MAX_WORDS=max_words, CONTEXT=self._convo)
res = self.chat(user_prompt=prompt, max_tokens=128, temp=0)
return res
def replay_messages(self):
output = []
for c in self._convo:
for k, v in c.items():
output.append(f"{k}, ':\t', {v}, '\n'")
return output
def get_convo(self) -> List[Dict]:
return self._convo
cwd = Path.cwd()
SYSTEM_PROMPT_TEMPLATE = '{0}'
USER_PROMPT_TEMPLATE = '<|im_start|>user\n{0}<|im_end|>\n<|im_start|>assistant\n{1}<|im_end|>\n'
model_name = "Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf"
model_path_file = cwd / model_name
# model_path_file_str = str(model_path_file)
cbot = GPT4ALLChatBot(model=model_name,
user_prompt_template=USER_PROMPT_TEMPLATE,
system_prompt_template=SYSTEM_PROMPT_TEMPLATE,
allow_download=False)
# test code
# print(cbot.chat("hello!"))
# print(cbot("can you make me a coffee?"))
# print(cbot.get_convo())
# fast api
app = FastAPI()
class RequestModel(BaseModel):
input_string: str
@app.post("/chat")
def chat(prompt: RequestModel):
prompt_str = prompt.input_string.strip()
completion = cbot(prompt_str)
return {"prompt": prompt_str, "completion": completion}
@app.get("/replay")
def replay():
return cbot.replay_messages()
@app.get("/summarise")
def summarise():
return cbot.summarise(max_words=50)
"""
to run
------
uvicorn main:app --reload
for debugging. --reload, reloads uvicorn when it detects code change.
otherwise no need --reload
default is port 8000 unless --port option is used to change it.
uvicorn main:app --port 8000
"""