|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
|
|
import argparse |
|
import json |
|
import torch |
|
import types |
|
import pandas as pd |
|
|
|
from typing import Annotated, Dict, List, Optional, cast |
|
|
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
|
|
from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA, ViDoReEvaluatorBEIR |
|
from vidore_benchmark.evaluation.interfaces import MetadataModel, ViDoReBenchmarkResults |
|
from vidore_benchmark.utils.data_utils import get_datasets_from_collection |
|
from typing import List, Optional, Union |
|
|
|
from datetime import datetime |
|
from importlib.metadata import version |
|
|
|
import torch |
|
from transformers import AutoModel |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--model_name_or_path', |
|
type=str, |
|
help='Path to model checkpoint if HF', |
|
default='' |
|
) |
|
parser.add_argument( |
|
'--model_revision', |
|
type=str, |
|
help='Commit Hash of the model as custom code is downloaded and executed', |
|
default=None |
|
) |
|
parser.add_argument( |
|
'--batch_size', |
|
type=int, |
|
help='Batch Size', |
|
default=1 |
|
) |
|
parser.add_argument( |
|
'--savedir_datasets', |
|
type=str, |
|
help='Path to save results', |
|
default='./default/' |
|
) |
|
args, extra_args = parser.parse_known_args() |
|
|
|
def convert_value(value): |
|
if value.replace('.', '', 1).isdigit(): |
|
return int(value) if '.' not in value else float(value) |
|
return value |
|
|
|
|
|
extra_args_dict = {extra_args[i].lstrip('-'): convert_value(extra_args[i + 1]) |
|
for i in range(0, len(extra_args), 2)} |
|
|
|
return args, extra_args_dict |
|
|
|
if __name__ == "__main__": |
|
args, add_args = get_args() |
|
batch_size = int(args.batch_size) |
|
savedir_datasets = args.savedir_datasets |
|
|
|
if not os.path.exists(savedir_datasets): |
|
os.makedirs(savedir_datasets) |
|
|
|
vision_retriever = AutoModel.from_pretrained( |
|
args.model_name_or_path, |
|
device_map='cuda', |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2", |
|
revision=args.model_revision |
|
).eval() |
|
|
|
vidore_evaluator_qa = ViDoReEvaluatorQA(vision_retriever) |
|
vidore_evaluator_beir = ViDoReEvaluatorBEIR(vision_retriever) |
|
|
|
vidore_v2_original_commits = { |
|
"vidore/synthetic_rse_restaurant_filtered_v1.0_multilingual": "c05a1da867bbedebef239d4aa96cab19160b3d88", |
|
"vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered_multilingual": "9daa25abc1026f812834ca9a6b48b26ecbc61317", |
|
"vidore/synthetics_economics_macro_economy_2024_filtered_v1.0_multilingual": "909aa23589332c30d7c6c9a89102fe2711cbb7a9", |
|
"vidore/restaurant_esg_reports_beir": "d8830ba2d04b285cfb2532b95be3748214e305da", |
|
"vidore/synthetic_rse_restaurant_filtered_v1.0": "4e52fd878318adb8799d0b6567f1134b3985b9d3", |
|
"vidore/synthetic_economics_macro_economy_2024_filtered_v1.0": "b6ff628a0b3c49f074abdcc86d29bc0ec21fd0c1", |
|
"vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered": "c1b889b051113c41e32960cd6b7c5ba5b27e39e2", |
|
} |
|
|
|
metrics_all: Dict[str, Dict[str, Optional[float]]] = {} |
|
results_all: List[ViDoReBenchmarkResults] = [] |
|
|
|
|
|
dataset_names = get_datasets_from_collection("vidore/vidore-benchmark-667173f98e70a1c0fa4db00d") |
|
for dataset_name in tqdm(dataset_names, desc="Evaluating dataset(s)"): |
|
sanitized_dataset_name = dataset_name.replace("/", "_") |
|
savepath_results = savedir_datasets + f"/{sanitized_dataset_name}_metrics.json" |
|
if os.path.isfile(savepath_results): |
|
saved_results = json.load(open(savepath_results, 'r')) |
|
metrics = saved_results['metrics'] |
|
results = ViDoReBenchmarkResults( |
|
metadata=MetadataModel( |
|
timestamp=saved_results['metadata']['timestamp'], |
|
vidore_benchmark_version=saved_results['metadata']['vidore_benchmark_version'], |
|
), |
|
metrics=saved_results['metrics'], |
|
) |
|
else: |
|
metrics = {dataset_name: vidore_evaluator_qa.evaluate_dataset( |
|
ds=load_dataset(dataset_name, split="test"), |
|
batch_query=batch_size, |
|
batch_passage=batch_size, |
|
batch_score=128, |
|
dataloader_prebatch_query=512, |
|
dataloader_prebatch_passage=512, |
|
)} |
|
results = ViDoReBenchmarkResults( |
|
metadata=MetadataModel( |
|
timestamp=datetime.now(), |
|
vidore_benchmark_version=version("vidore_benchmark"), |
|
), |
|
metrics={dataset_name: metrics[dataset_name]}, |
|
) |
|
with open(str(savepath_results), "w", encoding="utf-8") as f: |
|
f.write(results.model_dump_json(indent=4)) |
|
|
|
metrics_all.update(metrics) |
|
print(f"nDCG@5 on {dataset_name}: {metrics[dataset_name]['ndcg_at_5']}") |
|
results_all.append(results) |
|
|
|
original_commits = { |
|
"vidore/synthetic_rse_restaurant_filtered_v1.0_multilingual": "c05a1da867bbedebef239d4aa96cab19160b3d88", |
|
"vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered_multilingual": "9daa25abc1026f812834ca9a6b48b26ecbc61317", |
|
"vidore/synthetics_economics_macro_economy_2024_filtered_v1.0_multilingual": "909aa23589332c30d7c6c9a89102fe2711cbb7a9", |
|
"vidore/restaurant_esg_reports_beir": "d8830ba2d04b285cfb2532b95be3748214e305da", |
|
"vidore/synthetic_rse_restaurant_filtered_v1.0": "4e52fd878318adb8799d0b6567f1134b3985b9d3", |
|
"vidore/synthetic_economics_macro_economy_2024_filtered_v1.0": "b6ff628a0b3c49f074abdcc86d29bc0ec21fd0c1", |
|
"vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered": "c1b889b051113c41e32960cd6b7c5ba5b27e39e2", |
|
} |
|
|
|
for dataset_name, revision in vidore_v2_original_commits.items(): |
|
sanitized_dataset_name = dataset_name.replace("/", "_") |
|
savepath_results = savedir_datasets + f"/{sanitized_dataset_name}_metrics.json" |
|
if os.path.isfile(savepath_results): |
|
saved_results = json.load(open(savepath_results, 'r')) |
|
metrics = saved_results['metrics'] |
|
results = ViDoReBenchmarkResults( |
|
metadata=MetadataModel( |
|
timestamp=saved_results['metadata']['timestamp'], |
|
vidore_benchmark_version=saved_results['metadata']['vidore_benchmark_version'], |
|
), |
|
metrics=saved_results['metrics'], |
|
) |
|
else: |
|
ds = { |
|
"corpus": load_dataset(dataset_name, name="corpus", split="test", revision=revision), |
|
"queries": load_dataset(dataset_name, name="queries", split="test", revision=revision), |
|
"qrels": load_dataset(dataset_name, name="qrels", split="test", revision=revision) |
|
} |
|
metrics = {dataset_name: vidore_evaluator_beir.evaluate_dataset( |
|
ds=ds, |
|
batch_query=batch_size, |
|
batch_passage=batch_size, |
|
batch_score=128, |
|
dataloader_prebatch_query=512, |
|
dataloader_prebatch_passage=512, |
|
)} |
|
results = ViDoReBenchmarkResults( |
|
metadata=MetadataModel( |
|
timestamp=datetime.now(), |
|
vidore_benchmark_version=version("vidore_benchmark"), |
|
), |
|
metrics={dataset_name: metrics[dataset_name]}, |
|
) |
|
with open(str(savepath_results), "w", encoding="utf-8") as f: |
|
f.write(results.model_dump_json(indent=4)) |
|
|
|
metrics_all.update(metrics) |
|
print(f"nDCG@5 on {dataset_name}: {metrics[dataset_name]['ndcg_at_5']}") |
|
results_all.append(results) |
|
|
|
results_merged = ViDoReBenchmarkResults.merge(results_all) |
|
savepath_results_merged = savedir_datasets + f"/merged_metrics.json" |
|
|
|
with open(str(savepath_results_merged), "w", encoding="utf-8") as f: |
|
f.write(results_merged.model_dump_json(indent=4)) |
|
|
|
|
|
|
|
|