Spaces:
Sleeping
Sleeping
# https://docs.gpt4all.io/gpt4all_python.html#quickstart | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from abc import ABC, abstractmethod | |
from gpt4all import GPT4All | |
from typing import Dict, List, Union | |
from pathlib import Path | |
class ChatBot(ABC): | |
def __init__(self, model: str): | |
self.model_name = model | |
self._convo: List[Dict] = [] # memory | |
def add_system_message(self, **kwargs): | |
# including formatting the message to what the model expects. | |
pass | |
def add_user_message(self, **kwargs): | |
# including formatting the message to what the model expects. | |
pass | |
def chat(self, **kwargs): | |
pass | |
def replay_messages(self, **kwargs): | |
pass | |
class GPT4ALLChatBot(ChatBot): | |
def __init__(self, model: str, user_prompt_template: str, system_prompt_template: str = '', allow_download=False): | |
super().__init__(model) | |
# https://github.com/nomic-ai/gpt4all/issues/1235 | |
# https://docs.gpt4all.io/gpt4all_python.html#gpt4all.gpt4all.GPT4All | |
self._model = GPT4All(model_name=model, model_path="/code", allow_download=allow_download) # same folder as docker file | |
self._user_prompt_template = user_prompt_template | |
self._system_prompt_template = system_prompt_template | |
def print_model_attribs(self): | |
print(dir(self._model)) | |
def add_system_message(self, msg): | |
# including formatting the message to what the model expects. | |
pass | |
def add_user_message(self, msg): | |
# including formatting the message to what the model expects. | |
pass | |
def chat(self, user_prompt: str, max_tokens=128, temp=0.7): | |
with self._model.chat_session(system_prompt=self._system_prompt_template, | |
prompt_template=self._user_prompt_template): | |
res = self._model.generate(prompt=user_prompt, max_tokens=max_tokens, temp=temp) | |
self._convo.append({'role': 'user', 'content': user_prompt}) | |
self._convo.append({'role': 'assistant', 'content': res}) | |
return res | |
def __call__(self, user_prompt: str, max_tokens=128, temp=0.7): | |
return self.chat(user_prompt, max_tokens=max_tokens, temp=temp) | |
def summarise(self, max_words=50): | |
summarise_prompt = """ | |
Given the CONTEXT below, summarize it to less than {MAX_WORDS}. | |
### CONTEXT | |
{CONTEXT} | |
""" | |
prompt = summarise_prompt.format(MAX_WORDS=max_words, CONTEXT=self._convo) | |
res = self.chat(user_prompt=prompt, max_tokens=128, temp=0) | |
return res | |
def replay_messages(self): | |
output = [] | |
for c in self._convo: | |
for k, v in c.items(): | |
output.append(f"{k}, ':\t', {v}, '\n'") | |
return output | |
def get_convo(self) -> List[Dict]: | |
return self._convo | |
cwd = Path.cwd() | |
SYSTEM_PROMPT_TEMPLATE = '{0}' | |
USER_PROMPT_TEMPLATE = '<|im_start|>user\n{0}<|im_end|>\n<|im_start|>assistant\n{1}<|im_end|>\n' | |
model_name = "Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf" | |
model_path_file = cwd / model_name | |
# model_path_file_str = str(model_path_file) | |
cbot = GPT4ALLChatBot(model=model_name, | |
user_prompt_template=USER_PROMPT_TEMPLATE, | |
system_prompt_template=SYSTEM_PROMPT_TEMPLATE, | |
allow_download=False) | |
# test code | |
# print(cbot.chat("hello!")) | |
# print(cbot("can you make me a coffee?")) | |
# print(cbot.get_convo()) | |
# fast api | |
app = FastAPI() | |
class RequestModel(BaseModel): | |
input_string: str | |
def chat(prompt: RequestModel): | |
prompt_str = prompt.input_string.strip() | |
completion = cbot(prompt_str) | |
return {"prompt": prompt_str, "completion": completion} | |
def replay(): | |
return cbot.replay_messages() | |
def summarise(): | |
return cbot.summarise(max_words=50) | |
""" | |
to run | |
------ | |
uvicorn main:app --reload | |
for debugging. --reload, reloads uvicorn when it detects code change. | |
otherwise no need --reload | |
default is port 8000 unless --port option is used to change it. | |
uvicorn main:app --port 8000 | |
""" | |