Spaces:
Sleeping
Sleeping
# src/utils.py | |
import re | |
import datetime | |
import pandas as pd | |
from typing import Dict, List, Tuple, Set, Optional | |
from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES, DISPLAY_CONFIG | |
def get_all_language_pairs() -> List[Tuple[str, str]]: | |
"""Get all possible UG40 language pairs.""" | |
pairs = [] | |
for src in ALL_UG40_LANGUAGES: | |
for tgt in ALL_UG40_LANGUAGES: | |
if src != tgt: | |
pairs.append((src, tgt)) | |
return pairs | |
def get_google_comparable_pairs() -> List[Tuple[str, str]]: | |
"""Get language pairs that can be compared with Google Translate.""" | |
pairs = [] | |
for src in GOOGLE_SUPPORTED_LANGUAGES: | |
for tgt in GOOGLE_SUPPORTED_LANGUAGES: | |
if src != tgt: | |
pairs.append((src, tgt)) | |
return pairs | |
def format_language_pair(src: str, tgt: str) -> str: | |
"""Format language pair for display.""" | |
src_name = LANGUAGE_NAMES.get(src, src.upper()) | |
tgt_name = LANGUAGE_NAMES.get(tgt, tgt.upper()) | |
return f"{src_name} → {tgt_name}" | |
def validate_language_code(lang: str) -> bool: | |
"""Validate if language code is supported.""" | |
return lang in ALL_UG40_LANGUAGES | |
def create_submission_id() -> str: | |
"""Create unique submission ID.""" | |
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] | |
def sanitize_model_name(name: str) -> str: | |
"""Sanitize model name for display and storage.""" | |
if not name or not isinstance(name, str): | |
return "Anonymous_Model" | |
# Remove special characters, limit length | |
name = re.sub(r'[^\w\-.]', '_', name.strip()) | |
# Remove multiple consecutive underscores | |
name = re.sub(r'_+', '_', name) | |
# Remove leading/trailing underscores | |
name = name.strip('_') | |
# Ensure minimum length | |
if len(name) < 3: | |
name = f"Model_{name}" | |
return name[:50] # Limit to 50 characters | |
def format_metric_value(value: float, metric: str) -> str: | |
"""Format metric value for display with appropriate precision.""" | |
if pd.isna(value) or value is None: | |
return "N/A" | |
try: | |
precision = DISPLAY_CONFIG['decimal_places'].get(metric, 4) | |
if metric == 'coverage_rate': | |
return f"{value:.{precision}%}" | |
elif metric in ['bleu']: | |
return f"{value:.{precision}f}" | |
elif metric in ['cer', 'wer'] and value > 1: | |
# Cap error rates at 1.0 for display | |
return f"{min(value, 1.0):.{precision}f}" | |
else: | |
return f"{value:.{precision}f}" | |
except (ValueError, TypeError): | |
return str(value) | |
def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]: | |
"""Get statistics about language pair coverage in test data.""" | |
if test_data.empty: | |
return {} | |
stats = {} | |
try: | |
for src in ALL_UG40_LANGUAGES: | |
for tgt in ALL_UG40_LANGUAGES: | |
if src != tgt: | |
pair_data = test_data[ | |
(test_data['source_language'] == src) & | |
(test_data['target_language'] == tgt) | |
] | |
stats[f"{src}_{tgt}"] = { | |
'count': len(pair_data), | |
'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES, | |
'display_name': format_language_pair(src, tgt), | |
'source_language': src, | |
'target_language': tgt | |
} | |
except Exception as e: | |
print(f"Error calculating language pair stats: {e}") | |
return {} | |
return stats | |
def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict: | |
"""Validate that submission covers all required samples.""" | |
if predictions.empty or test_set.empty: | |
return { | |
'is_complete': False, | |
'missing_count': len(test_set) if not test_set.empty else 0, | |
'extra_count': len(predictions) if not predictions.empty else 0, | |
'missing_ids': [], | |
'coverage': 0.0 | |
} | |
try: | |
required_ids = set(test_set['sample_id'].astype(str)) | |
provided_ids = set(predictions['sample_id'].astype(str)) | |
missing_ids = required_ids - provided_ids | |
extra_ids = provided_ids - required_ids | |
return { | |
'is_complete': len(missing_ids) == 0, | |
'missing_count': len(missing_ids), | |
'extra_count': len(extra_ids), | |
'missing_ids': list(missing_ids)[:10], # First 10 for display | |
'coverage': len(provided_ids & required_ids) / len(required_ids) if required_ids else 0.0 | |
} | |
except Exception as e: | |
print(f"Error validating submission completeness: {e}") | |
return { | |
'is_complete': False, | |
'missing_count': 0, | |
'extra_count': 0, | |
'missing_ids': [], | |
'coverage': 0.0 | |
} | |
def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict: | |
"""Calculate coverage by language pair.""" | |
if predictions.empty or test_set.empty: | |
return {} | |
try: | |
# Merge to get language info | |
merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred')) | |
coverage = {} | |
for src in ALL_UG40_LANGUAGES: | |
for tgt in ALL_UG40_LANGUAGES: | |
if src != tgt: | |
pair_data = merged[ | |
(merged['source_language'] == src) & | |
(merged['target_language'] == tgt) | |
] | |
if len(pair_data) > 0: | |
predicted_count = pair_data['prediction'].notna().sum() | |
coverage[f"{src}_{tgt}"] = { | |
'total': len(pair_data), | |
'predicted': predicted_count, | |
'coverage': predicted_count / len(pair_data), | |
'display_name': format_language_pair(src, tgt) | |
} | |
return coverage | |
except Exception as e: | |
print(f"Error calculating language pair coverage: {e}") | |
return {} | |
def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float: | |
"""Safely divide two numbers, handling edge cases.""" | |
try: | |
if denominator == 0 or pd.isna(denominator) or pd.isna(numerator): | |
return default | |
result = numerator / denominator | |
if pd.isna(result) or not pd.isfinite(result): | |
return default | |
return float(result) | |
except (TypeError, ValueError, ZeroDivisionError): | |
return default | |
def clean_text_for_evaluation(text: str) -> str: | |
"""Clean text for evaluation, handling common encoding issues.""" | |
if not isinstance(text, str): | |
return str(text) if text is not None else "" | |
# Remove extra whitespace | |
text = re.sub(r'\s+', ' ', text.strip()) | |
# Handle common encoding issues | |
text = text.replace('\u00a0', ' ') # Non-breaking space | |
text = text.replace('\u2019', "'") # Right single quotation mark | |
text = text.replace('\u201c', '"') # Left double quotation mark | |
text = text.replace('\u201d', '"') # Right double quotation mark | |
return text | |
def get_model_summary_stats(model_results: Dict) -> Dict: | |
"""Extract summary statistics from model evaluation results.""" | |
if not model_results or 'averages' not in model_results: | |
return {} | |
averages = model_results.get('averages', {}) | |
summary = model_results.get('summary', {}) | |
return { | |
'quality_score': averages.get('quality_score', 0.0), | |
'bleu': averages.get('bleu', 0.0), | |
'chrf': averages.get('chrf', 0.0), | |
'rouge1': averages.get('rouge1', 0.0), | |
'rougeL': averages.get('rougeL', 0.0), | |
'total_samples': summary.get('total_samples', 0), | |
'language_pairs': summary.get('language_pairs_covered', 0), | |
'google_pairs': summary.get('google_comparable_pairs', 0) | |
} | |
def generate_model_identifier(model_name: str, author: str) -> str: | |
"""Generate a unique identifier for a model.""" | |
clean_name = sanitize_model_name(model_name) | |
clean_author = re.sub(r'[^\w\-]', '_', author.strip())[:20] if author else "Anonymous" | |
timestamp = datetime.datetime.now().strftime("%m%d_%H%M") | |
return f"{clean_name}_{clean_author}_{timestamp}" | |
def validate_dataframe_structure(df: pd.DataFrame, required_columns: List[str]) -> Tuple[bool, List[str]]: | |
"""Validate that a DataFrame has the required structure.""" | |
if df.empty: | |
return False, ["DataFrame is empty"] | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
return False, [f"Missing columns: {', '.join(missing_columns)}"] | |
return True, [] | |
def format_duration(seconds: float) -> str: | |
"""Format duration in seconds to human-readable format.""" | |
if seconds < 60: | |
return f"{seconds:.1f}s" | |
elif seconds < 3600: | |
return f"{seconds/60:.1f}m" | |
else: | |
return f"{seconds/3600:.1f}h" | |
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str: | |
"""Truncate text to specified length with suffix.""" | |
if not isinstance(text, str): | |
text = str(text) | |
if len(text) <= max_length: | |
return text | |
return text[:max_length - len(suffix)] + suffix |