Nick021402 commited on
Commit
1f6c376
Β·
verified Β·
1 Parent(s): 579c072

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +52 -10
  2. app.py +262 -0
  3. audio_utils.py +110 -0
  4. gitattributes (2).txt +35 -0
  5. requirements.txt +11 -0
  6. segmenter.py +139 -0
  7. tts_engine.py +110 -0
README.md CHANGED
@@ -1,14 +1,56 @@
1
  ---
2
- title: PodXplainClone
3
- emoji: πŸ‘€
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.31.0
8
- app_file: app.py
9
- pinned: false
10
  license: mit
11
- short_description: 'PodXplain '
 
 
 
 
 
 
12
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
 
 
 
 
 
 
 
 
2
  license: mit
3
+ title: πŸŽ™οΈ PodXplain
4
+ sdk: gradio
5
+ emoji: πŸ“š
6
+ colorFrom: red
7
+ colorTo: blue
8
+ pinned: true
9
+ short_description: PodXplain is a Hugging Face-hosted application that converts
10
  ---
11
+ # πŸŽ™οΈ PodXplain
12
+
13
+ **From script to story β€” voice it like never before.**
14
+
15
+ PodXplain is a Hugging Face-hosted application that converts long-form text into engaging multi-speaker podcast-style audio. Simply input your script, and get a professional-sounding MP3 podcast with automatic speaker detection and assignment.
16
+
17
+ ## ✨ Features
18
+
19
+ - **πŸ“ Long-form Support**: Handle up to 50,000 characters of text
20
+ - **🎭 Multi-speaker Audio**: Automatic speaker detection and assignment
21
+ - **πŸ”„ Smart Segmentation**: Intelligent text splitting with progress tracking
22
+ - **🎡 High-quality Output**: MP3 format for optimal file size and compatibility
23
+ - **πŸš€ Real-time Progress**: Live updates during generation
24
+ - **🎨 Modern UI**: Clean, intuitive Gradio interface
25
+
26
+ ## πŸ› οΈ Tech Stack
27
+
28
+ - **Frontend**: Gradio for interactive web interface
29
+ - **TTS Engine**: Nari DIA 1.6B for natural voice synthesis (currently mocked)
30
+ - **Audio Processing**: pydub for audio manipulation and MP3 conversion
31
+ - **Hosting**: Hugging Face Spaces with GPU support
32
+
33
+ ## πŸ“‹ How to Use
34
+
35
+ 1. **Input Text**: Paste or type your podcast script (up to 50,000 characters)
36
+ 2. **Choose Mode**: Select speaker detection mode:
37
+ * **Auto**: Smart detection based on content structure
38
+ * **Paragraph**: Speaker changes at paragraph breaks
39
+ * **Dialogue**: Detection based on dialogue markers
40
+ 3. **Generate**: Click "Generate Podcast" and watch the progress
41
+ 4. **Download**: Get your MP3 file and listen to your podcast!
42
+
43
+ ## πŸš€ Quick Start
44
+
45
+ ### Local Development
46
+
47
+ ```bash
48
+ # Clone the repository
49
+ git clone [https://github.com/yourusername/podxplain.git](https://github.com/yourusername/podxplain.git) # Replace with your actual repo URL
50
+ cd podxplain
51
+
52
+ # Install dependencies
53
+ pip install -r requirements.txt
54
 
55
+ # Run the application
56
+ python app.py
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - Main Gradio application
2
+ import gradio as gr
3
+ import os
4
+ import tempfile
5
+ import shutil
6
+ from pathlib import Path
7
+ import asyncio
8
+ from typing import List, Tuple, Generator
9
+ import logging
10
+ from datetime import datetime
11
+
12
+ # Import our custom modules
13
+ from segmenter import TextSegmenter
14
+ from tts_engine import NariDIAEngine
15
+ from audio_utils import AudioProcessor
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class PodXplainApp:
22
+ def __init__(self):
23
+ self.segmenter = TextSegmenter()
24
+ self.tts_engine = NariDIAEngine()
25
+ self.audio_processor = AudioProcessor()
26
+ self.temp_dir = None
27
+
28
+ def create_temp_directory(self) -> str:
29
+ """Create a temporary directory for processing."""
30
+ if self.temp_dir:
31
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
32
+ self.temp_dir = tempfile.mkdtemp(prefix="podxplain_")
33
+ return self.temp_dir
34
+
35
+ def cleanup_temp_directory(self):
36
+ """Clean up temporary files."""
37
+ if self.temp_dir and os.path.exists(self.temp_dir):
38
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
39
+ self.temp_dir = None
40
+
41
+ def generate_podcast(
42
+ self,
43
+ text: str,
44
+ speaker_detection_mode: str = "auto",
45
+ progress=gr.Progress()
46
+ ) -> Tuple[str, str]:
47
+ """
48
+ Main function to convert text to podcast audio.
49
+
50
+ Args:
51
+ text: Input text (up to 50,000 characters)
52
+ speaker_detection_mode: How to detect speaker changes
53
+ progress: Gradio progress tracker
54
+
55
+ Returns:
56
+ Tuple of (audio_path, status_message)
57
+ """
58
+ try:
59
+ # Validate input
60
+ if not text or len(text.strip()) == 0:
61
+ return None, "❌ Please provide some text to convert."
62
+
63
+ if len(text) > 50000:
64
+ return None, f"❌ Text too long ({len(text)} chars). Maximum is 50,000 characters."
65
+
66
+ # Create temporary directory
67
+ temp_dir = self.create_temp_directory()
68
+ progress(0, desc="πŸš€ Starting podcast generation...")
69
+
70
+ # Step 1: Segment text and assign speakers
71
+ progress(0.1, desc="πŸ“ Analyzing text and assigning speakers...")
72
+ segments = self.segmenter.segment_and_assign_speakers(
73
+ text, mode=speaker_detection_mode
74
+ )
75
+
76
+ if not segments:
77
+ return None, "❌ Could not process the text. Please check the input."
78
+
79
+ logger.info(f"Generated {len(segments)} segments")
80
+
81
+ # Step 2: Generate audio for each segment
82
+ progress(0.2, desc="🎀 Generating audio segments...")
83
+ audio_files = []
84
+
85
+ for i, (speaker, segment_text) in enumerate(segments):
86
+ progress(
87
+ 0.2 + (0.7 * i / len(segments)),
88
+ desc=f"🎡 Processing segment {i+1}/{len(segments)} (Speaker {speaker})"
89
+ )
90
+
91
+ # Generate audio for this segment
92
+ audio_path = self.tts_engine.synthesize_segment(
93
+ segment_text,
94
+ speaker,
95
+ os.path.join(temp_dir, f"segment_{i:03d}.wav")
96
+ )
97
+
98
+ if audio_path:
99
+ audio_files.append(audio_path)
100
+ else:
101
+ logger.warning(f"Failed to generate audio for segment {i}")
102
+
103
+ if not audio_files:
104
+ return None, "❌ Failed to generate any audio segments."
105
+
106
+ # Step 3: Merge audio files and convert to MP3
107
+ progress(0.9, desc="πŸ”§ Merging segments and converting to MP3...")
108
+ final_audio_path = self.audio_processor.merge_and_convert_to_mp3(
109
+ audio_files,
110
+ os.path.join(temp_dir, "podcast_output.mp3")
111
+ )
112
+
113
+ if not final_audio_path:
114
+ return None, "❌ Failed to merge audio segments."
115
+
116
+ progress(1.0, desc="βœ… Podcast generated successfully!")
117
+
118
+ # Generate summary
119
+ total_segments = len(segments)
120
+ speakers_used = len(set(speaker for speaker, _ in segments))
121
+ duration_estimate = len(text) / 1000 * 60 # Rough estimate: 1000 chars β‰ˆ 1 minute
122
+
123
+ status_message = f"""
124
+ βœ… **Podcast Generated Successfully!**
125
+
126
+ πŸ“Š **Statistics:**
127
+ - Total segments: {total_segments}
128
+ - Speakers used: {speakers_used}
129
+ - Estimated duration: {duration_estimate:.1f} minutes
130
+ - Character count: {len(text):,}
131
+
132
+ 🎧 **Your podcast is ready for download!**
133
+ """
134
+
135
+ return final_audio_path, status_message
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error generating podcast: {str(e)}")
139
+ return None, f"❌ Error: {str(e)}"
140
+
141
+ finally:
142
+ # Clean up temporary files (except the final output)
143
+ # Note: We keep the final MP3 for download
144
+ pass
145
+
146
+ def create_gradio_interface():
147
+ """Create the Gradio interface."""
148
+ app = PodXplainApp()
149
+
150
+ # Custom CSS for better styling
151
+ css = """
152
+ .main-container {
153
+ max-width: 1200px;
154
+ margin: 0 auto;
155
+ }
156
+ .header {
157
+ text-align: center;
158
+ padding: 20px;
159
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
160
+ color: white;
161
+ border-radius: 10px;
162
+ margin-bottom: 20px;
163
+ }
164
+ .footer {
165
+ text-align: center;
166
+ padding: 20px;
167
+ color: #666;
168
+ font-size: 0.9em;
169
+ }
170
+ """
171
+
172
+ with gr.Blocks(css=css, title="PodXplain - Text to Podcast") as interface:
173
+ # Header
174
+ gr.HTML("""
175
+ <div class="header">
176
+ <h1>πŸŽ™οΈ PodXplain</h1>
177
+ <p><em>From script to story β€” voice it like never before.</em></p>
178
+ </div>
179
+ """)
180
+
181
+ with gr.Row():
182
+ with gr.Column(scale=2):
183
+ # Input section
184
+ gr.Markdown("## πŸ“ Input Your Script")
185
+
186
+ text_input = gr.Textbox(
187
+ label="Podcast Script",
188
+ placeholder="Enter your podcast script here (up to 50,000 characters).\n\nTip: Use paragraph breaks to help with speaker detection.",
189
+ lines=15,
190
+ max_lines=20,
191
+ show_label=True
192
+ )
193
+
194
+ char_count = gr.HTML("Characters: 0 / 50,000")
195
+
196
+ # Options
197
+ speaker_mode = gr.Radio(
198
+ choices=["auto", "paragraph", "dialogue"],
199
+ value="auto",
200
+ label="Speaker Detection Mode",
201
+ info="How to detect when speakers change"
202
+ )
203
+
204
+ generate_btn = gr.Button(
205
+ "🎀 Generate Podcast",
206
+ variant="primary",
207
+ size="lg"
208
+ )
209
+
210
+ with gr.Column(scale=1):
211
+ # Output section
212
+ gr.Markdown("## 🎧 Your Podcast")
213
+
214
+ status_output = gr.Markdown("Ready to generate your podcast!")
215
+
216
+ audio_output = gr.Audio(
217
+ label="Generated Podcast",
218
+ show_download_button=True,
219
+ interactive=False
220
+ )
221
+
222
+ # Footer with instructions
223
+ gr.HTML("""
224
+ <div class="footer">
225
+ <h3>πŸ“‹ How to Use PodXplain</h3>
226
+ <ol>
227
+ <li><strong>Write your script:</strong> Enter up to 50,000 characters of text</li>
228
+ <li><strong>Choose speaker mode:</strong> Auto-detect, paragraph-based, or dialogue-based</li>
229
+ <li><strong>Generate:</strong> Click the button and wait for processing</li>
230
+ <li><strong>Listen & Download:</strong> Your MP3 podcast will be ready!</li>
231
+ </ol>
232
+ <p><strong>πŸ’‘ Tips:</strong> Use clear paragraph breaks for better speaker detection.
233
+ Write naturally as if speaking to an audience.</p>
234
+ </div>
235
+ """)
236
+
237
+ # JavaScript for character counting
238
+ text_input.change(
239
+ fn=lambda text: f"Characters: {len(text) if text else 0:,} / 50,000",
240
+ inputs=[text_input],
241
+ outputs=[char_count]
242
+ )
243
+
244
+ # Main generation function
245
+ generate_btn.click(
246
+ fn=app.generate_podcast,
247
+ inputs=[text_input, speaker_mode],
248
+ outputs=[audio_output, status_output],
249
+ show_progress=True
250
+ )
251
+
252
+ return interface
253
+
254
+ if __name__ == "__main__":
255
+ # Create and launch the interface
256
+ interface = create_gradio_interface()
257
+ interface.launch(
258
+ share=True,
259
+ server_name="0.0.0.0",
260
+ server_port=7860,
261
+ show_error=True
262
+ )
audio_utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # audio_utils.py - Audio processing utilities
2
+ import logging
3
+ from typing import List, Optional
4
+ import os
5
+ import tempfile
6
+ from pydub import AudioSegment
7
+ from pydub.utils import which
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class AudioProcessor:
12
+ def __init__(self):
13
+ self._check_dependencies()
14
+
15
+ def _check_dependencies(self):
16
+ """Check if required audio processing tools are available."""
17
+ # Check for ffmpeg
18
+ if not which("ffmpeg"):
19
+ logger.warning("ffmpeg not found. Some audio operations may fail.")
20
+
21
+ def merge_and_convert_to_mp3(
22
+ self,
23
+ audio_files: List[str],
24
+ output_path: str
25
+ ) -> Optional[str]:
26
+ """
27
+ Merge multiple audio files and convert to MP3.
28
+
29
+ Args:
30
+ audio_files: List of paths to audio files to merge
31
+ output_path: Path for the output MP3 file
32
+
33
+ Returns:
34
+ Path to the merged MP3 file, or None if failed
35
+ """
36
+ try:
37
+ if not audio_files:
38
+ logger.error("No audio files to merge")
39
+ return None
40
+
41
+ logger.info(f"Merging {len(audio_files)} audio files...")
42
+
43
+ # Start with empty audio
44
+ merged_audio = AudioSegment.empty()
45
+
46
+ for i, audio_file in enumerate(audio_files):
47
+ if not os.path.exists(audio_file):
48
+ logger.warning(f"Audio file not found: {audio_file}")
49
+ continue
50
+
51
+ try:
52
+ # Load audio segment
53
+ segment = AudioSegment.from_wav(audio_file)
54
+
55
+ # Add a small pause between segments (500ms)
56
+ if i > 0:
57
+ pause = AudioSegment.silent(duration=500)
58
+ merged_audio += pause
59
+
60
+ # Add the segment
61
+ merged_audio += segment
62
+
63
+ logger.info(f"Added segment {i+1}/{len(audio_files)}")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Failed to process audio file {audio_file}: {e}")
67
+ continue
68
+
69
+ if len(merged_audio) == 0:
70
+ logger.error("No audio content to export")
71
+ return None
72
+
73
+ # Normalize audio levels
74
+ merged_audio = self._normalize_audio(merged_audio)
75
+
76
+ # Export as MP3
77
+ logger.info(f"Exporting to MP3: {output_path}")
78
+ merged_audio.export(
79
+ output_path,
80
+ format="mp3",
81
+ bitrate="128k",
82
+ parameters=["-q:a", "2"] # Good quality
83
+ )
84
+
85
+ # Verify the file was created
86
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
87
+ duration = len(merged_audio) / 1000.0 # Convert to seconds
88
+ logger.info(f"Successfully created MP3: {duration:.1f} seconds")
89
+ return output_path
90
+ else:
91
+ logger.error("Failed to create MP3 file")
92
+ return None
93
+
94
+ except Exception as e:
95
+ logger.error(f"Failed to merge audio files: {e}")
96
+ return None
97
+
98
+ def _normalize_audio(self, audio: AudioSegment) -> AudioSegment:
99
+ """Normalize audio levels."""
100
+ try:
101
+ # Apply some basic audio processing
102
+ # Normalize to -6dB to avoid clipping
103
+ target_dBFS = -6.0
104
+ change_in_dBFS = target_dBFS - audio.dBFS
105
+ normalized_audio = audio.apply_gain(change_in_dBFS)
106
+
107
+ return normalized_audio
108
+ except Exception as e:
109
+ logger.warning(f"Failed to normalize audio: {e}")
110
+ return audio
gitattributes (2).txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ torchaudio>=2.0.0
5
+ numpy>=1.21.0
6
+ soundfile>=0.12.0
7
+ pydub>=0.25.0
8
+ librosa>=0.10.0
9
+ datasets>=2.10.0
10
+ accelerate>=0.20.0
11
+ git+https://github.com/nari-labs/dia.git # Add this line for Nari DIA
segmenter.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # segmenter.py - Text segmentation and speaker assignment
2
+ import re
3
+ from typing import List, Tuple
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class TextSegmenter:
9
+ def __init__(self):
10
+ # Changed speakers to Nari DIA's expected tags
11
+ self.speakers = ["S1", "S2"]
12
+ self.current_speaker_index = 0
13
+
14
+ def segment_and_assign_speakers(
15
+ self,
16
+ text: str,
17
+ mode: str = "auto"
18
+ ) -> List[Tuple[str, str]]:
19
+ """
20
+ Segment text and assign speakers.
21
+
22
+ Args:
23
+ text: Input text to segment
24
+ mode: Segmentation mode ("auto", "paragraph", "dialogue")
25
+
26
+ Returns:
27
+ List of (speaker, text) tuples
28
+ """
29
+ if mode == "paragraph":
30
+ return self._segment_by_paragraphs(text)
31
+ elif mode == "dialogue":
32
+ return self._segment_by_dialogue(text)
33
+ else: # auto mode
34
+ return self._segment_auto(text)
35
+
36
+ def _segment_by_paragraphs(self, text: str) -> List[Tuple[str, str]]:
37
+ """Segment by paragraphs, alternating speakers."""
38
+ paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
39
+ segments = []
40
+
41
+ for i, paragraph in enumerate(paragraphs):
42
+ speaker = self.speakers[i % len(self.speakers)]
43
+ segments.append((speaker, paragraph))
44
+
45
+ return segments
46
+
47
+ def _segment_by_dialogue(self, text: str) -> List[Tuple[str, str]]:
48
+ """Segment by detecting dialogue patterns."""
49
+ lines = text.split('\n')
50
+ segments = []
51
+ current_segment = []
52
+ # Start with the first speaker in the list
53
+ current_speaker = self.speakers[0]
54
+
55
+ for line in lines:
56
+ line = line.strip()
57
+ if not line:
58
+ continue
59
+
60
+ # Check for dialogue markers
61
+ if (line.startswith('"') or line.startswith("'") or
62
+ line.startswith('-') or line.startswith('β€”')):
63
+
64
+ # Save previous segment
65
+ if current_segment:
66
+ segments.append((current_speaker, ' '.join(current_segment)))
67
+
68
+ # Switch speaker and start new segment
69
+ self.current_speaker_index = (self.current_speaker_index + 1) % len(self.speakers)
70
+ current_speaker = self.speakers[self.current_speaker_index]
71
+ current_segment = [line]
72
+ else:
73
+ current_segment.append(line)
74
+
75
+ # Add final segment
76
+ if current_segment:
77
+ segments.append((current_speaker, ' '.join(current_segment)))
78
+
79
+ return segments
80
+
81
+ def _segment_auto(self, text: str) -> List[Tuple[str, str]]:
82
+ """Automatic segmentation using multiple heuristics."""
83
+ segments = []
84
+
85
+ paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
86
+
87
+ if len(paragraphs) > 1:
88
+ return self._segment_by_paragraphs(text)
89
+
90
+ sentences = self._split_into_sentences(text)
91
+ if len(sentences) > 10:
92
+ return self._segment_by_sentence_groups(sentences)
93
+
94
+ return self._segment_simple(text)
95
+
96
+ def _split_into_sentences(self, text: str) -> List[str]:
97
+ """Split text into sentences."""
98
+ # Simple sentence splitting
99
+ # Use a more robust regex to avoid splitting on abbreviations (e.g., "Mr.")
100
+ # This is a common simple improvement, though full NLP libraries are best for complex cases.
101
+ sentences = re.split(r'(?<=[.!?])\s+', text) # Split after . ! ? followed by space
102
+ return [s.strip() for s in sentences if s.strip()]
103
+
104
+ def _segment_by_sentence_groups(self, sentences: List[str]) -> List[Tuple[str, str]]:
105
+ """Group sentences and assign to different speakers."""
106
+ segments = []
107
+ group_size = max(2, len(sentences) // 8)
108
+
109
+ for i in range(0, len(sentences), group_size):
110
+ group = sentences[i:i + group_size]
111
+ speaker = self.speakers[i // group_size % len(self.speakers)]
112
+ text_segment = ' '.join(group) # No need to add '.' if already present from sentence splitting
113
+ segments.append((speaker, text_segment))
114
+
115
+ return segments
116
+
117
+ def _segment_simple(self, text: str) -> List[Tuple[str, str]]:
118
+ """Simple segmentation for short texts."""
119
+ words = text.split()
120
+ total_words = len(words)
121
+
122
+ if total_words < 50:
123
+ return [(self.speakers[0], text)] # Assign to S1
124
+
125
+ num_segments = min(len(self.speakers), max(2, total_words // 100)) # Limit segments by available speakers
126
+ segment_size = total_words // num_segments
127
+
128
+ segments = []
129
+ for i in range(num_segments):
130
+ start_idx = i * segment_size
131
+ end_idx = (i + 1) * segment_size if i < num_segments - 1 else total_words
132
+
133
+ segment_words = words[start_idx:end_idx]
134
+ segment_text = ' '.join(segment_words)
135
+ speaker = self.speakers[i % len(self.speakers)]
136
+
137
+ segments.append((speaker, segment_text))
138
+
139
+ return segments
tts_engine.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tts_engine.py - TTS engine wrapper for Nari DIA
2
+ import logging
3
+ import os
4
+ from typing import Optional
5
+ import tempfile
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import torch # Import torch for model operations
9
+
10
+ # Import the actual Nari DIA model
11
+ try:
12
+ from dia.model import Dia
13
+ except ImportError:
14
+ logging.error("Nari DIA library not found. Please ensure 'git+https://github.com/nari-labs/dia.git' is in your requirements.txt and installed.")
15
+ Dia = None # Set to None to prevent further errors
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class NariDIAEngine:
20
+ def __init__(self):
21
+ self.model = None
22
+ # No separate processor object for Dia, it handles internal processing
23
+ self._initialize_model()
24
+
25
+ def _initialize_model(self):
26
+ """Initialize the Nari DIA 1.6B model."""
27
+ if Dia is None:
28
+ logger.error("Nari DIA library is not available. Cannot initialize model.")
29
+ return
30
+
31
+ try:
32
+ logger.info("Initializing Nari DIA 1.6B model from nari-labs/Dia-1.6B...")
33
+
34
+ # Load the Nari DIA model
35
+ # Use compute_dtype="float16" for potentially better performance/memory on GPU
36
+ # Ensure you have a GPU with ~10GB VRAM for this.
37
+ self.model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
38
+
39
+ # Move model to GPU if available
40
+ if torch.cuda.is_available():
41
+ self.model.to("cuda")
42
+ logger.info("Nari DIA model moved to GPU (CUDA).")
43
+ else:
44
+ logger.warning("CUDA not available. Nari DIA model will run on CPU, which is not officially supported and will be very slow.")
45
+
46
+ logger.info("Nari DIA model initialized successfully.")
47
+
48
+ except Exception as e:
49
+ logger.error(f"Failed to initialize Nari DIA model: {e}", exc_info=True)
50
+ self.model = None
51
+
52
+ def synthesize_segment(
53
+ self,
54
+ text: str,
55
+ speaker: str, # This will be 'S1' or 'S2' from segmenter
56
+ output_path: str
57
+ ) -> Optional[str]:
58
+ """
59
+ Synthesize speech for a text segment using Nari DIA.
60
+
61
+ Args:
62
+ text: Text to synthesize
63
+ speaker: Speaker identifier ('S1' or 'S2' expected from segmenter)
64
+ output_path: Path to save the audio file
65
+
66
+ Returns:
67
+ Path to the generated audio file, or None if failed
68
+ """
69
+ if not self.model:
70
+ logger.error("Nari DIA model not initialized. Cannot synthesize speech.")
71
+ return None
72
+
73
+ try:
74
+ # Nari DIA expects [S1] or [S2] tags.
75
+ # The segmenter is directly outputting "S1" or "S2".
76
+ # We just need to wrap it in brackets.
77
+ if speaker in ["S1", "S2"]:
78
+ dia_speaker_tag = f"[{speaker}]"
79
+ else:
80
+ # Fallback in case segmenter outputs something unexpected
81
+ logger.warning(f"Unexpected speaker tag '{speaker}' from segmenter. Defaulting to [S1].")
82
+ dia_speaker_tag = "[S1]"
83
+
84
+ # Nari DIA expects the speaker tag at the beginning of the segment
85
+ full_text_input = f"{dia_speaker_tag} {text}"
86
+
87
+ # Generate audio using the Nari DIA model
88
+ logger.info(f"Synthesizing with Nari DIA: {full_text_input[:100]}...") # Log beginning of text
89
+
90
+ # Pass the text directly to the model's generate method
91
+ # Nari DIA's Dia class handles internal processing/tokenization
92
+ with torch.no_grad():
93
+ # The .generate method should return audio waveform as a PyTorch tensor
94
+ audio_waveform_tensor = self.model.generate(full_text_input)
95
+ audio_waveform = audio_waveform_tensor.cpu().numpy().squeeze()
96
+
97
+ # Nari DIA's sampling rate is typically 22050 Hz.
98
+ # If the Dia model object itself exposes a sampling_rate attribute, use it.
99
+ # Otherwise, default to 22050 as it's common for TTS models.
100
+ sampling_rate = getattr(self.model, 'sampling_rate', 22050)
101
+
102
+ # Save as WAV file
103
+ sf.write(output_path, audio_waveform, sampling_rate)
104
+
105
+ logger.info(f"Generated audio for {speaker} ({dia_speaker_tag}): {len(text)} characters to {output_path}")
106
+ return output_path
107
+
108
+ except Exception as e:
109
+ logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback
110
+ return None