Spaces:
Running
Running
Upload 7 files
Browse files- README.md +52 -10
- app.py +262 -0
- audio_utils.py +110 -0
- gitattributes (2).txt +35 -0
- requirements.txt +11 -0
- segmenter.py +139 -0
- tts_engine.py +110 -0
README.md
CHANGED
@@ -1,14 +1,56 @@
|
|
1 |
---
|
2 |
-
title: PodXplainClone
|
3 |
-
emoji: π
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.31.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
license: mit
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: mit
|
3 |
+
title: ποΈ PodXplain
|
4 |
+
sdk: gradio
|
5 |
+
emoji: π
|
6 |
+
colorFrom: red
|
7 |
+
colorTo: blue
|
8 |
+
pinned: true
|
9 |
+
short_description: PodXplain is a Hugging Face-hosted application that converts
|
10 |
---
|
11 |
+
# ποΈ PodXplain
|
12 |
+
|
13 |
+
**From script to story β voice it like never before.**
|
14 |
+
|
15 |
+
PodXplain is a Hugging Face-hosted application that converts long-form text into engaging multi-speaker podcast-style audio. Simply input your script, and get a professional-sounding MP3 podcast with automatic speaker detection and assignment.
|
16 |
+
|
17 |
+
## β¨ Features
|
18 |
+
|
19 |
+
- **π Long-form Support**: Handle up to 50,000 characters of text
|
20 |
+
- **π Multi-speaker Audio**: Automatic speaker detection and assignment
|
21 |
+
- **π Smart Segmentation**: Intelligent text splitting with progress tracking
|
22 |
+
- **π΅ High-quality Output**: MP3 format for optimal file size and compatibility
|
23 |
+
- **π Real-time Progress**: Live updates during generation
|
24 |
+
- **π¨ Modern UI**: Clean, intuitive Gradio interface
|
25 |
+
|
26 |
+
## π οΈ Tech Stack
|
27 |
+
|
28 |
+
- **Frontend**: Gradio for interactive web interface
|
29 |
+
- **TTS Engine**: Nari DIA 1.6B for natural voice synthesis (currently mocked)
|
30 |
+
- **Audio Processing**: pydub for audio manipulation and MP3 conversion
|
31 |
+
- **Hosting**: Hugging Face Spaces with GPU support
|
32 |
+
|
33 |
+
## π How to Use
|
34 |
+
|
35 |
+
1. **Input Text**: Paste or type your podcast script (up to 50,000 characters)
|
36 |
+
2. **Choose Mode**: Select speaker detection mode:
|
37 |
+
* **Auto**: Smart detection based on content structure
|
38 |
+
* **Paragraph**: Speaker changes at paragraph breaks
|
39 |
+
* **Dialogue**: Detection based on dialogue markers
|
40 |
+
3. **Generate**: Click "Generate Podcast" and watch the progress
|
41 |
+
4. **Download**: Get your MP3 file and listen to your podcast!
|
42 |
+
|
43 |
+
## π Quick Start
|
44 |
+
|
45 |
+
### Local Development
|
46 |
+
|
47 |
+
```bash
|
48 |
+
# Clone the repository
|
49 |
+
git clone [https://github.com/yourusername/podxplain.git](https://github.com/yourusername/podxplain.git) # Replace with your actual repo URL
|
50 |
+
cd podxplain
|
51 |
+
|
52 |
+
# Install dependencies
|
53 |
+
pip install -r requirements.txt
|
54 |
|
55 |
+
# Run the application
|
56 |
+
python app.py
|
app.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py - Main Gradio application
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import tempfile
|
5 |
+
import shutil
|
6 |
+
from pathlib import Path
|
7 |
+
import asyncio
|
8 |
+
from typing import List, Tuple, Generator
|
9 |
+
import logging
|
10 |
+
from datetime import datetime
|
11 |
+
|
12 |
+
# Import our custom modules
|
13 |
+
from segmenter import TextSegmenter
|
14 |
+
from tts_engine import NariDIAEngine
|
15 |
+
from audio_utils import AudioProcessor
|
16 |
+
|
17 |
+
# Configure logging
|
18 |
+
logging.basicConfig(level=logging.INFO)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
class PodXplainApp:
|
22 |
+
def __init__(self):
|
23 |
+
self.segmenter = TextSegmenter()
|
24 |
+
self.tts_engine = NariDIAEngine()
|
25 |
+
self.audio_processor = AudioProcessor()
|
26 |
+
self.temp_dir = None
|
27 |
+
|
28 |
+
def create_temp_directory(self) -> str:
|
29 |
+
"""Create a temporary directory for processing."""
|
30 |
+
if self.temp_dir:
|
31 |
+
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
32 |
+
self.temp_dir = tempfile.mkdtemp(prefix="podxplain_")
|
33 |
+
return self.temp_dir
|
34 |
+
|
35 |
+
def cleanup_temp_directory(self):
|
36 |
+
"""Clean up temporary files."""
|
37 |
+
if self.temp_dir and os.path.exists(self.temp_dir):
|
38 |
+
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
39 |
+
self.temp_dir = None
|
40 |
+
|
41 |
+
def generate_podcast(
|
42 |
+
self,
|
43 |
+
text: str,
|
44 |
+
speaker_detection_mode: str = "auto",
|
45 |
+
progress=gr.Progress()
|
46 |
+
) -> Tuple[str, str]:
|
47 |
+
"""
|
48 |
+
Main function to convert text to podcast audio.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
text: Input text (up to 50,000 characters)
|
52 |
+
speaker_detection_mode: How to detect speaker changes
|
53 |
+
progress: Gradio progress tracker
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Tuple of (audio_path, status_message)
|
57 |
+
"""
|
58 |
+
try:
|
59 |
+
# Validate input
|
60 |
+
if not text or len(text.strip()) == 0:
|
61 |
+
return None, "β Please provide some text to convert."
|
62 |
+
|
63 |
+
if len(text) > 50000:
|
64 |
+
return None, f"β Text too long ({len(text)} chars). Maximum is 50,000 characters."
|
65 |
+
|
66 |
+
# Create temporary directory
|
67 |
+
temp_dir = self.create_temp_directory()
|
68 |
+
progress(0, desc="π Starting podcast generation...")
|
69 |
+
|
70 |
+
# Step 1: Segment text and assign speakers
|
71 |
+
progress(0.1, desc="π Analyzing text and assigning speakers...")
|
72 |
+
segments = self.segmenter.segment_and_assign_speakers(
|
73 |
+
text, mode=speaker_detection_mode
|
74 |
+
)
|
75 |
+
|
76 |
+
if not segments:
|
77 |
+
return None, "β Could not process the text. Please check the input."
|
78 |
+
|
79 |
+
logger.info(f"Generated {len(segments)} segments")
|
80 |
+
|
81 |
+
# Step 2: Generate audio for each segment
|
82 |
+
progress(0.2, desc="π€ Generating audio segments...")
|
83 |
+
audio_files = []
|
84 |
+
|
85 |
+
for i, (speaker, segment_text) in enumerate(segments):
|
86 |
+
progress(
|
87 |
+
0.2 + (0.7 * i / len(segments)),
|
88 |
+
desc=f"π΅ Processing segment {i+1}/{len(segments)} (Speaker {speaker})"
|
89 |
+
)
|
90 |
+
|
91 |
+
# Generate audio for this segment
|
92 |
+
audio_path = self.tts_engine.synthesize_segment(
|
93 |
+
segment_text,
|
94 |
+
speaker,
|
95 |
+
os.path.join(temp_dir, f"segment_{i:03d}.wav")
|
96 |
+
)
|
97 |
+
|
98 |
+
if audio_path:
|
99 |
+
audio_files.append(audio_path)
|
100 |
+
else:
|
101 |
+
logger.warning(f"Failed to generate audio for segment {i}")
|
102 |
+
|
103 |
+
if not audio_files:
|
104 |
+
return None, "β Failed to generate any audio segments."
|
105 |
+
|
106 |
+
# Step 3: Merge audio files and convert to MP3
|
107 |
+
progress(0.9, desc="π§ Merging segments and converting to MP3...")
|
108 |
+
final_audio_path = self.audio_processor.merge_and_convert_to_mp3(
|
109 |
+
audio_files,
|
110 |
+
os.path.join(temp_dir, "podcast_output.mp3")
|
111 |
+
)
|
112 |
+
|
113 |
+
if not final_audio_path:
|
114 |
+
return None, "β Failed to merge audio segments."
|
115 |
+
|
116 |
+
progress(1.0, desc="β
Podcast generated successfully!")
|
117 |
+
|
118 |
+
# Generate summary
|
119 |
+
total_segments = len(segments)
|
120 |
+
speakers_used = len(set(speaker for speaker, _ in segments))
|
121 |
+
duration_estimate = len(text) / 1000 * 60 # Rough estimate: 1000 chars β 1 minute
|
122 |
+
|
123 |
+
status_message = f"""
|
124 |
+
β
**Podcast Generated Successfully!**
|
125 |
+
|
126 |
+
π **Statistics:**
|
127 |
+
- Total segments: {total_segments}
|
128 |
+
- Speakers used: {speakers_used}
|
129 |
+
- Estimated duration: {duration_estimate:.1f} minutes
|
130 |
+
- Character count: {len(text):,}
|
131 |
+
|
132 |
+
π§ **Your podcast is ready for download!**
|
133 |
+
"""
|
134 |
+
|
135 |
+
return final_audio_path, status_message
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
logger.error(f"Error generating podcast: {str(e)}")
|
139 |
+
return None, f"β Error: {str(e)}"
|
140 |
+
|
141 |
+
finally:
|
142 |
+
# Clean up temporary files (except the final output)
|
143 |
+
# Note: We keep the final MP3 for download
|
144 |
+
pass
|
145 |
+
|
146 |
+
def create_gradio_interface():
|
147 |
+
"""Create the Gradio interface."""
|
148 |
+
app = PodXplainApp()
|
149 |
+
|
150 |
+
# Custom CSS for better styling
|
151 |
+
css = """
|
152 |
+
.main-container {
|
153 |
+
max-width: 1200px;
|
154 |
+
margin: 0 auto;
|
155 |
+
}
|
156 |
+
.header {
|
157 |
+
text-align: center;
|
158 |
+
padding: 20px;
|
159 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
160 |
+
color: white;
|
161 |
+
border-radius: 10px;
|
162 |
+
margin-bottom: 20px;
|
163 |
+
}
|
164 |
+
.footer {
|
165 |
+
text-align: center;
|
166 |
+
padding: 20px;
|
167 |
+
color: #666;
|
168 |
+
font-size: 0.9em;
|
169 |
+
}
|
170 |
+
"""
|
171 |
+
|
172 |
+
with gr.Blocks(css=css, title="PodXplain - Text to Podcast") as interface:
|
173 |
+
# Header
|
174 |
+
gr.HTML("""
|
175 |
+
<div class="header">
|
176 |
+
<h1>ποΈ PodXplain</h1>
|
177 |
+
<p><em>From script to story β voice it like never before.</em></p>
|
178 |
+
</div>
|
179 |
+
""")
|
180 |
+
|
181 |
+
with gr.Row():
|
182 |
+
with gr.Column(scale=2):
|
183 |
+
# Input section
|
184 |
+
gr.Markdown("## π Input Your Script")
|
185 |
+
|
186 |
+
text_input = gr.Textbox(
|
187 |
+
label="Podcast Script",
|
188 |
+
placeholder="Enter your podcast script here (up to 50,000 characters).\n\nTip: Use paragraph breaks to help with speaker detection.",
|
189 |
+
lines=15,
|
190 |
+
max_lines=20,
|
191 |
+
show_label=True
|
192 |
+
)
|
193 |
+
|
194 |
+
char_count = gr.HTML("Characters: 0 / 50,000")
|
195 |
+
|
196 |
+
# Options
|
197 |
+
speaker_mode = gr.Radio(
|
198 |
+
choices=["auto", "paragraph", "dialogue"],
|
199 |
+
value="auto",
|
200 |
+
label="Speaker Detection Mode",
|
201 |
+
info="How to detect when speakers change"
|
202 |
+
)
|
203 |
+
|
204 |
+
generate_btn = gr.Button(
|
205 |
+
"π€ Generate Podcast",
|
206 |
+
variant="primary",
|
207 |
+
size="lg"
|
208 |
+
)
|
209 |
+
|
210 |
+
with gr.Column(scale=1):
|
211 |
+
# Output section
|
212 |
+
gr.Markdown("## π§ Your Podcast")
|
213 |
+
|
214 |
+
status_output = gr.Markdown("Ready to generate your podcast!")
|
215 |
+
|
216 |
+
audio_output = gr.Audio(
|
217 |
+
label="Generated Podcast",
|
218 |
+
show_download_button=True,
|
219 |
+
interactive=False
|
220 |
+
)
|
221 |
+
|
222 |
+
# Footer with instructions
|
223 |
+
gr.HTML("""
|
224 |
+
<div class="footer">
|
225 |
+
<h3>π How to Use PodXplain</h3>
|
226 |
+
<ol>
|
227 |
+
<li><strong>Write your script:</strong> Enter up to 50,000 characters of text</li>
|
228 |
+
<li><strong>Choose speaker mode:</strong> Auto-detect, paragraph-based, or dialogue-based</li>
|
229 |
+
<li><strong>Generate:</strong> Click the button and wait for processing</li>
|
230 |
+
<li><strong>Listen & Download:</strong> Your MP3 podcast will be ready!</li>
|
231 |
+
</ol>
|
232 |
+
<p><strong>π‘ Tips:</strong> Use clear paragraph breaks for better speaker detection.
|
233 |
+
Write naturally as if speaking to an audience.</p>
|
234 |
+
</div>
|
235 |
+
""")
|
236 |
+
|
237 |
+
# JavaScript for character counting
|
238 |
+
text_input.change(
|
239 |
+
fn=lambda text: f"Characters: {len(text) if text else 0:,} / 50,000",
|
240 |
+
inputs=[text_input],
|
241 |
+
outputs=[char_count]
|
242 |
+
)
|
243 |
+
|
244 |
+
# Main generation function
|
245 |
+
generate_btn.click(
|
246 |
+
fn=app.generate_podcast,
|
247 |
+
inputs=[text_input, speaker_mode],
|
248 |
+
outputs=[audio_output, status_output],
|
249 |
+
show_progress=True
|
250 |
+
)
|
251 |
+
|
252 |
+
return interface
|
253 |
+
|
254 |
+
if __name__ == "__main__":
|
255 |
+
# Create and launch the interface
|
256 |
+
interface = create_gradio_interface()
|
257 |
+
interface.launch(
|
258 |
+
share=True,
|
259 |
+
server_name="0.0.0.0",
|
260 |
+
server_port=7860,
|
261 |
+
show_error=True
|
262 |
+
)
|
audio_utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# audio_utils.py - Audio processing utilities
|
2 |
+
import logging
|
3 |
+
from typing import List, Optional
|
4 |
+
import os
|
5 |
+
import tempfile
|
6 |
+
from pydub import AudioSegment
|
7 |
+
from pydub.utils import which
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
class AudioProcessor:
|
12 |
+
def __init__(self):
|
13 |
+
self._check_dependencies()
|
14 |
+
|
15 |
+
def _check_dependencies(self):
|
16 |
+
"""Check if required audio processing tools are available."""
|
17 |
+
# Check for ffmpeg
|
18 |
+
if not which("ffmpeg"):
|
19 |
+
logger.warning("ffmpeg not found. Some audio operations may fail.")
|
20 |
+
|
21 |
+
def merge_and_convert_to_mp3(
|
22 |
+
self,
|
23 |
+
audio_files: List[str],
|
24 |
+
output_path: str
|
25 |
+
) -> Optional[str]:
|
26 |
+
"""
|
27 |
+
Merge multiple audio files and convert to MP3.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
audio_files: List of paths to audio files to merge
|
31 |
+
output_path: Path for the output MP3 file
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Path to the merged MP3 file, or None if failed
|
35 |
+
"""
|
36 |
+
try:
|
37 |
+
if not audio_files:
|
38 |
+
logger.error("No audio files to merge")
|
39 |
+
return None
|
40 |
+
|
41 |
+
logger.info(f"Merging {len(audio_files)} audio files...")
|
42 |
+
|
43 |
+
# Start with empty audio
|
44 |
+
merged_audio = AudioSegment.empty()
|
45 |
+
|
46 |
+
for i, audio_file in enumerate(audio_files):
|
47 |
+
if not os.path.exists(audio_file):
|
48 |
+
logger.warning(f"Audio file not found: {audio_file}")
|
49 |
+
continue
|
50 |
+
|
51 |
+
try:
|
52 |
+
# Load audio segment
|
53 |
+
segment = AudioSegment.from_wav(audio_file)
|
54 |
+
|
55 |
+
# Add a small pause between segments (500ms)
|
56 |
+
if i > 0:
|
57 |
+
pause = AudioSegment.silent(duration=500)
|
58 |
+
merged_audio += pause
|
59 |
+
|
60 |
+
# Add the segment
|
61 |
+
merged_audio += segment
|
62 |
+
|
63 |
+
logger.info(f"Added segment {i+1}/{len(audio_files)}")
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
logger.error(f"Failed to process audio file {audio_file}: {e}")
|
67 |
+
continue
|
68 |
+
|
69 |
+
if len(merged_audio) == 0:
|
70 |
+
logger.error("No audio content to export")
|
71 |
+
return None
|
72 |
+
|
73 |
+
# Normalize audio levels
|
74 |
+
merged_audio = self._normalize_audio(merged_audio)
|
75 |
+
|
76 |
+
# Export as MP3
|
77 |
+
logger.info(f"Exporting to MP3: {output_path}")
|
78 |
+
merged_audio.export(
|
79 |
+
output_path,
|
80 |
+
format="mp3",
|
81 |
+
bitrate="128k",
|
82 |
+
parameters=["-q:a", "2"] # Good quality
|
83 |
+
)
|
84 |
+
|
85 |
+
# Verify the file was created
|
86 |
+
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
87 |
+
duration = len(merged_audio) / 1000.0 # Convert to seconds
|
88 |
+
logger.info(f"Successfully created MP3: {duration:.1f} seconds")
|
89 |
+
return output_path
|
90 |
+
else:
|
91 |
+
logger.error("Failed to create MP3 file")
|
92 |
+
return None
|
93 |
+
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Failed to merge audio files: {e}")
|
96 |
+
return None
|
97 |
+
|
98 |
+
def _normalize_audio(self, audio: AudioSegment) -> AudioSegment:
|
99 |
+
"""Normalize audio levels."""
|
100 |
+
try:
|
101 |
+
# Apply some basic audio processing
|
102 |
+
# Normalize to -6dB to avoid clipping
|
103 |
+
target_dBFS = -6.0
|
104 |
+
change_in_dBFS = target_dBFS - audio.dBFS
|
105 |
+
normalized_audio = audio.apply_gain(change_in_dBFS)
|
106 |
+
|
107 |
+
return normalized_audio
|
108 |
+
except Exception as e:
|
109 |
+
logger.warning(f"Failed to normalize audio: {e}")
|
110 |
+
return audio
|
gitattributes (2).txt
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=4.0.0
|
2 |
+
transformers>=4.30.0
|
3 |
+
torch>=2.0.0
|
4 |
+
torchaudio>=2.0.0
|
5 |
+
numpy>=1.21.0
|
6 |
+
soundfile>=0.12.0
|
7 |
+
pydub>=0.25.0
|
8 |
+
librosa>=0.10.0
|
9 |
+
datasets>=2.10.0
|
10 |
+
accelerate>=0.20.0
|
11 |
+
git+https://github.com/nari-labs/dia.git # Add this line for Nari DIA
|
segmenter.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# segmenter.py - Text segmentation and speaker assignment
|
2 |
+
import re
|
3 |
+
from typing import List, Tuple
|
4 |
+
import logging
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class TextSegmenter:
|
9 |
+
def __init__(self):
|
10 |
+
# Changed speakers to Nari DIA's expected tags
|
11 |
+
self.speakers = ["S1", "S2"]
|
12 |
+
self.current_speaker_index = 0
|
13 |
+
|
14 |
+
def segment_and_assign_speakers(
|
15 |
+
self,
|
16 |
+
text: str,
|
17 |
+
mode: str = "auto"
|
18 |
+
) -> List[Tuple[str, str]]:
|
19 |
+
"""
|
20 |
+
Segment text and assign speakers.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
text: Input text to segment
|
24 |
+
mode: Segmentation mode ("auto", "paragraph", "dialogue")
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
List of (speaker, text) tuples
|
28 |
+
"""
|
29 |
+
if mode == "paragraph":
|
30 |
+
return self._segment_by_paragraphs(text)
|
31 |
+
elif mode == "dialogue":
|
32 |
+
return self._segment_by_dialogue(text)
|
33 |
+
else: # auto mode
|
34 |
+
return self._segment_auto(text)
|
35 |
+
|
36 |
+
def _segment_by_paragraphs(self, text: str) -> List[Tuple[str, str]]:
|
37 |
+
"""Segment by paragraphs, alternating speakers."""
|
38 |
+
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
39 |
+
segments = []
|
40 |
+
|
41 |
+
for i, paragraph in enumerate(paragraphs):
|
42 |
+
speaker = self.speakers[i % len(self.speakers)]
|
43 |
+
segments.append((speaker, paragraph))
|
44 |
+
|
45 |
+
return segments
|
46 |
+
|
47 |
+
def _segment_by_dialogue(self, text: str) -> List[Tuple[str, str]]:
|
48 |
+
"""Segment by detecting dialogue patterns."""
|
49 |
+
lines = text.split('\n')
|
50 |
+
segments = []
|
51 |
+
current_segment = []
|
52 |
+
# Start with the first speaker in the list
|
53 |
+
current_speaker = self.speakers[0]
|
54 |
+
|
55 |
+
for line in lines:
|
56 |
+
line = line.strip()
|
57 |
+
if not line:
|
58 |
+
continue
|
59 |
+
|
60 |
+
# Check for dialogue markers
|
61 |
+
if (line.startswith('"') or line.startswith("'") or
|
62 |
+
line.startswith('-') or line.startswith('β')):
|
63 |
+
|
64 |
+
# Save previous segment
|
65 |
+
if current_segment:
|
66 |
+
segments.append((current_speaker, ' '.join(current_segment)))
|
67 |
+
|
68 |
+
# Switch speaker and start new segment
|
69 |
+
self.current_speaker_index = (self.current_speaker_index + 1) % len(self.speakers)
|
70 |
+
current_speaker = self.speakers[self.current_speaker_index]
|
71 |
+
current_segment = [line]
|
72 |
+
else:
|
73 |
+
current_segment.append(line)
|
74 |
+
|
75 |
+
# Add final segment
|
76 |
+
if current_segment:
|
77 |
+
segments.append((current_speaker, ' '.join(current_segment)))
|
78 |
+
|
79 |
+
return segments
|
80 |
+
|
81 |
+
def _segment_auto(self, text: str) -> List[Tuple[str, str]]:
|
82 |
+
"""Automatic segmentation using multiple heuristics."""
|
83 |
+
segments = []
|
84 |
+
|
85 |
+
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
86 |
+
|
87 |
+
if len(paragraphs) > 1:
|
88 |
+
return self._segment_by_paragraphs(text)
|
89 |
+
|
90 |
+
sentences = self._split_into_sentences(text)
|
91 |
+
if len(sentences) > 10:
|
92 |
+
return self._segment_by_sentence_groups(sentences)
|
93 |
+
|
94 |
+
return self._segment_simple(text)
|
95 |
+
|
96 |
+
def _split_into_sentences(self, text: str) -> List[str]:
|
97 |
+
"""Split text into sentences."""
|
98 |
+
# Simple sentence splitting
|
99 |
+
# Use a more robust regex to avoid splitting on abbreviations (e.g., "Mr.")
|
100 |
+
# This is a common simple improvement, though full NLP libraries are best for complex cases.
|
101 |
+
sentences = re.split(r'(?<=[.!?])\s+', text) # Split after . ! ? followed by space
|
102 |
+
return [s.strip() for s in sentences if s.strip()]
|
103 |
+
|
104 |
+
def _segment_by_sentence_groups(self, sentences: List[str]) -> List[Tuple[str, str]]:
|
105 |
+
"""Group sentences and assign to different speakers."""
|
106 |
+
segments = []
|
107 |
+
group_size = max(2, len(sentences) // 8)
|
108 |
+
|
109 |
+
for i in range(0, len(sentences), group_size):
|
110 |
+
group = sentences[i:i + group_size]
|
111 |
+
speaker = self.speakers[i // group_size % len(self.speakers)]
|
112 |
+
text_segment = ' '.join(group) # No need to add '.' if already present from sentence splitting
|
113 |
+
segments.append((speaker, text_segment))
|
114 |
+
|
115 |
+
return segments
|
116 |
+
|
117 |
+
def _segment_simple(self, text: str) -> List[Tuple[str, str]]:
|
118 |
+
"""Simple segmentation for short texts."""
|
119 |
+
words = text.split()
|
120 |
+
total_words = len(words)
|
121 |
+
|
122 |
+
if total_words < 50:
|
123 |
+
return [(self.speakers[0], text)] # Assign to S1
|
124 |
+
|
125 |
+
num_segments = min(len(self.speakers), max(2, total_words // 100)) # Limit segments by available speakers
|
126 |
+
segment_size = total_words // num_segments
|
127 |
+
|
128 |
+
segments = []
|
129 |
+
for i in range(num_segments):
|
130 |
+
start_idx = i * segment_size
|
131 |
+
end_idx = (i + 1) * segment_size if i < num_segments - 1 else total_words
|
132 |
+
|
133 |
+
segment_words = words[start_idx:end_idx]
|
134 |
+
segment_text = ' '.join(segment_words)
|
135 |
+
speaker = self.speakers[i % len(self.speakers)]
|
136 |
+
|
137 |
+
segments.append((speaker, segment_text))
|
138 |
+
|
139 |
+
return segments
|
tts_engine.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tts_engine.py - TTS engine wrapper for Nari DIA
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from typing import Optional
|
5 |
+
import tempfile
|
6 |
+
import numpy as np
|
7 |
+
import soundfile as sf
|
8 |
+
import torch # Import torch for model operations
|
9 |
+
|
10 |
+
# Import the actual Nari DIA model
|
11 |
+
try:
|
12 |
+
from dia.model import Dia
|
13 |
+
except ImportError:
|
14 |
+
logging.error("Nari DIA library not found. Please ensure 'git+https://github.com/nari-labs/dia.git' is in your requirements.txt and installed.")
|
15 |
+
Dia = None # Set to None to prevent further errors
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
class NariDIAEngine:
|
20 |
+
def __init__(self):
|
21 |
+
self.model = None
|
22 |
+
# No separate processor object for Dia, it handles internal processing
|
23 |
+
self._initialize_model()
|
24 |
+
|
25 |
+
def _initialize_model(self):
|
26 |
+
"""Initialize the Nari DIA 1.6B model."""
|
27 |
+
if Dia is None:
|
28 |
+
logger.error("Nari DIA library is not available. Cannot initialize model.")
|
29 |
+
return
|
30 |
+
|
31 |
+
try:
|
32 |
+
logger.info("Initializing Nari DIA 1.6B model from nari-labs/Dia-1.6B...")
|
33 |
+
|
34 |
+
# Load the Nari DIA model
|
35 |
+
# Use compute_dtype="float16" for potentially better performance/memory on GPU
|
36 |
+
# Ensure you have a GPU with ~10GB VRAM for this.
|
37 |
+
self.model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
|
38 |
+
|
39 |
+
# Move model to GPU if available
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
self.model.to("cuda")
|
42 |
+
logger.info("Nari DIA model moved to GPU (CUDA).")
|
43 |
+
else:
|
44 |
+
logger.warning("CUDA not available. Nari DIA model will run on CPU, which is not officially supported and will be very slow.")
|
45 |
+
|
46 |
+
logger.info("Nari DIA model initialized successfully.")
|
47 |
+
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Failed to initialize Nari DIA model: {e}", exc_info=True)
|
50 |
+
self.model = None
|
51 |
+
|
52 |
+
def synthesize_segment(
|
53 |
+
self,
|
54 |
+
text: str,
|
55 |
+
speaker: str, # This will be 'S1' or 'S2' from segmenter
|
56 |
+
output_path: str
|
57 |
+
) -> Optional[str]:
|
58 |
+
"""
|
59 |
+
Synthesize speech for a text segment using Nari DIA.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
text: Text to synthesize
|
63 |
+
speaker: Speaker identifier ('S1' or 'S2' expected from segmenter)
|
64 |
+
output_path: Path to save the audio file
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Path to the generated audio file, or None if failed
|
68 |
+
"""
|
69 |
+
if not self.model:
|
70 |
+
logger.error("Nari DIA model not initialized. Cannot synthesize speech.")
|
71 |
+
return None
|
72 |
+
|
73 |
+
try:
|
74 |
+
# Nari DIA expects [S1] or [S2] tags.
|
75 |
+
# The segmenter is directly outputting "S1" or "S2".
|
76 |
+
# We just need to wrap it in brackets.
|
77 |
+
if speaker in ["S1", "S2"]:
|
78 |
+
dia_speaker_tag = f"[{speaker}]"
|
79 |
+
else:
|
80 |
+
# Fallback in case segmenter outputs something unexpected
|
81 |
+
logger.warning(f"Unexpected speaker tag '{speaker}' from segmenter. Defaulting to [S1].")
|
82 |
+
dia_speaker_tag = "[S1]"
|
83 |
+
|
84 |
+
# Nari DIA expects the speaker tag at the beginning of the segment
|
85 |
+
full_text_input = f"{dia_speaker_tag} {text}"
|
86 |
+
|
87 |
+
# Generate audio using the Nari DIA model
|
88 |
+
logger.info(f"Synthesizing with Nari DIA: {full_text_input[:100]}...") # Log beginning of text
|
89 |
+
|
90 |
+
# Pass the text directly to the model's generate method
|
91 |
+
# Nari DIA's Dia class handles internal processing/tokenization
|
92 |
+
with torch.no_grad():
|
93 |
+
# The .generate method should return audio waveform as a PyTorch tensor
|
94 |
+
audio_waveform_tensor = self.model.generate(full_text_input)
|
95 |
+
audio_waveform = audio_waveform_tensor.cpu().numpy().squeeze()
|
96 |
+
|
97 |
+
# Nari DIA's sampling rate is typically 22050 Hz.
|
98 |
+
# If the Dia model object itself exposes a sampling_rate attribute, use it.
|
99 |
+
# Otherwise, default to 22050 as it's common for TTS models.
|
100 |
+
sampling_rate = getattr(self.model, 'sampling_rate', 22050)
|
101 |
+
|
102 |
+
# Save as WAV file
|
103 |
+
sf.write(output_path, audio_waveform, sampling_rate)
|
104 |
+
|
105 |
+
logger.info(f"Generated audio for {speaker} ({dia_speaker_tag}): {len(text)} characters to {output_path}")
|
106 |
+
return output_path
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback
|
110 |
+
return None
|