# Evaluating models for comma fixing on the wikitext dataset

#### Setup the dataset

In [6]:
import datasets

In [7]:
wikitext = datasets.load_dataset('wikitext', 'wikitext-103-raw-v1', split="validation")

Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/192M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [8]:
wikitext[3]["text"]

' Homarus gammarus , known as the European lobster or common lobster , is a species of clawed lobster from the eastern Atlantic Ocean , Mediterranean Sea and parts of the Black Sea . It is closely related to the American lobster , H. americanus . It may grow to a length of 60 cm ( 24 in ) and a mass of 6 kilograms ( 13 lb ) , and bears a conspicuous pair of claws . In life , the lobsters are blue , only becoming " lobster red " on cooking . Mating occurs in the summer , producing eggs which are carried by the females for up to a year before hatching into planktonic larvae . Homarus gammarus is a highly esteemed food , and is widely caught using lobster pots , mostly around the British Isles . \n'

## Evaluation

Using our BaselineCommaFixer class to obtain predictions

In [9]:
from seqeval.metrics import classification_report, precision_score, recall_score
import re

In [10]:
from commafixer.src.baseline import BaselineCommaFixer
from commafixer.src.fixer import CommaFixer

In [38]:
comma_fixer = CommaFixer()

Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
baseline_fixer = BaselineCommaFixer(device=0)
baseline_fixer._ner.device

Validating the dataset, and testing the evaluate function.

In [None]:
gt = wikitext[7]["text"]

In [None]:
pred = comma_fixer.fix_commas(gt)

In [None]:
baseline_pred = baseline_fixer.fix_commas(gt)

In [16]:
gt, pred, baseline_pred

(' Homarus gammarus is a large crustacean , with a body length up to 60 centimetres ( 24 in ) and weighing up to 5 – 6 kilograms ( 11 – 13 lb ) , although the lobsters caught in lobster pots are usually 23 – 38 cm ( 9 – 15 in ) long and weigh 0 @.@ 7 – 2 @.@ 2 kg ( 1 @.@ 5 – 4 @.@ 9 lb ) . Like other crustaceans , lobsters have a hard exoskeleton which they must shed in order to grow , in a process called ecdysis ( moulting ) . This may occur several times a year for young lobsters , but decreases to once every 1 – 2 years for larger animals . \n',
 ' Homarus gammarus is a large crustacean with a body length up to 60 centimetres ( 24 in ) and weighing up to 5 – 6 kilograms ( 11 – 13 lb ), although the lobsters caught in lobster pots are usually 23 – 38 cm ( 9 – 15 in ) long and weigh 0 @.@ 7 – 2 @.@ 2 kg ( 1 @.@ 5 – 4 @.@ 9 lb ) . Like other crustaceans, lobsters have a hard exoskeleton, which they must shed in order to grow in a process called ecdysis ( moulting ) . This may occur sev

In [23]:
comma_indices_gt = [m.start() for m in re.finditer(',', gt)]
comma_indices_gt

[40, 142, 312, 385, 485]

In [18]:
comma_indices_pred = [m.start() for m in re.finditer(',', pred)]
comma_indices_pred

[139, 308, 342, 479]

In [22]:
comma_indices_baseline_pred = [m.start() for m in re.finditer(',', baseline_pred)]
comma_indices_baseline_pred

[309, 481]

We have 3 commas predicted correctly, 1 incorrectly and 2 missed for ours, and 2 correctly and 3 missed for baseline, so we are expecting 75% precision and 60% recall for main and 100% precission and 40% recall for baseline respectively.

In [19]:
def evaluate(gt_texts: list[str], pred_texts: list[str]) -> tuple[float, float]:
  """
  Evaluates predicted texts against ground truth texts in terms of comma placement.
  For each comma from the predicted text:
  If it should be there, according to ground truth, it counts as a true positive.
  If it should not be there, it counts as a false positive.
  If a comma from ground truth is not predicted, it counts as a false negative.
  :param gt_texts: Ground truth texts.
  :param pred_texts: Predicted texts.
  :return: Prints out the classification report, and returns the precision and recall scores.
  """
  def _seqeval_tags(normalized_text: str) -> list[str]:
    return ['B-COMMA' if token.endswith('-COMMA') else 'O' for token in normalized_text.split()]
  # Replacing commas with suffixes for the preceding token is a hack to make sure
  # number of tokens is equal between ground truth and predicted. It allows us
  # to convert texts into sequences of tags, and use seqeval metrics.
  tags_gt = [_seqeval_tags(gt.replace(' , ', '-COMMA ')) for gt in gt_texts]
  tags_pred = [_seqeval_tags(pred.replace(', ', '-COMMA ')) for pred in pred_texts]

  print(classification_report(y_true=tags_gt, y_pred=tags_pred))
  return precision_score(y_true=tags_gt, y_pred=tags_pred), recall_score(y_true=tags_gt, y_pred=tags_pred)

In [24]:
assert evaluate([gt], [pred]) == (0.75, 0.6)
assert evaluate([gt], [baseline_pred]) == (1.0, 0.4)

              precision    recall  f1-score   support

       COMMA       0.75      0.60      0.67         5

   micro avg       0.75      0.60      0.67         5
   macro avg       0.75      0.60      0.67         5
weighted avg       0.75      0.60      0.67         5

              precision    recall  f1-score   support

       COMMA       1.00      0.40      0.57         5

   micro avg       1.00      0.40      0.57         5
   macro avg       1.00      0.40      0.57         5
weighted avg       1.00      0.40      0.57         5


#### Prepare texts for evaluation

In [25]:
gt_texts = [x["text"] for x in wikitext]

## Baseline evaluation

To make sure we are utilizing the GPU properly, we operate directly on the dataset using the pipeline as opposed to our custom class.

In [26]:
from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset
from commafixer.src.baseline import _remove_punctuation, _fix_commas_based_on_pipeline_output

Preparing the dataset with removed punctuation, needed for the baseline model.

In [27]:
def map_function(text) -> dict:
  t, punctuation_indices = _remove_punctuation(text["text"])
  return {"text": t, "indices": punctuation_indices}

wikitext_no_punct = wikitext.map(map_function)
wikitext_no_punct[3]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

{'text': ' Homarus gammarus  known as the European lobster or common lobster  is a species of clawed lobster from the eastern Atlantic Ocean  Mediterranean Sea and parts of the Black Sea  It is closely related to the American lobster  H americanus  It may grow to a length of 60 cm ( 24 in ) and a mass of 6 kilograms ( 13 lb )  and bears a conspicuous pair of claws  In life  the lobsters are blue  only becoming " lobster red " on cooking  Mating occurs in the summer  producing eggs which are carried by the females for up to a year before hatching into planktonic larvae  Homarus gammarus is a highly esteemed food  and is widely caught using lobster pots  mostly around the British Isles  \n',
 'indices': [180, 231, 244, 365, 442, 578, 699]}

In [None]:
pipeline_outs = []
for out in tqdm(baseline_fixer._ner(KeyDataset(wikitext_no_punct, "text"), batched=True)):
  pipeline_outs.append(out)

3649it [03:25, 16.29it/s]

Obtaining predictions based on original texts and the pipeline outputs

In [29]:
pred_texts = [_fix_commas_based_on_pipeline_output(out, s, x["indices"]) for out, s, x in zip(pipeline_outs, gt_texts, wikitext_no_punct)]

## Evaluation results

In [30]:
evaluate(pred_texts=pred_texts, gt_texts=gt_texts)

              precision    recall  f1-score   support

       COMMA       0.79      0.72      0.75     10079

   micro avg       0.79      0.72      0.75     10079
   macro avg       0.79      0.72      0.75     10079
weighted avg       0.79      0.72      0.75     10079


(0.7898985491436675, 0.7184244468697292)

## Fine-tuned model evaluation

In [31]:
# TODO batch this too to make faster
pred_texts = [comma_fixer.fix_commas(gt) for gt in gt_texts]

In [32]:
pred_texts[:5]

['',
 ' = Homarus gammarus = \n',
 '',
 ' Homarus gammarus, known as the European lobster or common lobster, is a species of clawed lobster from the eastern Atlantic Ocean, Mediterranean Sea and parts of the Black Sea . It is closely related to the American lobster H. americanus . It may grow to a length of 60 cm ( 24 in ) and a mass of 6 kilograms ( 13 lb ), and bears a conspicuous pair of claws . In life, the lobsters are blue, only becoming " lobster red " on cooking . Mating occurs in the summer, producing eggs, which are carried by the females for up to a year before hatching into planktonic larvae . Homarus gammarus is a highly esteemed food and is widely caught using lobster pots, mostly around the British Isles . \n',
 '']

In [33]:
evaluate(pred_texts=pred_texts, gt_texts=gt_texts)

              precision    recall  f1-score   support

       COMMA       0.84      0.84      0.84     10079

   micro avg       0.84      0.84      0.84     10079
   macro avg       0.84      0.84      0.84     10079
weighted avg       0.84      0.84      0.84     10079


(0.8440220110055028, 0.8369877964083738)

In [34]:
test_wikitext = datasets.load_dataset('wikitext', 'wikitext-103-raw-v1', split="test")

In [35]:
# comma_fixer.model = comma_fixer.model.cuda() # TODO make this work and evaluate on test in the notebook as well. In
#  training eval on test was ~ same F1

In [36]:
gt_texts = [x["text"] for x in test_wikitext]

In [None]:
pred_texts = [comma_fixer.fix_commas(gt) for gt in gt_texts]

In [None]:
evaluate(pred_texts=pred_texts, gt_texts=gt_texts)