llmOS-Agent / src /chat.py
tech-envision
fix duplicate junior messages
e6ecb98
raw
history blame
17.4 kB
from __future__ import annotations
from typing import List, AsyncIterator
from dataclasses import dataclass, field
import json
import asyncio
import shutil
from pathlib import Path
from ollama import AsyncClient, ChatResponse, Message
from .config import (
MAX_TOOL_CALL_DEPTH,
MODEL_NAME,
NUM_CTX,
OLLAMA_HOST,
SYSTEM_PROMPT,
UPLOAD_DIR,
)
from .db import (
Conversation,
Message as DBMessage,
User,
_db,
init_db,
add_document,
)
from .log import get_logger
from .schema import Msg
from .tools import execute_terminal, execute_terminal_async, set_vm
from .vm import VMRegistry
@dataclass
class _SessionData:
"""Shared state for each conversation session."""
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
state: str = "idle"
tool_task: asyncio.Task | None = None
_SESSION_DATA: dict[int, _SessionData] = {}
def _get_session_data(conv_id: int) -> _SessionData:
data = _SESSION_DATA.get(conv_id)
if data is None:
data = _SessionData()
_SESSION_DATA[conv_id] = data
return data
_LOG = get_logger(__name__)
class ChatSession:
def __init__(
self,
user: str = "default",
session: str = "default",
host: str = OLLAMA_HOST,
model: str = MODEL_NAME,
*,
system_prompt: str = SYSTEM_PROMPT,
tools: list[callable] | None = None,
) -> None:
init_db()
self._client = AsyncClient(host=host)
self._model = model
self._user, _ = User.get_or_create(username=user)
self._conversation, _ = Conversation.get_or_create(
user=self._user, session_name=session
)
self._vm = None
self._system_prompt = system_prompt
self._tools = tools or [execute_terminal]
self._tool_funcs = {func.__name__: func for func in self._tools}
self._current_tool_name: str | None = None
self._messages: List[Msg] = self._load_history()
self._data = _get_session_data(self._conversation.id)
self._lock = self._data.lock
self._prompt_queue: asyncio.Queue[
tuple[str, asyncio.Queue[str | None]]
] = asyncio.Queue()
self._worker: asyncio.Task | None = None
# Shared state properties -------------------------------------------------
@property
def _state(self) -> str:
return self._data.state
@_state.setter
def _state(self, value: str) -> None:
self._data.state = value
@property
def _tool_task(self) -> asyncio.Task | None:
return self._data.tool_task
@_tool_task.setter
def _tool_task(self, task: asyncio.Task | None) -> None:
self._data.tool_task = task
async def __aenter__(self) -> "ChatSession":
self._vm = VMRegistry.acquire(self._user.username)
set_vm(self._vm)
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
set_vm(None)
if self._vm:
VMRegistry.release(self._user.username)
if not _db.is_closed():
_db.close()
def upload_document(self, file_path: str) -> str:
"""Save a document for later access inside the VM.
The file is copied into ``UPLOAD_DIR`` and recorded in the database. The
returned path is the location inside the VM (prefixed with ``/data``).
"""
src = Path(file_path)
if not src.exists():
raise FileNotFoundError(file_path)
dest = Path(UPLOAD_DIR) / self._user.username
dest.mkdir(parents=True, exist_ok=True)
target = dest / src.name
shutil.copy(src, target)
add_document(self._user.username, str(target), src.name)
return f"/data/{src.name}"
def _load_history(self) -> List[Msg]:
messages: List[Msg] = []
for msg in self._conversation.messages.order_by(DBMessage.created_at):
if msg.role == "system":
# Skip persisted system prompts from older versions
continue
if msg.role == "assistant":
try:
calls = json.loads(msg.content)
except json.JSONDecodeError:
messages.append({"role": "assistant", "content": msg.content})
else:
messages.append(
{
"role": "assistant",
"tool_calls": [Message.ToolCall(**c) for c in calls],
}
)
elif msg.role == "user":
messages.append({"role": "user", "content": msg.content})
else:
messages.append({"role": "tool", "content": msg.content})
return messages
# ------------------------------------------------------------------
@staticmethod
def _serialize_tool_calls(calls: List[Message.ToolCall]) -> str:
"""Convert tool calls to a JSON string for storage or output."""
return json.dumps([c.model_dump() for c in calls])
@staticmethod
def _format_output(message: Message) -> str:
"""Return tool calls as JSON or message content if present."""
# if message.tool_calls:
# return ChatSession._serialize_tool_calls(message.tool_calls)
return message.content or ""
@staticmethod
def _remove_tool_placeholder(messages: List[Msg]) -> None:
"""Remove the pending placeholder tool message if present."""
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
if (
msg.get("role") == "tool"
and msg.get("content") == "Awaiting tool response..."
):
messages.pop(i)
break
@staticmethod
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
"""Persist assistant messages, storing tool calls when present."""
if message.tool_calls:
content = ChatSession._serialize_tool_calls(message.tool_calls)
else:
content = message.content or ""
if content.strip():
DBMessage.create(
conversation=conversation, role="assistant", content=content
)
async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
"""Send a chat request, automatically prepending the system prompt."""
if not messages or messages[0].get("role") != "system":
payload = [{"role": "system", "content": self._system_prompt}, *messages]
else:
payload = messages
return await self._client.chat(
self._model,
messages=payload,
think=think,
tools=self._tools,
options={"num_ctx": NUM_CTX},
)
async def _run_tool_async(self, func, **kwargs) -> str:
if asyncio.iscoroutinefunction(func):
return await func(**kwargs)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: func(**kwargs))
async def _handle_tool_calls_stream(
self,
messages: List[Msg],
response: ChatResponse,
conversation: Conversation,
depth: int = 0,
) -> AsyncIterator[ChatResponse]:
if not response.message.tool_calls:
if response.message.content:
yield response
async with self._lock:
self._state = "idle"
return
while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
for call in response.message.tool_calls:
func = self._tool_funcs.get(call.function.name)
if not func:
_LOG.warning("Unsupported tool call: %s", call.function.name)
result = f"Unsupported tool: {call.function.name}"
name = (
"junior" if call.function.name == "send_to_junior" else call.function.name
)
messages.append({"role": "tool", "name": name, "content": result})
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
continue
exec_task = asyncio.create_task(
self._run_tool_async(func, **call.function.arguments)
)
self._current_tool_name = call.function.name
placeholder = {
"role": "tool",
"name": "junior" if call.function.name == "send_to_junior" else call.function.name,
"content": "Awaiting tool response...",
}
messages.append(placeholder)
follow_task = asyncio.create_task(self.ask(messages, think=True))
async with self._lock:
self._state = "awaiting_tool"
self._tool_task = exec_task
done, _ = await asyncio.wait(
{exec_task, follow_task},
return_when=asyncio.FIRST_COMPLETED,
)
if exec_task in done:
follow_task.cancel()
try:
await follow_task
except asyncio.CancelledError:
pass
self._remove_tool_placeholder(messages)
result = await exec_task
name = (
"junior" if call.function.name == "send_to_junior" else call.function.name
)
messages.append({"role": "tool", "name": name, "content": result})
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
async with self._lock:
self._state = "generating"
self._tool_task = None
nxt = await self.ask(messages, think=True)
self._store_assistant_message(conversation, nxt.message)
messages.append(nxt.message.model_dump())
response = nxt
yield nxt
else:
followup = await follow_task
self._store_assistant_message(conversation, followup.message)
messages.append(followup.message.model_dump())
yield followup
result = await exec_task
self._remove_tool_placeholder(messages)
name = (
"junior" if call.function.name == "send_to_junior" else call.function.name
)
messages.append({"role": "tool", "name": name, "content": result})
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
async with self._lock:
self._state = "generating"
self._tool_task = None
nxt = await self.ask(messages, think=True)
self._store_assistant_message(conversation, nxt.message)
messages.append(nxt.message.model_dump())
response = nxt
yield nxt
depth += 1
async with self._lock:
self._state = "idle"
async def _generate_stream(self, prompt: str) -> AsyncIterator[str]:
async with self._lock:
if self._state == "awaiting_tool" and self._tool_task:
async for part in self._chat_during_tool(prompt):
yield part
return
self._state = "generating"
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
response = await self.ask(self._messages)
self._messages.append(response.message.model_dump())
self._store_assistant_message(self._conversation, response.message)
async for resp in self._handle_tool_calls_stream(
self._messages, response, self._conversation
):
text = self._format_output(resp.message)
if text:
yield text
async def _process_prompt_queue(self) -> None:
try:
while not self._prompt_queue.empty():
prompt, result_q = await self._prompt_queue.get()
try:
async for part in self._generate_stream(prompt):
await result_q.put(part)
except Exception as exc: # pragma: no cover - unforeseen errors
_LOG.exception("Error processing prompt: %s", exc)
await result_q.put(f"Error: {exc}")
finally:
await result_q.put(None)
finally:
self._worker = None
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
result_q: asyncio.Queue[str | None] = asyncio.Queue()
await self._prompt_queue.put((prompt, result_q))
if not self._worker or self._worker.done():
self._worker = asyncio.create_task(self._process_prompt_queue())
while True:
part = await result_q.get()
if part is None:
break
yield part
async def continue_stream(self) -> AsyncIterator[str]:
async with self._lock:
if self._state != "idle":
return
self._state = "generating"
response = await self.ask(self._messages)
self._messages.append(response.message.model_dump())
self._store_assistant_message(self._conversation, response.message)
async for resp in self._handle_tool_calls_stream(
self._messages, response, self._conversation
):
text = self._format_output(resp.message)
if text:
yield text
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
user_task = asyncio.create_task(self.ask(self._messages))
exec_task = self._tool_task
done, _ = await asyncio.wait(
{exec_task, user_task},
return_when=asyncio.FIRST_COMPLETED,
)
if exec_task in done:
user_task.cancel()
try:
await user_task
except asyncio.CancelledError:
pass
self._remove_tool_placeholder(self._messages)
result = await exec_task
self._tool_task = None
name = self._current_tool_name or "tool"
self._current_tool_name = None
self._messages.append({"role": "tool", "name": name, "content": result})
DBMessage.create(
conversation=self._conversation, role="tool", content=result
)
async with self._lock:
self._state = "generating"
nxt = await self.ask(self._messages, think=True)
self._store_assistant_message(self._conversation, nxt.message)
self._messages.append(nxt.message.model_dump())
text = self._format_output(nxt.message)
if text:
yield text
async for part in self._handle_tool_calls_stream(
self._messages, nxt, self._conversation
):
text = self._format_output(part.message)
if text:
yield text
else:
resp = await user_task
self._store_assistant_message(self._conversation, resp.message)
self._messages.append(resp.message.model_dump())
async with self._lock:
self._state = "awaiting_tool"
text = self._format_output(resp.message)
if text:
yield text
result = await exec_task
self._tool_task = None
self._remove_tool_placeholder(self._messages)
name = self._current_tool_name or "tool"
self._current_tool_name = None
self._messages.append({"role": "tool", "name": name, "content": result})
DBMessage.create(
conversation=self._conversation, role="tool", content=result
)
async with self._lock:
self._state = "generating"
nxt = await self.ask(self._messages, think=True)
self._store_assistant_message(self._conversation, nxt.message)
self._messages.append(nxt.message.model_dump())
text = self._format_output(nxt.message)
if text:
yield text
async for part in self._handle_tool_calls_stream(
self._messages, nxt, self._conversation
):
text = self._format_output(part.message)
if text:
yield text