Spaces:
Sleeping
Sleeping
File size: 4,237 Bytes
6a26cdf 5fcadbd cee1f89 6a26cdf 32cefa9 6a26cdf eb8405f 6a26cdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# https://docs.gpt4all.io/gpt4all_python.html#quickstart
from fastapi import FastAPI
from pydantic import BaseModel
from abc import ABC, abstractmethod
from gpt4all import GPT4All
from typing import Dict, List, Union
from pathlib import Path
class ChatBot(ABC):
def __init__(self, model: str):
self.model_name = model
self._convo: List[Dict] = [] # memory
@abstractmethod
def add_system_message(self, **kwargs):
# including formatting the message to what the model expects.
pass
@abstractmethod
def add_user_message(self, **kwargs):
# including formatting the message to what the model expects.
pass
@abstractmethod
def chat(self, **kwargs):
pass
@abstractmethod
def replay_messages(self, **kwargs):
pass
class GPT4ALLChatBot(ChatBot):
def __init__(self, model: str, user_prompt_template: str, system_prompt_template: str = '', allow_download=False):
super().__init__(model)
# https://github.com/nomic-ai/gpt4all/issues/1235
# https://docs.gpt4all.io/gpt4all_python.html#gpt4all.gpt4all.GPT4All
self._model = GPT4All(model_name=model, model_path="/code", allow_download=allow_download) # same folder as docker file
self._user_prompt_template = user_prompt_template
self._system_prompt_template = system_prompt_template
def print_model_attribs(self):
print(dir(self._model))
def add_system_message(self, msg):
# including formatting the message to what the model expects.
pass
def add_user_message(self, msg):
# including formatting the message to what the model expects.
pass
def chat(self, user_prompt: str, max_tokens=128, temp=0.7):
with self._model.chat_session(system_prompt=self._system_prompt_template,
prompt_template=self._user_prompt_template):
res = self._model.generate(prompt=user_prompt, max_tokens=max_tokens, temp=temp)
self._convo.append({'role': 'user', 'content': user_prompt})
self._convo.append({'role': 'assistant', 'content': res})
return res
def __call__(self, user_prompt: str, max_tokens=128, temp=0.7):
return self.chat(user_prompt, max_tokens=max_tokens, temp=temp)
def summarise(self, max_words=50):
summarise_prompt = """
Given the CONTEXT below, summarize it to less than {MAX_WORDS}.
### CONTEXT
{CONTEXT}
"""
prompt = summarise_prompt.format(MAX_WORDS=max_words, CONTEXT=self._convo)
res = self.chat(user_prompt=prompt, max_tokens=128, temp=0)
return res
def replay_messages(self):
output = []
for c in self._convo:
for k, v in c.items():
output.append(f"{k}, ':\t', {v}, '\n'")
return output
def get_convo(self) -> List[Dict]:
return self._convo
cwd = Path.cwd()
SYSTEM_PROMPT_TEMPLATE = '{0}'
USER_PROMPT_TEMPLATE = '<|im_start|>user\n{0}<|im_end|>\n<|im_start|>assistant\n{1}<|im_end|>\n'
model_name = "Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf"
model_path_file = cwd / model_name
# model_path_file_str = str(model_path_file)
cbot = GPT4ALLChatBot(model=model_name,
user_prompt_template=USER_PROMPT_TEMPLATE,
system_prompt_template=SYSTEM_PROMPT_TEMPLATE,
allow_download=False)
# test code
# print(cbot.chat("hello!"))
# print(cbot("can you make me a coffee?"))
# print(cbot.get_convo())
# fast api
app = FastAPI()
class RequestModel(BaseModel):
input_string: str
@app.post("/chat")
def chat(prompt: RequestModel):
prompt_str = prompt.input_string.strip()
completion = cbot(prompt_str)
return {"prompt": prompt_str, "completion": completion}
@app.get("/replay")
def replay():
return cbot.replay_messages()
@app.get("/summarise")
def summarise():
return cbot.summarise(max_words=50)
"""
to run
------
uvicorn main:app --reload
for debugging. --reload, reloads uvicorn when it detects code change.
otherwise no need --reload
default is port 8000 unless --port option is used to change it.
uvicorn main:app --port 8000
"""
|