vimey commited on
Commit
1f9f72d
·
1 Parent(s): 25839d1

Upload 2 files

Browse files
results/base_encoder_freezing_normal.csv ADDED
The diff for this file is too large to render. See raw diff
 
scripts/encoder_freezing.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load the requirements
2
+ import torch
3
+ import os
4
+ from transformers import (
5
+ WhisperFeatureExtractor,
6
+ WhisperTokenizer, WhisperProcessor,
7
+ Seq2SeqTrainingArguments,
8
+ WhisperForConditionalGeneration,
9
+ TrainerCallback,
10
+ Seq2SeqTrainer,
11
+ )
12
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
13
+ from torch.utils.data import IterableDataset
14
+ import evaluate
15
+ from datasets import load_dataset, Audio
16
+ from dataclasses import dataclass
17
+ import pandas as pd
18
+ import subprocess
19
+ import datetime
20
+ import csv
21
+
22
+ # define the model id
23
+ model_id = "openai/insert_model_id"
24
+
25
+ # specify the output file path of the wrong predictions
26
+ output_file_path = "path/to/your/output/wrong_predictions.csv"
27
+
28
+ # specify the output file path of the computational resources data
29
+ output_file_path_gpu = "path/to/your/output/efficiency_data.csv"
30
+
31
+ # load and define the feature extractor and the tokenizer
32
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)
33
+
34
+ tokenizer = WhisperTokenizer.from_pretrained(model_id, language = "English", task = "transcribe")
35
+
36
+ # load audio dataset
37
+ audio_dataset_train = load_dataset("audiofolder", data_dir = "/path/to/dataset/train")
38
+ audio_dataset_test = load_dataset("audiofolder", data_dir = "/path/to/dataset/test")
39
+
40
+ # load the processor
41
+ processor = WhisperProcessor.from_pretrained(model_id, language = "English", task = "transcribe")
42
+
43
+ # preprocess the data
44
+ audio_dataset_train = audio_dataset_train.cast_column("audio", Audio(sampling_rate=16000))
45
+ audio_dataset_test = audio_dataset_test.cast_column("audio", Audio(sampling_rate=16000))
46
+
47
+ do_lower_case = False
48
+ do_remove_punctuation = False
49
+ normalizer = BasicTextNormalizer()
50
+
51
+ def prepare_dataset(batch):
52
+
53
+ audio = batch["audio"]
54
+ batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
55
+ batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
56
+ transcription = batch["transcription"]
57
+ if do_lower_case:
58
+ transcription = transcription.lower()
59
+ if do_remove_punctuation:
60
+ transcription = normalizer(transcription).strip()
61
+ batch["labels"] = processor.tokenizer(transcription).input_ids
62
+ return batch
63
+
64
+ # apply 'prepare dataset' function to each sample in the dataset
65
+ vectorized_audio_dataset_train = audio_dataset_train.map(
66
+ prepare_dataset,
67
+ remove_columns=list(next(iter(audio_dataset_train.values())).features)).with_format("torch")
68
+ vectorized_audio_dataset_test = audio_dataset_test.map(
69
+ prepare_dataset,
70
+ remove_columns=list(next(iter(audio_dataset_test.values())).features)).with_format("torch")
71
+
72
+ # shuffle the audioset, shard selects the whole dataset, seed and contigiuguos=TRUE ensure the reproducibility of the shuffling order
73
+ vectorized_audio_dataset_train["train"] = vectorized_audio_dataset_train["train"].shuffle(
74
+ seed=0,
75
+ load_from_cache_file=False).shard(
76
+ num_shards=1, index=0, contiguous=True)
77
+
78
+ # training and evaluation
79
+
80
+ # define a data collator
81
+ @dataclass
82
+ class DataCollatorSpeechSeq2SeqWithPadding:
83
+ processor: any
84
+
85
+ def __call__(self, features):
86
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
87
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
88
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
89
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
90
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
91
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
92
+ labels = labels[:, 1:]
93
+ batch["labels"] = labels
94
+ return batch
95
+
96
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
97
+
98
+ # evaluation matrix WER
99
+ metric = evaluate.load("wer")
100
+ do_normalize_eval = True
101
+
102
+ # store filenames, predictions and references
103
+ predicted_words_list = []
104
+ target_words_list = []
105
+ filenames = []
106
+
107
+ def compute_metrics(pred):
108
+ pred_ids = pred.predictions
109
+ label_ids = pred.label_ids
110
+
111
+ # replace -100 with the pad_token_id
112
+ label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
113
+ pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
114
+ label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
115
+
116
+ if do_normalize_eval:
117
+ pred_str = [normalizer(pred) for pred in pred_str]
118
+ label_str = [normalizer(label) for label in label_str]
119
+
120
+ # filtering step to only evaluate the samples that correspond to non-zero references:
121
+ pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
122
+ label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
123
+
124
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
125
+
126
+ # append wrong predictions and references to the respective lists, if it is a wrong prediction
127
+ for pred_word, target_word, filename in zip(pred_str, label_str, audio_dataset_test["train"]["audio"]):
128
+ if pred_word.strip() != "" and pred_word != target_word:
129
+ predicted_words_list.append(pred_word)
130
+ target_words_list.append(target_word)
131
+ filenames.append(os.path.basename(str(filename)))
132
+
133
+ print(f"WER: {wer}")
134
+ return {"wer": wer}
135
+
136
+ # load a pre-trained checkpoint
137
+ model = WhisperForConditionalGeneration.from_pretrained(model_id).to(torch.device(0))
138
+
139
+ # disable the use of forced ids, suppressing tokens and the cache
140
+ model.config.forced_decoder_ids = None
141
+ model.config.suppress_tokens = []
142
+ model.config.use_cache = False
143
+
144
+ # freeze the encoder
145
+ for param in model.get_encoder().parameters():
146
+ param.requires_grad = False
147
+
148
+ # define the training parameters
149
+ training_args = Seq2SeqTrainingArguments(
150
+ output_dir="./",
151
+ save_total_limit=2,
152
+ per_device_train_batch_size=64,
153
+ gradient_accumulation_steps=1,
154
+ eval_accumulation_steps=1,
155
+ learning_rate=1e-5,
156
+ warmup_steps=100,
157
+ max_steps=1000,
158
+ gradient_checkpointing=True,
159
+ fp16=True,
160
+ evaluation_strategy="steps",
161
+ per_device_eval_batch_size=8,
162
+ predict_with_generate=True,
163
+ generation_max_length=225,
164
+ save_steps=1000,
165
+ eval_steps=25,
166
+ logging_steps=25,
167
+ report_to=["tensorboard"],
168
+ load_best_model_at_end=True,
169
+ metric_for_best_model="wer",
170
+ greater_is_better=False,
171
+ push_to_hub=False,
172
+ )
173
+
174
+ # trainer callback to reinitialise and reshuffle the datasets at the beginning of each epoch
175
+ class ShuffleCallback(TrainerCallback):
176
+ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
177
+ if not isinstance(train_dataloader.dataset, IterableDataset):
178
+ train_dataloader.dataset.shuffle()
179
+
180
+
181
+ trainer = Seq2SeqTrainer(
182
+ args=training_args,
183
+ model=model,
184
+ train_dataset=vectorized_audio_dataset_train["train"],
185
+ eval_dataset=vectorized_audio_dataset_test["train"],
186
+ data_collator=data_collator,
187
+ compute_metrics=compute_metrics,
188
+ tokenizer=processor,
189
+ callbacks=[ShuffleCallback()],
190
+ )
191
+
192
+ model.save_pretrained(training_args.output_dir)
193
+ processor.save_pretrained(training_args.output_dir)
194
+
195
+ # log start and endtime of the training
196
+ start_time = datetime.datetime.now()
197
+
198
+ # launch training
199
+ trainer.train()
200
+
201
+ end_time = datetime.datetime.now()
202
+
203
+ # determine the maximum length among the lists
204
+ max_length = max(len(filenames), len(predicted_words_list), len(target_words_list))
205
+
206
+ # fill in missing values with empty strings to ensure equal lengths
207
+ filenames += [""] * (max_length - len(filenames))
208
+ predicted_words_list += [""] * (max_length - len(predicted_words_list))
209
+ target_words_list += [""] * (max_length - len(target_words_list))
210
+
211
+ # save the wrong predictions
212
+ df_wrong_predictions = pd.DataFrame({
213
+ "File Name": filenames,
214
+ "Predictions": predicted_words_list,
215
+ "References": target_words_list
216
+ })
217
+
218
+ pred_words_split = [pred.split() for pred in predicted_words_list]
219
+ target_words_split = [target.split() for target in target_words_list]
220
+ filtered_pred_words = [" ".join([word for word in pred if word != target_word]) for pred, target_word in zip(pred_words_split, target_words_split)]
221
+ filtered_target_words = [" ".join([word for word in target if word != pred_word]) for target, pred_word in zip(target_words_split, pred_words_split)]
222
+
223
+ # update the DataFrame with the filtered files
224
+ df_wrong_predictions["Predictions"] = filtered_pred_words
225
+ df_wrong_predictions["References"] = filtered_target_words
226
+ df_wrong_predictions = df_wrong_predictions[df_wrong_predictions["Predictions"] != df_wrong_predictions["References"]]
227
+
228
+ # save the DataFrame as a CSV file
229
+ df_wrong_predictions.to_csv(output_file_path, index=False)
230
+
231
+ # get training speed
232
+ duration = end_time - start_time
233
+ duration_hours = duration.total_seconds() / 3600 # Convert duration to hours
234
+
235
+ # get the GPU infos
236
+ def get_gpu_info():
237
+ try:
238
+ output = subprocess.check_output(["nvidia-smi", "--query-gpu=index,name,memory.used", "--format=csv,noheader,nounits"])
239
+ gpu_info = [line.strip().split(", ") for line in output.decode("utf-8").split("\n") if line.strip()]
240
+ return gpu_info
241
+ except Exception as e:
242
+ return []
243
+
244
+ gpu_info = get_gpu_info()
245
+ if gpu_info:
246
+ gpu_name = gpu_info[0][1]
247
+ gpu_memory_used = int(gpu_info[0][2])
248
+
249
+ with open(output_file_path_gpu, mode="w", newline="") as file:
250
+ writer = csv.writer(file)
251
+ writer.writerow(["Training Duration (hours)", "GPU Name", "GPU Memory Used (MB)"])
252
+ writer.writerow([duration_hours, gpu_name, gpu_memory_used])