leaderboard / src /utils.py
akera's picture
Update src/utils.py
fa045d5 verified
raw
history blame
9.63 kB
# src/utils.py
import re
import datetime
import pandas as pd
from typing import Dict, List, Tuple, Set, Optional
from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES, DISPLAY_CONFIG
def get_all_language_pairs() -> List[Tuple[str, str]]:
"""Get all possible UG40 language pairs."""
pairs = []
for src in ALL_UG40_LANGUAGES:
for tgt in ALL_UG40_LANGUAGES:
if src != tgt:
pairs.append((src, tgt))
return pairs
def get_google_comparable_pairs() -> List[Tuple[str, str]]:
"""Get language pairs that can be compared with Google Translate."""
pairs = []
for src in GOOGLE_SUPPORTED_LANGUAGES:
for tgt in GOOGLE_SUPPORTED_LANGUAGES:
if src != tgt:
pairs.append((src, tgt))
return pairs
def format_language_pair(src: str, tgt: str) -> str:
"""Format language pair for display."""
src_name = LANGUAGE_NAMES.get(src, src.upper())
tgt_name = LANGUAGE_NAMES.get(tgt, tgt.upper())
return f"{src_name}{tgt_name}"
def validate_language_code(lang: str) -> bool:
"""Validate if language code is supported."""
return lang in ALL_UG40_LANGUAGES
def create_submission_id() -> str:
"""Create unique submission ID."""
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
def sanitize_model_name(name: str) -> str:
"""Sanitize model name for display and storage."""
if not name or not isinstance(name, str):
return "Anonymous_Model"
# Remove special characters, limit length
name = re.sub(r'[^\w\-.]', '_', name.strip())
# Remove multiple consecutive underscores
name = re.sub(r'_+', '_', name)
# Remove leading/trailing underscores
name = name.strip('_')
# Ensure minimum length
if len(name) < 3:
name = f"Model_{name}"
return name[:50] # Limit to 50 characters
def format_metric_value(value: float, metric: str) -> str:
"""Format metric value for display with appropriate precision."""
if pd.isna(value) or value is None:
return "N/A"
try:
precision = DISPLAY_CONFIG['decimal_places'].get(metric, 4)
if metric == 'coverage_rate':
return f"{value:.{precision}%}"
elif metric in ['bleu']:
return f"{value:.{precision}f}"
elif metric in ['cer', 'wer'] and value > 1:
# Cap error rates at 1.0 for display
return f"{min(value, 1.0):.{precision}f}"
else:
return f"{value:.{precision}f}"
except (ValueError, TypeError):
return str(value)
def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]:
"""Get statistics about language pair coverage in test data."""
if test_data.empty:
return {}
stats = {}
try:
for src in ALL_UG40_LANGUAGES:
for tgt in ALL_UG40_LANGUAGES:
if src != tgt:
pair_data = test_data[
(test_data['source_language'] == src) &
(test_data['target_language'] == tgt)
]
stats[f"{src}_{tgt}"] = {
'count': len(pair_data),
'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES,
'display_name': format_language_pair(src, tgt),
'source_language': src,
'target_language': tgt
}
except Exception as e:
print(f"Error calculating language pair stats: {e}")
return {}
return stats
def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
"""Validate that submission covers all required samples."""
if predictions.empty or test_set.empty:
return {
'is_complete': False,
'missing_count': len(test_set) if not test_set.empty else 0,
'extra_count': len(predictions) if not predictions.empty else 0,
'missing_ids': [],
'coverage': 0.0
}
try:
required_ids = set(test_set['sample_id'].astype(str))
provided_ids = set(predictions['sample_id'].astype(str))
missing_ids = required_ids - provided_ids
extra_ids = provided_ids - required_ids
return {
'is_complete': len(missing_ids) == 0,
'missing_count': len(missing_ids),
'extra_count': len(extra_ids),
'missing_ids': list(missing_ids)[:10], # First 10 for display
'coverage': len(provided_ids & required_ids) / len(required_ids) if required_ids else 0.0
}
except Exception as e:
print(f"Error validating submission completeness: {e}")
return {
'is_complete': False,
'missing_count': 0,
'extra_count': 0,
'missing_ids': [],
'coverage': 0.0
}
def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
"""Calculate coverage by language pair."""
if predictions.empty or test_set.empty:
return {}
try:
# Merge to get language info
merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred'))
coverage = {}
for src in ALL_UG40_LANGUAGES:
for tgt in ALL_UG40_LANGUAGES:
if src != tgt:
pair_data = merged[
(merged['source_language'] == src) &
(merged['target_language'] == tgt)
]
if len(pair_data) > 0:
predicted_count = pair_data['prediction'].notna().sum()
coverage[f"{src}_{tgt}"] = {
'total': len(pair_data),
'predicted': predicted_count,
'coverage': predicted_count / len(pair_data),
'display_name': format_language_pair(src, tgt)
}
return coverage
except Exception as e:
print(f"Error calculating language pair coverage: {e}")
return {}
def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float:
"""Safely divide two numbers, handling edge cases."""
try:
if denominator == 0 or pd.isna(denominator) or pd.isna(numerator):
return default
result = numerator / denominator
if pd.isna(result) or not pd.isfinite(result):
return default
return float(result)
except (TypeError, ValueError, ZeroDivisionError):
return default
def clean_text_for_evaluation(text: str) -> str:
"""Clean text for evaluation, handling common encoding issues."""
if not isinstance(text, str):
return str(text) if text is not None else ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text.strip())
# Handle common encoding issues
text = text.replace('\u00a0', ' ') # Non-breaking space
text = text.replace('\u2019', "'") # Right single quotation mark
text = text.replace('\u201c', '"') # Left double quotation mark
text = text.replace('\u201d', '"') # Right double quotation mark
return text
def get_model_summary_stats(model_results: Dict) -> Dict:
"""Extract summary statistics from model evaluation results."""
if not model_results or 'averages' not in model_results:
return {}
averages = model_results.get('averages', {})
summary = model_results.get('summary', {})
return {
'quality_score': averages.get('quality_score', 0.0),
'bleu': averages.get('bleu', 0.0),
'chrf': averages.get('chrf', 0.0),
'rouge1': averages.get('rouge1', 0.0),
'rougeL': averages.get('rougeL', 0.0),
'total_samples': summary.get('total_samples', 0),
'language_pairs': summary.get('language_pairs_covered', 0),
'google_pairs': summary.get('google_comparable_pairs', 0)
}
def generate_model_identifier(model_name: str, author: str) -> str:
"""Generate a unique identifier for a model."""
clean_name = sanitize_model_name(model_name)
clean_author = re.sub(r'[^\w\-]', '_', author.strip())[:20] if author else "Anonymous"
timestamp = datetime.datetime.now().strftime("%m%d_%H%M")
return f"{clean_name}_{clean_author}_{timestamp}"
def validate_dataframe_structure(df: pd.DataFrame, required_columns: List[str]) -> Tuple[bool, List[str]]:
"""Validate that a DataFrame has the required structure."""
if df.empty:
return False, ["DataFrame is empty"]
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
return False, [f"Missing columns: {', '.join(missing_columns)}"]
return True, []
def format_duration(seconds: float) -> str:
"""Format duration in seconds to human-readable format."""
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
return f"{seconds/60:.1f}m"
else:
return f"{seconds/3600:.1f}h"
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
"""Truncate text to specified length with suffix."""
if not isinstance(text, str):
text = str(text)
if len(text) <= max_length:
return text
return text[:max_length - len(suffix)] + suffix