root
add long_32k_eval
dfdc6c0
import os
import json
import argparse
import numpy as np
from .metrics import (
qa_f1_score,
rouge_zh_score,
qa_f1_zh_score,
rouge_score,
classification_score,
retrieval_score,
retrieval_zh_score,
count_score,
code_sim_score,
)
dataset2metric = {
"narrativeqa": qa_f1_score,
"qasper": qa_f1_score,
"multifieldqa_en": qa_f1_score, # NOTE
"multifieldqa_zh": qa_f1_zh_score,
"hotpotqa": qa_f1_score, # NOTE
"2wikimqa": qa_f1_score,
"musique": qa_f1_score, # NOTE
"dureader": rouge_zh_score,
"gov_report": rouge_score,
"qmsum": rouge_score,
"multi_news": rouge_score,
"vcsum": rouge_zh_score,
"trec": classification_score,
"triviaqa": qa_f1_score,
"samsum": rouge_score,
"lsht": classification_score,
"passage_retrieval_en": retrieval_score,
"passage_count": count_score,
"passage_retrieval_zh": retrieval_zh_score,
"lcc": code_sim_score,
"repobench-p": code_sim_score,
}
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default=None)
parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
return parser.parse_args(args)
def scorer_e(dataset, predictions, answers, lengths, all_classes):
scores = {"0-4k": [], "4-8k": [], "8k+": []}
for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
score = 0.
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip('\n').split('\n')[0]
for ground_truth in ground_truths:
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
if length < 4000:
scores["0-4k"].append(score)
elif length < 8000:
scores["4-8k"].append(score)
else:
scores["8k+"].append(score)
for key in scores.keys():
scores[key] = round(100 * np.mean(scores[key]), 2)
return scores
def scorer(dataset, predictions, answers, all_classes):
# dataset = 'hotpotqa', 'musique', 'multifieldqa_en'
# predictions = [pred.str, ..., ]
# answers = [ [answer.str, ...], ... ]
# all_classes = None
#import ipdb; ipdb.set_trace() # all_classes=None for 'hotpotqa' dataset NOTE
total_score = 0.
for (prediction, ground_truths) in zip(predictions, answers):
score = 0.
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip('\n').split('\n')[0]
for ground_truth in ground_truths:
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
total_score += score
outscore = round(100 * total_score / len(predictions), 2)
print(dataset, outscore)
return outscore
if __name__ == '__main__':
#import ipdb; ipdb.set_trace()
args = parse_args()
scores = dict()
if args.e:
path = f"pred_e/{args.model}/"
else:
path = f"pred/{args.model}/" # 'pred/chatglm2-6b-32k/' NOTE
all_files = os.listdir(path) # 21 files
print("Evaluating on:", all_files)
for filename in all_files:
#import ipdb; ipdb.set_trace()
if not filename.endswith("jsonl"):
continue
predictions, answers, lengths = [], [], []
dataset = filename.split('.')[0] # 获取数据集的名字
if not dataset in ['musique', 'hotpotqa', 'multifieldqa_en']:
continue # TODO debug only
with open(f"{path}{filename}", "r", encoding="utf-8") as f:
for line in f: # 每一行,进行一次json的解析
data = json.loads(line)
predictions.append(data["pred"])
answers.append(data["answers"])
all_classes = data["all_classes"] # 这是属于被一次次重复赋值了
if "length" in data:
lengths.append(data["length"])
if args.e:
score = scorer_e(dataset, predictions, answers, lengths, all_classes)
else:
score = scorer(dataset, predictions, answers, all_classes) # NOTE 重要的计算得分的入口 TODO 1. dataset=具体的数据集的名字;predictions=list of str,预测结果; answers = list of list,参考答案; all_classes这是原本就带的,test in
scores[dataset] = score
if args.e:
out_path = f"pred_e/{args.model}/result.json"
else:
out_path = f"pred/{args.model}/result.json"
print(scores)
with open(out_path, "w") as f:
json.dump(scores, f, ensure_ascii=False, indent=4)