|
import threading
|
|
import queue
|
|
import time
|
|
import base64
|
|
import io
|
|
import logging
|
|
from typing import Callable, Optional, List, Dict
|
|
|
|
import mss
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from openai import OpenAI
|
|
from config.settings import Settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ScreenService:
|
|
def __init__(
|
|
self,
|
|
prompt: str,
|
|
model: str,
|
|
fps: float = 0.5,
|
|
queue_size: int = 2,
|
|
monitor: int = 1,
|
|
max_width: int = 3440,
|
|
max_height: int = 1440,
|
|
compression_quality: int = 100,
|
|
image_format: str = "PNG",
|
|
):
|
|
"""
|
|
:param prompt: Vision model instruction
|
|
:param model: Nebius model name
|
|
:param fps: Capture frames per second
|
|
:param queue_size: Internal buffer size
|
|
:param monitor: MSS monitor index
|
|
:param max_width/max_height: Max resolution for resizing
|
|
:param compression_quality: JPEG quality (1-100)
|
|
:param image_format: "JPEG" or "PNG" (PNG is lossless)
|
|
"""
|
|
self.prompt = prompt
|
|
self.model = model
|
|
self.fps = fps
|
|
self.queue: queue.Queue = queue.Queue(maxsize=queue_size)
|
|
self.monitor = monitor
|
|
self.max_width = max_width
|
|
self.max_height = max_height
|
|
self.compression_quality = compression_quality
|
|
self.image_format = image_format.upper()
|
|
|
|
self._stop_event = threading.Event()
|
|
self._producer: Optional[threading.Thread] = None
|
|
self._consumer: Optional[threading.Thread] = None
|
|
|
|
|
|
self.client = OpenAI(
|
|
base_url=Settings.NEBIUS_BASE_URL,
|
|
api_key=Settings.NEBIUS_API_KEY
|
|
)
|
|
|
|
def _process_image(self, img: Image.Image) -> Image.Image:
|
|
|
|
if img.mode != "RGB":
|
|
img = img.convert("RGB")
|
|
w, h = img.size
|
|
ar = w / h
|
|
|
|
if w > self.max_width or h > self.max_height:
|
|
if ar > 1:
|
|
new_w = min(w, self.max_width)
|
|
new_h = int(new_w / ar)
|
|
else:
|
|
new_h = min(h, self.max_height)
|
|
new_w = int(new_h * ar)
|
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
return img
|
|
|
|
def _image_to_base64(self, img: Image.Image) -> str:
|
|
buf = io.BytesIO()
|
|
if self.image_format == "PNG":
|
|
img.save(buf, format="PNG")
|
|
else:
|
|
img.save(
|
|
buf,
|
|
format="JPEG",
|
|
quality=self.compression_quality,
|
|
optimize=True
|
|
)
|
|
data = buf.getvalue()
|
|
return base64.b64encode(data).decode("utf-8")
|
|
|
|
def _capture_loop(self):
|
|
with mss.mss() as sct:
|
|
mon = sct.monitors[self.monitor]
|
|
interval = 1.0 / self.fps if self.fps > 0 else 0
|
|
while not self._stop_event.is_set():
|
|
t0 = time.time()
|
|
frame = np.array(sct.grab(mon))
|
|
pil = Image.fromarray(frame)
|
|
pil = self._process_image(pil)
|
|
b64 = self._image_to_base64(pil)
|
|
try:
|
|
self.queue.put_nowait((t0, b64))
|
|
except queue.Full:
|
|
self.queue.get_nowait()
|
|
self.queue.put_nowait((t0, b64))
|
|
if interval:
|
|
time.sleep(interval)
|
|
|
|
def _flatten_conversation_history(self, history: List[Dict[str, str]]) -> str:
|
|
"""Flatten conversation history into a readable format for the vision model"""
|
|
if not history:
|
|
return "No previous conversation."
|
|
|
|
|
|
filtered_history = []
|
|
for msg in history:
|
|
role = msg.get('role', '')
|
|
content = msg.get('content', '')
|
|
|
|
|
|
if role == 'system':
|
|
continue
|
|
if content.startswith('VISION MODEL OUTPUT:'):
|
|
continue
|
|
if 'screen' in content.lower() and 'sharing' in content.lower():
|
|
continue
|
|
|
|
filtered_history.append(msg)
|
|
|
|
|
|
if len(filtered_history) > 20:
|
|
filtered_history = filtered_history[-20:]
|
|
|
|
|
|
formatted_lines = []
|
|
for msg in filtered_history:
|
|
role = msg.get('role', 'unknown')
|
|
content = msg.get('content', '')
|
|
|
|
|
|
if len(content) > 200:
|
|
content = content[:200] + "..."
|
|
|
|
if role == 'user':
|
|
formatted_lines.append(f"User: {content}")
|
|
elif role == 'assistant':
|
|
formatted_lines.append(f"Assistant: {content}")
|
|
|
|
return "\n".join(formatted_lines) if formatted_lines else "No relevant conversation history."
|
|
|
|
def _inference_loop(
|
|
self,
|
|
callback: Callable[[Dict, float, str], None],
|
|
history_getter: Callable[[], List[Dict[str, str]]]
|
|
):
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
t0, frame_b64 = self.queue.get(timeout=1)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
|
|
history = history_getter()
|
|
flattened_history = self._flatten_conversation_history(history)
|
|
|
|
|
|
full_prompt = f"{self.prompt}\n\nCONVERSATION CONTEXT:\n{flattened_history}"
|
|
|
|
for i, msg in enumerate(history):
|
|
content_preview = msg.get('content', '')[:100] + "..." if len(msg.get('content', '')) > 100 else msg.get('content', '')
|
|
|
|
user_message = {
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": full_prompt},
|
|
{"type": "image_url", "image_url": {"url": f"data:image/{self.image_format.lower()};base64,{frame_b64}"}}
|
|
]
|
|
}
|
|
|
|
try:
|
|
resp = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[user_message]
|
|
)
|
|
latency = time.time() - t0
|
|
callback(resp, latency, frame_b64)
|
|
except Exception as e:
|
|
logger.error(f"Nebius inference error: {e}")
|
|
|
|
def start(
|
|
self,
|
|
callback: Callable[[Dict, float, str], None],
|
|
history_getter: Callable[[], List[Dict[str, str]]]
|
|
) -> None:
|
|
if self._producer and self._producer.is_alive():
|
|
return
|
|
self._stop_event.clear()
|
|
self._producer = threading.Thread(target=self._capture_loop, daemon=True)
|
|
self._consumer = threading.Thread(
|
|
target=self._inference_loop,
|
|
args=(callback, history_getter),
|
|
daemon=True
|
|
)
|
|
self._producer.start()
|
|
self._consumer.start()
|
|
logger.info("ScreenService started.")
|
|
|
|
def stop(self) -> None:
|
|
self._stop_event.set()
|
|
if self._producer:
|
|
self._producer.join(timeout=1.0)
|
|
if self._consumer:
|
|
self._consumer.join(timeout=1.0)
|
|
logger.info("ScreenService stopped.") |