File size: 2,409 Bytes
031f9b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import logging
from typing import Dict, List, Optional
from dataclasses import dataclass
from huggingface_hub import InferenceClient

from config.settings import Settings

# Configure logger for detailed debugging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)

@dataclass
class LLMConfig:
    api_key: str
    model_name: str
    temperature: float = 0.01
    max_tokens: int = 512

class LLMService:
    def __init__(

        self,

        api_key: Optional[str] = None,

        model_name: Optional[str] = None,

    ):
        """

        LLMService that uses HuggingFace InferenceClient for chat completions.

        """
        settings = Settings()
        
        # Use provided values or fall back to settings
        key = api_key or settings.hf_token
        name = model_name or settings.effective_model_name
        
        self.config = LLMConfig(
            api_key=key,
            model_name=name,
            temperature=settings.hf_temperature,
            max_tokens=settings.hf_max_new_tokens,
        )
        
        # Initialize the InferenceClient
        self.client = InferenceClient(token=self.config.api_key)

    async def get_chat_completion(self, messages: List[Dict[str, str]]) -> str:
        """

        Return the assistant response for a chat-style messages array.

        """
        logger.debug(f"Chat completion request with model: {self.config.model_name}")
        
        try:
            # Use chat_completion method
            response = self.client.chat_completion(
                messages=messages,
                model=self.config.model_name,
                max_tokens=self.config.max_tokens,
                temperature=self.config.temperature
            )
            
            # Extract the content from the response
            content = response.choices[0].message.content
            logger.debug(f"Chat completion response: {content[:200]}")
            
            return content
            
        except Exception as e:
            logger.error(f"Chat completion error: {str(e)}")
            raise Exception(f"HF chat completion error: {str(e)}")