PodXplainClone / segmenter.py
Nick021402's picture
Upload 7 files
1f6c376 verified
# segmenter.py - Text segmentation and speaker assignment
import re
from typing import List, Tuple
import logging
logger = logging.getLogger(__name__)
class TextSegmenter:
def __init__(self):
# Changed speakers to Nari DIA's expected tags
self.speakers = ["S1", "S2"]
self.current_speaker_index = 0
def segment_and_assign_speakers(
self,
text: str,
mode: str = "auto"
) -> List[Tuple[str, str]]:
"""
Segment text and assign speakers.
Args:
text: Input text to segment
mode: Segmentation mode ("auto", "paragraph", "dialogue")
Returns:
List of (speaker, text) tuples
"""
if mode == "paragraph":
return self._segment_by_paragraphs(text)
elif mode == "dialogue":
return self._segment_by_dialogue(text)
else: # auto mode
return self._segment_auto(text)
def _segment_by_paragraphs(self, text: str) -> List[Tuple[str, str]]:
"""Segment by paragraphs, alternating speakers."""
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
segments = []
for i, paragraph in enumerate(paragraphs):
speaker = self.speakers[i % len(self.speakers)]
segments.append((speaker, paragraph))
return segments
def _segment_by_dialogue(self, text: str) -> List[Tuple[str, str]]:
"""Segment by detecting dialogue patterns."""
lines = text.split('\n')
segments = []
current_segment = []
# Start with the first speaker in the list
current_speaker = self.speakers[0]
for line in lines:
line = line.strip()
if not line:
continue
# Check for dialogue markers
if (line.startswith('"') or line.startswith("'") or
line.startswith('-') or line.startswith('β€”')):
# Save previous segment
if current_segment:
segments.append((current_speaker, ' '.join(current_segment)))
# Switch speaker and start new segment
self.current_speaker_index = (self.current_speaker_index + 1) % len(self.speakers)
current_speaker = self.speakers[self.current_speaker_index]
current_segment = [line]
else:
current_segment.append(line)
# Add final segment
if current_segment:
segments.append((current_speaker, ' '.join(current_segment)))
return segments
def _segment_auto(self, text: str) -> List[Tuple[str, str]]:
"""Automatic segmentation using multiple heuristics."""
segments = []
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
if len(paragraphs) > 1:
return self._segment_by_paragraphs(text)
sentences = self._split_into_sentences(text)
if len(sentences) > 10:
return self._segment_by_sentence_groups(sentences)
return self._segment_simple(text)
def _split_into_sentences(self, text: str) -> List[str]:
"""Split text into sentences."""
# Simple sentence splitting
# Use a more robust regex to avoid splitting on abbreviations (e.g., "Mr.")
# This is a common simple improvement, though full NLP libraries are best for complex cases.
sentences = re.split(r'(?<=[.!?])\s+', text) # Split after . ! ? followed by space
return [s.strip() for s in sentences if s.strip()]
def _segment_by_sentence_groups(self, sentences: List[str]) -> List[Tuple[str, str]]:
"""Group sentences and assign to different speakers."""
segments = []
group_size = max(2, len(sentences) // 8)
for i in range(0, len(sentences), group_size):
group = sentences[i:i + group_size]
speaker = self.speakers[i // group_size % len(self.speakers)]
text_segment = ' '.join(group) # No need to add '.' if already present from sentence splitting
segments.append((speaker, text_segment))
return segments
def _segment_simple(self, text: str) -> List[Tuple[str, str]]:
"""Simple segmentation for short texts."""
words = text.split()
total_words = len(words)
if total_words < 50:
return [(self.speakers[0], text)] # Assign to S1
num_segments = min(len(self.speakers), max(2, total_words // 100)) # Limit segments by available speakers
segment_size = total_words // num_segments
segments = []
for i in range(num_segments):
start_idx = i * segment_size
end_idx = (i + 1) * segment_size if i < num_segments - 1 else total_words
segment_words = words[start_idx:end_idx]
segment_text = ' '.join(segment_words)
speaker = self.speakers[i % len(self.speakers)]
segments.append((speaker, segment_text))
return segments