# tts_engine.py - TTS engine wrapper for CPU-friendly SpeechT5 import logging import os from typing import Optional import tempfile import numpy as np import soundfile as sf import torch from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan from datasets import load_dataset # To get speaker embeddings from VCTK logger = logging.getLogger(__name__) class CPUMultiSpeakerTTS: def __init__(self): self.processor = None self.model = None self.vocoder = None self.speaker_embeddings = {} # Will store speaker embeddings for S1, S2 etc. self._initialize_model() def _initialize_model(self): """Initialize the SpeechT5 model and vocoder on CPU.""" try: logger.info("Initializing SpeechT5 model for CPU...") self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") self.model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") # Ensure all components are on CPU explicitly self.model.to("cpu") self.vocoder.to("cpu") logger.info("SpeechT5 model and vocoder initialized successfully on CPU.") # Load speaker embeddings for multiple voices logger.info("Loading VCTK dataset for speaker embeddings...") # VCTK is a multi-speaker dataset used with SpeechT5 # We'll pick a few representative speaker embeddings for S1, S2, etc. # This loads the 'xvector' split of the vctk dataset which contains pre-computed embeddings embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") # Map 'S1' and 'S2' to specific speaker embeddings from the dataset # You can pick any speaker IDs from the dataset. # Common ones from VCTK for examples are 'p280', 'p272', 'p232', 'p249' etc. # Let's map S1 to a male voice and S2 to a female voice from common VCTK examples. # You can get a list of available speakers from the dataset: # print(embeddings_dataset.features['speaker_id'].names) # Let's use two distinct speakers for S1 and S2 # These are common speaker IDs from VCTK used in SpeechT5 examples self.speaker_embeddings["S1"] = torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0) # Speaker p280 self.speaker_embeddings["S2"] = torch.tensor(embeddings_dataset[1]["xvector"]).unsqueeze(0) # Speaker p272 # Ensure embeddings are also on CPU self.speaker_embeddings["S1"] = self.speaker_embeddings["S1"].to("cpu") self.speaker_embeddings["S2"] = self.speaker_embeddings["S2"].to("cpu") logger.info("Speaker embeddings loaded for S1 and S2.") except Exception as e: logger.error(f"Failed to initialize TTS model (SpeechT5): {e}", exc_info=True) self.processor = None self.model = None self.vocoder = None def synthesize_segment( self, text: str, speaker: str, # This will be 'S1' or 'S2' from segmenter output_path: str ) -> Optional[str]: """ Synthesize speech for a text segment using SpeechT5. Args: text: Text to synthesize speaker: Speaker identifier ('S1' or 'S2' expected from segmenter) output_path: Path to save the audio file Returns: Path to the generated audio file, or None if failed """ if not self.model or not self.processor or not self.vocoder: logger.error("SpeechT5 model, processor, or vocoder not initialized. Cannot synthesize speech.") return None try: # Get the correct speaker embedding speaker_embedding = self.speaker_embeddings.get(speaker) if speaker_embedding is None: logger.warning(f"Speaker '{speaker}' not found in pre-loaded embeddings. Defaulting to S1.") speaker_embedding = self.speaker_embeddings["S1"] # Fallback to S1 logger.info(f"Synthesizing text for speaker {speaker}: {text[:100]}...") # Prepare inputs inputs = self.processor(text=text, return_tensors="pt") # Ensure inputs are on CPU inputs = {k: v.to("cpu") for k, v in inputs.items()} with torch.no_grad(): # Generate speech # SpeechT5 returns logits/features, which then need to be passed to the vocoder speech = self.model.generate_speech( inputs["input_ids"], speaker_embedding, # Pass the speaker embedding here vocoder=self.vocoder ) audio_waveform = speech.cpu().numpy().squeeze() # Sampling rate from the vocoder or model config (typically 16000 for SpeechT5) sampling_rate = self.vocoder.config.sampling_rate if hasattr(self.vocoder.config, 'sampling_rate') else 16000 sf.write(output_path, audio_waveform, sampling_rate) logger.info(f"Generated audio for {speaker}: {len(text)} characters to {output_path}") return output_path except Exception as e: logger.error(f"Failed to synthesize segment with SpeechT5: {e}", exc_info=True) return None