Llama3-ChatQA-2-70B / evaluation /long_32k_eval /dataset_evaluator_retro_longbench.py
root
add long_32k_eval
dfdc6c0
import os
import argparse
import json
import shutil
import re
from datasets import load_dataset, load_metric
from huggingface_hub import hf_hub_download
from longbench.eval import scorer
LONGBENCH_DATASETS = [
'musique', # NOTE TODO to add other 20 datasets
'hotpotqa',
'multifieldqa_en'
]
PATTERN = re.compile(r'\b[A-D]\b')
def find_answer(s):
match = PATTERN.search(s)
if match is None:
return None # None is a signal of not find! NOTE
return match.group()
def read_json_data(data_path):
references = []
questions = []
id_to_labels = dict()
id_list = list()
idx = 0
with open(data_path, "r") as f:
examples = json.load(f)
for data_item in examples: # dict_keys(['source', 'paragraph_id', 'question', 'answer', 'sub-paragraphs', 'word_count', 'id', 'ctxs'])
idx_str = str(idx) if 'id' not in data_item else str(data_item['id'])
idx += 1
id_list.append(idx_str)
questions.append(data_item['question'])
if "answers" in data_item:
references.append(data_item['answers']) # NOTE take all the answers!
answer_list = [answer_str for answer_str in data_item['answers']]
id_to_labels[idx_str] = answer_list
elif "answer" in data_item:
references.append([data_item['answer']]) # take the single answer, as a list
id_to_labels[idx_str] = [data_item['answer']]
else:
raise ValueError("need answer or answers from input json")
return id_to_labels, id_list, questions, references #answers
def convert_to_seq(aquestion, apred):
if apred is None:
apred = ""
matched_pred = find_answer(apred)
if matched_pred is None:
matched_pred = apred
apred = '({})'.format(matched_pred)
alist = aquestion.split('\n')
for aitem in alist:
aitem = aitem.strip()
if aitem.startswith(apred):
pred_out = ' '.join(aitem.split(' ')[1:])
print('from {} to [{}]'.format(apred, pred_out))
return pred_out
print('Warning: could not find ({}) from question {}'.format(apred, aquestion))
return apred
def load_prediction_openai(test_file):
predictions = []
with open(test_file, "r") as f:
apred_list = list()
for aline in f.readlines():
if aline.startswith('assistant: '):
if len(apred_list) > 0:
print('\n'.join(apred_list))
predictions.append('\n'.join(apred_list))
apred_list = list()
apred_list.append(aline[len('assistant: '):].strip())
else:
apred_list.append(aline.strip())
if len(apred_list) > 0:
predictions.append('\n'.join(apred_list))
print(len(predictions))
return predictions
# 500 -> 100
def load_prediction(test_file, id_list, id_to_labels,
questions, dataset_name, is_openai_assistant=False):
if is_openai_assistant:
predictions = load_prediction_openai(test_file)
else:
predictions = []
with open(test_file, "r") as f:
for line in f.readlines():
predictions.append(line.strip())
if len(predictions) != len(id_list):
print("NOTE: different number of samples, {} in prediction, yet {} in reference".format(
len(predictions), len(id_list)))
id_list = id_list[0: len(predictions)]
id_to_prediction = dict()
for aid, apred in zip(id_list, predictions):
id_to_prediction[aid] = apred
if dataset_name.startswith('quality'):
print('quality dataset, and rewriting the prediction to the full textual sequence...')
questions = questions[0: len(predictions)]
id_to_prediction = dict()
for aid, aquestion, apred in zip(id_list, questions, predictions):
apred_seq = convert_to_seq(aquestion, apred)
id_to_prediction[aid] = apred_seq
return id_to_prediction, id_list, predictions
def main(args, raise_on_errors=False):
datasets = [args.dataset] if args.dataset in LONGBENCH_DATASETS else LONGBENCH_DATASETS
for dataset_name in datasets:
print(dataset_name)
id_to_labels, id_list, questions, answers = read_json_data(args.datapath)
id_to_pred, id_list, predictions = load_prediction(args.gen_test_file,
id_list, id_to_labels, questions,
dataset_name, args.is_openai_assistant)
if len(id_to_labels) > len(id_list):
print('NOTE: prune the reference set from {} to {}'.format(
len(id_to_labels), len(id_list)))
id_to_labels = {aid:id_to_labels[aid] for aid in id_list}
errors, details = verify(id_to_pred, id_to_labels)
if len(errors) == 0:
score = scorer(dataset_name, predictions, answers, all_classes=None)
print('final display:', dataset_name, score, "\n", args.gen_test_file)
elif len(errors) > 0:
errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
print(json.dumps(errors, indent=4))
raise ValueError(f"Failed to evaluate due to: {errors_msg}")
def download_metric():
scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
)
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
return updated_scrolls_metric_path
def verify(id_to_pred, id_to_labels):
errors = []
details = {"missing_keys": [], "redundant_keys": []}
if not isinstance(id_to_pred, dict):
errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
else:
if not all(isinstance(key, str) for key in id_to_pred.keys()):
errors.append("All keys of the predictions dictionary must be strings")
if not all(isinstance(value, str) for value in id_to_pred.values()):
errors.append("All values of the predictions dictionary must be strings")
if len(errors) == 0:
predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
missing_keys = reference_keys - predictions_keys
redundant_keys = predictions_keys - reference_keys
if len(missing_keys) > 0:
details["missing_keys"] = list(missing_keys)
errors.append(f"There are missing example IDs.")
else:
del details["missing_keys"]
if len(redundant_keys) > 0:
details["redundant_keys"] = list(redundant_keys)
errors.append(f"There are redundant example IDs.")
else:
del details["redundant_keys"]
return errors, details
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
dataset_help = "name of the dataset used in longbench: {}".format(LONGBENCH_DATASETS)
parser.add_argument("--datapath", type=str, required=True,
default=None, help="datapath for test json file [reference]")
parser.add_argument("--gen_test_file", type=str, required=True,
default=None, help="generations for test file [system prediction]")
parser.add_argument("--dataset", type=str, required=True,
default=None, help=dataset_help)
parser.add_argument("--is_openai_assistant", type=bool, required=False,
default=False,
help='if openai assistant, then combine multiple lines and the 1st-line starts with assistant:')
args = parser.parse_args()
print(args)
main(args)