File size: 4,429 Bytes
031f9b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import io
import base64
import logging
import tempfile
import asyncio
from typing import Optional, Union
from pathlib import Path

from huggingface_hub import InferenceClient

from config.settings import Settings

# Configure logger for detailed debugging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)

class AudioService:
    def __init__(

        self,

        api_key: str,

        stt_provider: str = "fal-ai",

        stt_model: str = "openai/whisper-large-v3",

        tts_model: str = "canopylabs/orpheus-3b-0.1-ft",

    ):
        """

        AudioService with separate providers for ASR and TTS.



        :param api_key: Hugging Face API token

        :param stt_provider: Provider for speech-to-text (e.g., "fal-ai")

        :param stt_model: ASR model ID

        :param tts_model: TTS model ID

        """
        self.api_key = api_key
        self.stt_model = stt_model
        self.tts_model = tts_model
        
        # Speech-to-Text client
        logger.debug(f"Initializing ASR client with provider={stt_provider}")
        self.asr_client = InferenceClient(
            provider=stt_provider,
            api_key=self.api_key,
        )
        
        # Text-to-Speech client (no provider needed, use token parameter)
        logger.debug(f"Initializing TTS client with default provider")
        self.tts_client = InferenceClient(token=self.api_key)
        
        logger.info(f"AudioService configured: ASR model={self.stt_model} via {stt_provider}, TTS model={self.tts_model} via default provider.")

    async def speech_to_text(self, audio_file: Union[str, bytes, io.BytesIO]) -> str:
        """

        Convert speech to text using the configured ASR provider.

        """
        # Prepare input path
        if isinstance(audio_file, str):
            input_path = audio_file
            logger.debug(f"Using existing file for ASR: {input_path}")
        else:
            data = audio_file.getvalue() if isinstance(audio_file, io.BytesIO) else audio_file
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
            tmp.write(data)
            tmp.close()
            input_path = tmp.name
            logger.debug(f"Wrote audio to temp file for ASR: {input_path}")

        # Call ASR synchronously in executor
        try:
            logger.info(f"Calling ASR model={self.stt_model}")
            result = await asyncio.get_event_loop().run_in_executor(
                None,
                lambda: self.asr_client.automatic_speech_recognition(
                    input_path,
                    model=self.stt_model,
                )
            )
            # Parse result
            transcript = result.get("text") if isinstance(result, dict) else getattr(result, "text", "")
            logger.info(f"ASR success, transcript length={len(transcript)}")
            logger.debug(f"Transcript preview: {transcript[:100]}")
            return transcript or ""
        except Exception as e:
            logger.error(f"ASR error: {e}", exc_info=True)
            return ""

    async def text_to_speech(self, text: str) -> Optional[bytes]:
        """

        Convert text to speech using the configured TTS provider.

        """
        if not text.strip():
            logger.debug("Empty text input for TTS. Skipping generation.")
            return None

        def _call_tts():
            """Wrapper function to handle StopIteration properly."""
            try:
                return self.tts_client.text_to_speech(text, model=self.tts_model)
            except StopIteration as e:
                # Convert StopIteration to RuntimeError to prevent Future issues
                raise RuntimeError(f"StopIteration in TTS call: {e}")

        try:
            logger.info(f"Calling TTS model={self.tts_model}, text length={len(text)}")
            audio = await asyncio.get_event_loop().run_in_executor(None, _call_tts)
            logger.info(f"TTS success, received {len(audio)} bytes")
            return audio
        except Exception as e:
            logger.error(f"TTS error: {e}", exc_info=True)
            return None