akera commited on
Commit
ce626d3
·
verified ·
1 Parent(s): 796d1cd

Update src/plotting.py

Browse files
Files changed (1) hide show
  1. src/plotting.py +124 -57
src/plotting.py CHANGED
@@ -27,14 +27,12 @@ def create_leaderboard_ranking_plot(df: pd.DataFrame, metric: str = 'quality_sco
27
  x=0.5, y=0.5, showarrow=False,
28
  font=dict(size=16)
29
  )
 
30
  return fig
31
 
32
  # Get top N models
33
  top_models = df.head(top_n)
34
 
35
- # Create color scale based on scores
36
- colors = px.colors.qualitative.Set3[:len(top_models)]
37
-
38
  # Create horizontal bar chart
39
  fig = go.Figure(data=[
40
  go.Bar(
@@ -79,7 +77,10 @@ def create_metrics_comparison_plot(df: pd.DataFrame, models: List[str] = None, m
79
  """Create radar chart comparing multiple metrics across models."""
80
 
81
  if df.empty:
82
- return go.Figure().add_annotation(text="No data available", x=0.5, y=0.5)
 
 
 
83
 
84
  # Select models to compare
85
  if models is None:
@@ -88,7 +89,10 @@ def create_metrics_comparison_plot(df: pd.DataFrame, models: List[str] = None, m
88
  selected_models = df[df['model_name'].isin(models)].head(max_models)
89
 
90
  if len(selected_models) == 0:
91
- return go.Figure().add_annotation(text="No models found", x=0.5, y=0.5)
 
 
 
92
 
93
  # Metrics to include in radar chart
94
  metrics = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL']
@@ -139,7 +143,10 @@ def create_language_pair_heatmap(results_dict: Dict, metric: str = 'quality_scor
139
  """Create heatmap showing performance across language pairs."""
140
 
141
  if not results_dict or 'pair_metrics' not in results_dict:
142
- return go.Figure().add_annotation(text="No language pair data available", x=0.5, y=0.5)
 
 
 
143
 
144
  pair_metrics = results_dict['pair_metrics']
145
 
@@ -168,7 +175,7 @@ def create_language_pair_heatmap(results_dict: Dict, metric: str = 'quality_scor
168
  colorscale='Viridis',
169
  showscale=True,
170
  colorbar=dict(title=metric.replace('_', ' ').title()),
171
- hoverinfotemplate=(
172
  "Source: %{y}<br>" +
173
  "Target: %{x}<br>" +
174
  f"{metric.replace('_', ' ').title()}: %{{z:.3f}}<br>" +
@@ -190,7 +197,10 @@ def create_coverage_analysis_plot(df: pd.DataFrame) -> go.Figure:
190
  """Create plot analyzing test set coverage across submissions."""
191
 
192
  if df.empty:
193
- return go.Figure().add_annotation(text="No data available", x=0.5, y=0.5)
 
 
 
194
 
195
  fig = make_subplots(
196
  rows=2, cols=2,
@@ -258,7 +268,10 @@ def create_model_performance_timeline(df: pd.DataFrame) -> go.Figure:
258
  """Create timeline showing model performance over time."""
259
 
260
  if df.empty:
261
- return go.Figure().add_annotation(text="No data available", x=0.5, y=0.5)
 
 
 
262
 
263
  # Convert submission_date to datetime
264
  df_copy = df.copy()
@@ -319,10 +332,13 @@ def create_google_comparison_plot(df: pd.DataFrame) -> go.Figure:
319
  google_models = df[df['google_pairs_covered'] > 0].copy()
320
 
321
  if google_models.empty:
322
- return go.Figure().add_annotation(
 
323
  text="No models with Google Translate comparable results",
324
- x=0.5, y=0.5
325
  )
 
 
326
 
327
  fig = go.Figure()
328
 
@@ -360,10 +376,13 @@ def create_google_comparison_plot(df: pd.DataFrame) -> go.Figure:
360
  return fig
361
 
362
  def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.Figure:
363
- """Create detailed analysis plot for a specific model."""
364
 
365
  if not model_results or 'pair_metrics' not in model_results:
366
- return go.Figure().add_annotation(text="No detailed results available", x=0.5, y=0.5)
 
 
 
367
 
368
  pair_metrics = model_results['pair_metrics']
369
 
@@ -388,22 +407,26 @@ def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.F
388
  google_comparable.append(is_google)
389
 
390
  if not pairs:
391
- return go.Figure().add_annotation(text="No language pair data found", x=0.5, y=0.5)
 
 
 
392
 
393
- # Create subplot
394
  fig = make_subplots(
395
  rows=2, cols=1,
396
  subplot_titles=(
397
- f"{model_name} - BLEU Scores by Language Pair",
398
- f"{model_name} - Quality Scores by Language Pair"
399
  ),
400
- vertical_spacing=0.1
 
401
  )
402
 
403
  # Color code by Google comparable
404
  colors = ['#1f77b4' if gc else '#ff7f0e' for gc in google_comparable]
405
 
406
- # BLEU scores
407
  fig.add_trace(
408
  go.Bar(
409
  x=pairs,
@@ -411,12 +434,14 @@ def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.F
411
  marker_color=colors,
412
  name="BLEU",
413
  text=[f"{score:.1f}" for score in bleu_scores],
414
- textposition='auto'
 
 
415
  ),
416
  row=1, col=1
417
  )
418
 
419
- # Quality scores
420
  fig.add_trace(
421
  go.Bar(
422
  x=pairs,
@@ -424,27 +449,47 @@ def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.F
424
  marker_color=colors,
425
  name="Quality",
426
  text=[f"{score:.3f}" for score in quality_scores],
427
- textposition='auto',
 
428
  showlegend=False
429
  ),
430
  row=2, col=1
431
  )
432
 
 
433
  fig.update_layout(
434
- height=800,
435
- title=f"📊 Detailed Analysis: {model_name}",
436
- showlegend=True
 
 
 
 
 
437
  )
438
 
439
- # Rotate x-axis labels
440
- fig.update_xaxes(tickangle=45)
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
- # Add legend for colors
443
  fig.add_trace(
444
  go.Scatter(
445
  x=[None], y=[None],
446
  mode='markers',
447
- marker=dict(size=10, color='#1f77b4'),
448
  name="Google Comparable",
449
  showlegend=True
450
  )
@@ -454,7 +499,7 @@ def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.F
454
  go.Scatter(
455
  x=[None], y=[None],
456
  mode='markers',
457
- marker=dict(size=10, color='#ff7f0e'),
458
  name="UG40 Only",
459
  showlegend=True
460
  )
@@ -468,25 +513,25 @@ def create_submission_summary_plot(validation_info: Dict, evaluation_results: Di
468
  fig = make_subplots(
469
  rows=2, cols=2,
470
  subplot_titles=(
471
- "Coverage by Language Pair",
472
  "Primary Metrics",
473
- "Error Analysis",
474
- "Sample Distribution"
475
  ),
476
- specs=[[{"type": "bar"}, {"type": "bar"}],
477
- [{"type": "bar"}, {"type": "pie"}]]
478
  )
479
 
480
- # Coverage by language pair
481
- if 'pair_coverage' in validation_info:
482
- pair_data = validation_info['pair_coverage']
483
- pairs = list(pair_data.keys())[:10] # Top 10 pairs
484
- coverage_rates = [pair_data[p]['coverage_rate'] for p in pairs]
485
-
486
- fig.add_trace(
487
- go.Bar(x=pairs, y=coverage_rates, name="Coverage"),
488
- row=1, col=1
489
- )
490
 
491
  # Primary metrics
492
  if 'summary' in evaluation_results:
@@ -495,7 +540,13 @@ def create_submission_summary_plot(validation_info: Dict, evaluation_results: Di
495
  metric_values = list(metrics_data.values())
496
 
497
  fig.add_trace(
498
- go.Bar(x=metric_names, y=metric_values, name="Metrics"),
 
 
 
 
 
 
499
  row=1, col=2
500
  )
501
 
@@ -505,20 +556,36 @@ def create_submission_summary_plot(validation_info: Dict, evaluation_results: Di
505
  error_values = [evaluation_results['averages'].get(m, 0) for m in error_metrics]
506
 
507
  fig.add_trace(
508
- go.Bar(x=error_metrics, y=error_values, name="Errors"),
 
 
 
 
 
 
509
  row=2, col=1
510
  )
511
 
512
- # Sample distribution (placeholder)
513
- fig.add_trace(
514
- go.Pie(
515
- labels=["Evaluated", "Missing"],
516
- values=[validation_info.get('coverage', 0.8) * 100,
517
- (1 - validation_info.get('coverage', 0.8)) * 100],
518
- name="Samples"
519
- ),
520
- row=2, col=2
521
- )
 
 
 
 
 
 
 
 
 
 
522
 
523
  fig.update_layout(
524
  title="📋 Submission Summary",
 
27
  x=0.5, y=0.5, showarrow=False,
28
  font=dict(size=16)
29
  )
30
+ fig.update_layout(title="No Data Available")
31
  return fig
32
 
33
  # Get top N models
34
  top_models = df.head(top_n)
35
 
 
 
 
36
  # Create horizontal bar chart
37
  fig = go.Figure(data=[
38
  go.Bar(
 
77
  """Create radar chart comparing multiple metrics across models."""
78
 
79
  if df.empty:
80
+ fig = go.Figure()
81
+ fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False)
82
+ fig.update_layout(title="No Data Available")
83
+ return fig
84
 
85
  # Select models to compare
86
  if models is None:
 
89
  selected_models = df[df['model_name'].isin(models)].head(max_models)
90
 
91
  if len(selected_models) == 0:
92
+ fig = go.Figure()
93
+ fig.add_annotation(text="No models found", x=0.5, y=0.5, showarrow=False)
94
+ fig.update_layout(title="No Models Found")
95
+ return fig
96
 
97
  # Metrics to include in radar chart
98
  metrics = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL']
 
143
  """Create heatmap showing performance across language pairs."""
144
 
145
  if not results_dict or 'pair_metrics' not in results_dict:
146
+ fig = go.Figure()
147
+ fig.add_annotation(text="No language pair data available", x=0.5, y=0.5, showarrow=False)
148
+ fig.update_layout(title="No Language Pair Data Available")
149
+ return fig
150
 
151
  pair_metrics = results_dict['pair_metrics']
152
 
 
175
  colorscale='Viridis',
176
  showscale=True,
177
  colorbar=dict(title=metric.replace('_', ' ').title()),
178
+ hovertemplate=(
179
  "Source: %{y}<br>" +
180
  "Target: %{x}<br>" +
181
  f"{metric.replace('_', ' ').title()}: %{{z:.3f}}<br>" +
 
197
  """Create plot analyzing test set coverage across submissions."""
198
 
199
  if df.empty:
200
+ fig = go.Figure()
201
+ fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False)
202
+ fig.update_layout(title="No Data Available")
203
+ return fig
204
 
205
  fig = make_subplots(
206
  rows=2, cols=2,
 
268
  """Create timeline showing model performance over time."""
269
 
270
  if df.empty:
271
+ fig = go.Figure()
272
+ fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False)
273
+ fig.update_layout(title="No Data Available")
274
+ return fig
275
 
276
  # Convert submission_date to datetime
277
  df_copy = df.copy()
 
332
  google_models = df[df['google_pairs_covered'] > 0].copy()
333
 
334
  if google_models.empty:
335
+ fig = go.Figure()
336
+ fig.add_annotation(
337
  text="No models with Google Translate comparable results",
338
+ x=0.5, y=0.5, showarrow=False
339
  )
340
+ fig.update_layout(title="No Google Comparable Models")
341
+ return fig
342
 
343
  fig = go.Figure()
344
 
 
376
  return fig
377
 
378
  def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.Figure:
379
+ """Create detailed analysis plot for a specific model - FIXED version."""
380
 
381
  if not model_results or 'pair_metrics' not in model_results:
382
+ fig = go.Figure()
383
+ fig.add_annotation(text="No detailed results available", x=0.5, y=0.5, showarrow=False)
384
+ fig.update_layout(title=f"No Data for {model_name}")
385
+ return fig
386
 
387
  pair_metrics = model_results['pair_metrics']
388
 
 
407
  google_comparable.append(is_google)
408
 
409
  if not pairs:
410
+ fig = go.Figure()
411
+ fig.add_annotation(text="No language pair data found", x=0.5, y=0.5, showarrow=False)
412
+ fig.update_layout(title=f"No Language Pair Data for {model_name}")
413
+ return fig
414
 
415
+ # Create subplot with proper spacing and titles
416
  fig = make_subplots(
417
  rows=2, cols=1,
418
  subplot_titles=(
419
+ f"BLEU Scores by Language Pair",
420
+ f"Quality Scores by Language Pair"
421
  ),
422
+ vertical_spacing=0.15,
423
+ row_heights=[0.45, 0.45]
424
  )
425
 
426
  # Color code by Google comparable
427
  colors = ['#1f77b4' if gc else '#ff7f0e' for gc in google_comparable]
428
 
429
+ # BLEU scores (top subplot)
430
  fig.add_trace(
431
  go.Bar(
432
  x=pairs,
 
434
  marker_color=colors,
435
  name="BLEU",
436
  text=[f"{score:.1f}" for score in bleu_scores],
437
+ textposition='outside',
438
+ textfont=dict(size=10),
439
+ showlegend=True
440
  ),
441
  row=1, col=1
442
  )
443
 
444
+ # Quality scores (bottom subplot)
445
  fig.add_trace(
446
  go.Bar(
447
  x=pairs,
 
449
  marker_color=colors,
450
  name="Quality",
451
  text=[f"{score:.3f}" for score in quality_scores],
452
+ textposition='outside',
453
+ textfont=dict(size=10),
454
  showlegend=False
455
  ),
456
  row=2, col=1
457
  )
458
 
459
+ # Update layout
460
  fig.update_layout(
461
+ height=900,
462
+ title=dict(
463
+ text=f"📊 Detailed Analysis: {model_name}",
464
+ x=0.5,
465
+ xanchor='center'
466
+ ),
467
+ showlegend=True,
468
+ margin=dict(l=50, r=50, t=100, b=150)
469
  )
470
 
471
+ # Update x-axes to rotate labels properly
472
+ fig.update_xaxes(
473
+ tickangle=45,
474
+ tickfont=dict(size=10),
475
+ row=1, col=1
476
+ )
477
+ fig.update_xaxes(
478
+ tickangle=45,
479
+ tickfont=dict(size=10),
480
+ row=2, col=1
481
+ )
482
+
483
+ # Update y-axes
484
+ fig.update_yaxes(title_text="BLEU Score", row=1, col=1)
485
+ fig.update_yaxes(title_text="Quality Score", row=2, col=1)
486
 
487
+ # Add legend manually for Google vs UG40 only
488
  fig.add_trace(
489
  go.Scatter(
490
  x=[None], y=[None],
491
  mode='markers',
492
+ marker=dict(size=15, color='#1f77b4', symbol='square'),
493
  name="Google Comparable",
494
  showlegend=True
495
  )
 
499
  go.Scatter(
500
  x=[None], y=[None],
501
  mode='markers',
502
+ marker=dict(size=15, color='#ff7f0e', symbol='square'),
503
  name="UG40 Only",
504
  showlegend=True
505
  )
 
513
  fig = make_subplots(
514
  rows=2, cols=2,
515
  subplot_titles=(
516
+ "Sample Distribution",
517
  "Primary Metrics",
518
+ "Error Analysis",
519
+ "Coverage Summary"
520
  ),
521
+ specs=[[{"type": "pie"}, {"type": "bar"}],
522
+ [{"type": "bar"}, {"type": "bar"}]]
523
  )
524
 
525
+ # Sample distribution (pie chart)
526
+ coverage = validation_info.get('coverage', 0.8)
527
+ fig.add_trace(
528
+ go.Pie(
529
+ labels=["Evaluated", "Missing"],
530
+ values=[coverage * 100, (1 - coverage) * 100],
531
+ name="Samples"
532
+ ),
533
+ row=1, col=1
534
+ )
535
 
536
  # Primary metrics
537
  if 'summary' in evaluation_results:
 
540
  metric_values = list(metrics_data.values())
541
 
542
  fig.add_trace(
543
+ go.Bar(
544
+ x=metric_names,
545
+ y=metric_values,
546
+ name="Metrics",
547
+ text=[f"{val:.3f}" for val in metric_values],
548
+ textposition='auto'
549
+ ),
550
  row=1, col=2
551
  )
552
 
 
556
  error_values = [evaluation_results['averages'].get(m, 0) for m in error_metrics]
557
 
558
  fig.add_trace(
559
+ go.Bar(
560
+ x=error_metrics,
561
+ y=error_values,
562
+ name="Errors",
563
+ text=[f"{val:.3f}" for val in error_values],
564
+ textposition='auto'
565
+ ),
566
  row=2, col=1
567
  )
568
 
569
+ # Coverage summary
570
+ if 'summary' in evaluation_results:
571
+ summary = evaluation_results['summary']
572
+ coverage_labels = ["Total Samples", "Lang Pairs", "Google Pairs"]
573
+ coverage_values = [
574
+ summary.get('total_samples', 0),
575
+ summary.get('language_pairs_covered', 0),
576
+ summary.get('google_comparable_pairs', 0)
577
+ ]
578
+
579
+ fig.add_trace(
580
+ go.Bar(
581
+ x=coverage_labels,
582
+ y=coverage_values,
583
+ name="Coverage",
584
+ text=[f"{val}" for val in coverage_values],
585
+ textposition='auto'
586
+ ),
587
+ row=2, col=2
588
+ )
589
 
590
  fig.update_layout(
591
  title="📋 Submission Summary",