tech-envision commited on
Commit
bedb8e2
·
1 Parent(s): 0e02b97

Add database support and persist chat history

Browse files
Files changed (9) hide show
  1. README.md +11 -1
  2. requirements.txt +2 -1
  3. run.py +3 -3
  4. src/__init__.py +1 -1
  5. src/chat.py +35 -2
  6. src/config.py +1 -1
  7. src/db.py +46 -0
  8. src/schema.py +1 -1
  9. src/tools.py +2 -2
README.md CHANGED
@@ -1 +1,11 @@
1
- # llm-backend
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm-backend
2
+
3
+ This project provides a simple async interface to interact with an Ollama model and demonstrates basic tool usage. Chat histories are stored in a local SQLite database using Peewee.
4
+
5
+ ## Usage
6
+
7
+ ```bash
8
+ python run.py
9
+ ```
10
+
11
+ The script will ask the model to compute an arithmetic expression and print the answer. Conversations are automatically persisted to `chat.db`.
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  colorlog
2
- ollama
 
 
1
  colorlog
2
+ ollama
3
+ peewee
run.py CHANGED
@@ -6,9 +6,9 @@ from src.chat import ChatSession
6
 
7
 
8
  async def _main() -> None:
9
- chat = ChatSession()
10
- answer = await chat.chat("What is 10 + 23?")
11
- print("\n>>>", answer)
12
 
13
 
14
  if __name__ == "__main__":
 
6
 
7
 
8
  async def _main() -> None:
9
+ async with ChatSession() as chat:
10
+ answer = await chat.chat("What is 10 + 23?")
11
+ print("\n>>>", answer)
12
 
13
 
14
  if __name__ == "__main__":
src/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
  from .chat import ChatSession
2
  from .tools import add_two_numbers
3
 
4
- __all__: list[str] = ["ChatSession", "add_two_numbers"]
 
1
  from .chat import ChatSession
2
  from .tools import add_two_numbers
3
 
4
+ __all__ = ["ChatSession", "add_two_numbers"]
src/chat.py CHANGED
@@ -5,6 +5,7 @@ from typing import List
5
  from ollama import AsyncClient, ChatResponse
6
 
7
  from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
 
8
  from .log import get_logger
9
  from .schema import Msg
10
  from .tools import add_two_numbers
@@ -14,12 +15,17 @@ _LOG = get_logger(__name__)
14
 
15
  class ChatSession:
16
  def __init__(self, host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
 
17
  self._client = AsyncClient(host=host)
18
  self._model = model
19
 
20
  async def __aenter__(self) -> "ChatSession":
21
  return self
22
 
 
 
 
 
23
  async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
24
  return await self._client.chat(
25
  self._model,
@@ -32,6 +38,7 @@ class ChatSession:
32
  self,
33
  messages: List[Msg],
34
  response: ChatResponse,
 
35
  depth: int = 0,
36
  ) -> ChatResponse:
37
  if depth >= MAX_TOOL_CALL_DEPTH or not response.message.tool_calls:
@@ -47,17 +54,43 @@ class ChatSession:
47
  "content": str(result),
48
  }
49
  )
 
 
 
 
 
50
  nxt = await self.ask(messages, think=True)
51
- return await self._handle_tool_calls(messages, nxt, depth + 1)
 
 
 
 
 
 
 
52
 
53
  return response
54
 
55
  async def chat(self, prompt: str) -> str:
 
 
56
  messages: List[Msg] = [{"role": "user", "content": prompt}]
57
  response = await self.ask(messages)
58
  messages.append(response.message.model_dump())
 
 
 
 
 
59
 
60
  _LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
61
 
62
- final_resp = await self._handle_tool_calls(messages, response)
 
 
 
 
 
 
 
63
  return final_resp.message.content
 
5
  from ollama import AsyncClient, ChatResponse
6
 
7
  from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
8
+ from .db import Conversation, Message, _db, init_db
9
  from .log import get_logger
10
  from .schema import Msg
11
  from .tools import add_two_numbers
 
15
 
16
  class ChatSession:
17
  def __init__(self, host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
18
+ init_db()
19
  self._client = AsyncClient(host=host)
20
  self._model = model
21
 
22
  async def __aenter__(self) -> "ChatSession":
23
  return self
24
 
25
+ async def __aexit__(self, exc_type, exc, tb) -> None:
26
+ if not _db.is_closed():
27
+ _db.close()
28
+
29
  async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
30
  return await self._client.chat(
31
  self._model,
 
38
  self,
39
  messages: List[Msg],
40
  response: ChatResponse,
41
+ conversation: Conversation,
42
  depth: int = 0,
43
  ) -> ChatResponse:
44
  if depth >= MAX_TOOL_CALL_DEPTH or not response.message.tool_calls:
 
54
  "content": str(result),
55
  }
56
  )
57
+ Message.create(
58
+ conversation=conversation,
59
+ role="tool",
60
+ content=str(result),
61
+ )
62
  nxt = await self.ask(messages, think=True)
63
+ Message.create(
64
+ conversation=conversation,
65
+ role="assistant",
66
+ content=nxt.message.content,
67
+ )
68
+ return await self._handle_tool_calls(
69
+ messages, nxt, conversation, depth + 1
70
+ )
71
 
72
  return response
73
 
74
  async def chat(self, prompt: str) -> str:
75
+ conversation = Conversation.create()
76
+ Message.create(conversation=conversation, role="user", content=prompt)
77
  messages: List[Msg] = [{"role": "user", "content": prompt}]
78
  response = await self.ask(messages)
79
  messages.append(response.message.model_dump())
80
+ Message.create(
81
+ conversation=conversation,
82
+ role="assistant",
83
+ content=response.message.content,
84
+ )
85
 
86
  _LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
87
 
88
+ final_resp = await self._handle_tool_calls(messages, response, conversation)
89
+ if final_resp is not response:
90
+ # final response after handling tool calls
91
+ Message.create(
92
+ conversation=conversation,
93
+ role="assistant",
94
+ content=final_resp.message.content,
95
+ )
96
  return final_resp.message.content
src/config.py CHANGED
@@ -5,4 +5,4 @@ from typing import Final
5
 
6
  MODEL_NAME: Final[str] = os.getenv("OLLAMA_MODEL", "qwen3")
7
  OLLAMA_HOST: Final[str] = os.getenv("OLLAMA_HOST", "http://localhost:11434")
8
- MAX_TOOL_CALL_DEPTH: Final[int] = 5
 
5
 
6
  MODEL_NAME: Final[str] = os.getenv("OLLAMA_MODEL", "qwen3")
7
  OLLAMA_HOST: Final[str] = os.getenv("OLLAMA_HOST", "http://localhost:11434")
8
+ MAX_TOOL_CALL_DEPTH: Final[int] = 5
src/db.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+
6
+ from peewee import (
7
+ AutoField,
8
+ CharField,
9
+ DateTimeField,
10
+ ForeignKeyField,
11
+ Model,
12
+ SqliteDatabase,
13
+ TextField,
14
+ )
15
+
16
+
17
+ _DB_PATH = Path(__file__).resolve().parent.parent / "chat.db"
18
+ _db = SqliteDatabase(_DB_PATH)
19
+
20
+
21
+ class BaseModel(Model):
22
+ class Meta:
23
+ database = _db
24
+
25
+
26
+ class Conversation(BaseModel):
27
+ id = AutoField()
28
+ started_at = DateTimeField(default=datetime.utcnow)
29
+
30
+
31
+ class Message(BaseModel):
32
+ id = AutoField()
33
+ conversation = ForeignKeyField(Conversation, backref="messages")
34
+ role = CharField()
35
+ content = TextField()
36
+ created_at = DateTimeField(default=datetime.utcnow)
37
+
38
+
39
+ __all__ = ["_db", "Conversation", "Message"]
40
+
41
+
42
+ def init_db() -> None:
43
+ """Initialise the database and create tables if they do not exist."""
44
+ if _db.is_closed():
45
+ _db.connect()
46
+ _db.create_tables([Conversation, Message])
src/schema.py CHANGED
@@ -9,4 +9,4 @@ class Msg(TypedDict, total=False):
9
  role: Literal["user", "assistant", "tool"]
10
  content: str
11
  name: Optional[str]
12
- tool_calls: Optional[List[Message.ToolCall]]
 
9
  role: Literal["user", "assistant", "tool"]
10
  content: str
11
  name: Optional[str]
12
+ tool_calls: Optional[List[Message.ToolCall]]
src/tools.py CHANGED
@@ -1,6 +1,6 @@
1
  from __future__ import annotations
2
 
3
- __all__: list[str] = ["add_two_numbers"]
4
 
5
 
6
  def add_two_numbers(a: int, b: int) -> int: # noqa: D401
@@ -13,4 +13,4 @@ def add_two_numbers(a: int, b: int) -> int: # noqa: D401
13
  Returns:
14
  int: The sum of the two numbers.
15
  """
16
- return a + b
 
1
  from __future__ import annotations
2
 
3
+ __all__ = ["add_two_numbers"]
4
 
5
 
6
  def add_two_numbers(a: int, b: int) -> int: # noqa: D401
 
13
  Returns:
14
  int: The sum of the two numbers.
15
  """
16
+ return a + b