File size: 4,237 Bytes
6a26cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fcadbd
 
cee1f89
6a26cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32cefa9
 
6a26cdf
 
eb8405f
6a26cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# https://docs.gpt4all.io/gpt4all_python.html#quickstart

from fastapi import FastAPI
from pydantic import BaseModel
from abc import ABC, abstractmethod
from gpt4all import GPT4All
from typing import Dict, List, Union
from pathlib import Path


class ChatBot(ABC):
    def __init__(self, model: str):
        self.model_name = model
        self._convo: List[Dict] = []  # memory

    @abstractmethod
    def add_system_message(self, **kwargs):
        # including formatting the message to what the model expects.
        pass

    @abstractmethod
    def add_user_message(self, **kwargs):
        # including formatting the message to what the model expects.
        pass

    @abstractmethod
    def chat(self, **kwargs):
        pass

    @abstractmethod
    def replay_messages(self, **kwargs):
        pass


class GPT4ALLChatBot(ChatBot):
    def __init__(self, model: str, user_prompt_template: str, system_prompt_template: str = '', allow_download=False):
        super().__init__(model)
        # https://github.com/nomic-ai/gpt4all/issues/1235
        # https://docs.gpt4all.io/gpt4all_python.html#gpt4all.gpt4all.GPT4All
        self._model = GPT4All(model_name=model, model_path="/code", allow_download=allow_download)  # same folder as docker file
        self._user_prompt_template = user_prompt_template
        self._system_prompt_template = system_prompt_template

    def print_model_attribs(self):
        print(dir(self._model))

    def add_system_message(self, msg):
        # including formatting the message to what the model expects.
        pass

    def add_user_message(self, msg):
        # including formatting the message to what the model expects.
        pass

    def chat(self, user_prompt: str, max_tokens=128, temp=0.7):
        with self._model.chat_session(system_prompt=self._system_prompt_template,
                                      prompt_template=self._user_prompt_template):
            res = self._model.generate(prompt=user_prompt, max_tokens=max_tokens, temp=temp)
            self._convo.append({'role': 'user', 'content': user_prompt})
            self._convo.append({'role': 'assistant', 'content': res})
        return res

    def __call__(self, user_prompt: str, max_tokens=128, temp=0.7):
        return self.chat(user_prompt, max_tokens=max_tokens, temp=temp)

    def summarise(self, max_words=50):
        summarise_prompt = """
        Given the CONTEXT below, summarize it to less than {MAX_WORDS}.
        
        ### CONTEXT
        {CONTEXT}
        """
        prompt = summarise_prompt.format(MAX_WORDS=max_words, CONTEXT=self._convo)
        res = self.chat(user_prompt=prompt, max_tokens=128, temp=0)
        return res

    def replay_messages(self):
        output = []
        for c in self._convo:
            for k, v in c.items():
                output.append(f"{k}, ':\t', {v}, '\n'")
        return output

    def get_convo(self) -> List[Dict]:
        return self._convo


cwd = Path.cwd()
SYSTEM_PROMPT_TEMPLATE = '{0}'
USER_PROMPT_TEMPLATE = '<|im_start|>user\n{0}<|im_end|>\n<|im_start|>assistant\n{1}<|im_end|>\n'
model_name = "Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf"
model_path_file = cwd / model_name
# model_path_file_str = str(model_path_file)
cbot = GPT4ALLChatBot(model=model_name,
                      user_prompt_template=USER_PROMPT_TEMPLATE,
                      system_prompt_template=SYSTEM_PROMPT_TEMPLATE,
                      allow_download=False)

# test code
# print(cbot.chat("hello!"))
# print(cbot("can you make me a coffee?"))
# print(cbot.get_convo())


# fast api
app = FastAPI()


class RequestModel(BaseModel):
    input_string: str


@app.post("/chat")
def chat(prompt: RequestModel):
    prompt_str = prompt.input_string.strip()
    completion = cbot(prompt_str)
    return {"prompt": prompt_str, "completion": completion}


@app.get("/replay")
def replay():
    return cbot.replay_messages()


@app.get("/summarise")
def summarise():
    return cbot.summarise(max_words=50)


"""
to run
------

uvicorn main:app --reload
for debugging. --reload, reloads uvicorn when it detects code change.

otherwise no need --reload

default is port 8000 unless --port option is used to change it.

uvicorn main:app --port 8000

"""