nv-bschifferer's picture
adding license
780d274
# --------------------------------------------------------
# Copyright (c) 2025 NVIDIA
# Licensed under customized NSCLv1 [see LICENSE.md for details]
# --------------------------------------------------------
import os
import sys
import argparse
import json
import torch
import types
import pandas as pd
from typing import Annotated, Dict, List, Optional, cast
from datasets import load_dataset
from tqdm import tqdm
from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA, ViDoReEvaluatorBEIR
from vidore_benchmark.evaluation.interfaces import MetadataModel, ViDoReBenchmarkResults
from vidore_benchmark.utils.data_utils import get_datasets_from_collection
from typing import List, Optional, Union
from datetime import datetime
from importlib.metadata import version
import torch
from transformers import AutoModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_name_or_path',
type=str,
help='Path to model checkpoint if HF',
default=''
)
parser.add_argument(
'--model_revision',
type=str,
help='Commit Hash of the model as custom code is downloaded and executed',
default=None
)
parser.add_argument(
'--batch_size',
type=int,
help='Batch Size',
default=1
)
parser.add_argument(
'--savedir_datasets',
type=str,
help='Path to save results',
default='./default/'
)
args, extra_args = parser.parse_known_args()
def convert_value(value):
if value.replace('.', '', 1).isdigit(): # Check if it's a number (int or float)
return int(value) if '.' not in value else float(value)
return value # Keep as string if not numeric
# Convert extra_args list to dictionary with proper type conversion
extra_args_dict = {extra_args[i].lstrip('-'): convert_value(extra_args[i + 1])
for i in range(0, len(extra_args), 2)}
return args, extra_args_dict
if __name__ == "__main__":
args, add_args = get_args()
batch_size = int(args.batch_size)
savedir_datasets = args.savedir_datasets
if not os.path.exists(savedir_datasets):
os.makedirs(savedir_datasets)
vision_retriever = AutoModel.from_pretrained(
args.model_name_or_path,
device_map='cuda',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
revision=args.model_revision
).eval()
vidore_evaluator_qa = ViDoReEvaluatorQA(vision_retriever) # ViDoRe-v1
vidore_evaluator_beir = ViDoReEvaluatorBEIR(vision_retriever) # ViDoRe-v2
vidore_v2_original_commits = {
"vidore/synthetic_rse_restaurant_filtered_v1.0_multilingual": "c05a1da867bbedebef239d4aa96cab19160b3d88",
"vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered_multilingual": "9daa25abc1026f812834ca9a6b48b26ecbc61317",
"vidore/synthetics_economics_macro_economy_2024_filtered_v1.0_multilingual": "909aa23589332c30d7c6c9a89102fe2711cbb7a9",
"vidore/restaurant_esg_reports_beir": "d8830ba2d04b285cfb2532b95be3748214e305da",
"vidore/synthetic_rse_restaurant_filtered_v1.0": "4e52fd878318adb8799d0b6567f1134b3985b9d3",
"vidore/synthetic_economics_macro_economy_2024_filtered_v1.0": "b6ff628a0b3c49f074abdcc86d29bc0ec21fd0c1",
"vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered": "c1b889b051113c41e32960cd6b7c5ba5b27e39e2",
}
metrics_all: Dict[str, Dict[str, Optional[float]]] = {}
results_all: List[ViDoReBenchmarkResults] = [] # same as metrics_all but structured + with metadata
# Evaluate ViDoRe V1 with QA Datasets
dataset_names = get_datasets_from_collection("vidore/vidore-benchmark-667173f98e70a1c0fa4db00d")
for dataset_name in tqdm(dataset_names, desc="Evaluating dataset(s)"):
sanitized_dataset_name = dataset_name.replace("/", "_")
savepath_results = savedir_datasets + f"/{sanitized_dataset_name}_metrics.json"
if os.path.isfile(savepath_results):
saved_results = json.load(open(savepath_results, 'r'))
metrics = saved_results['metrics']
results = ViDoReBenchmarkResults(
metadata=MetadataModel(
timestamp=saved_results['metadata']['timestamp'],
vidore_benchmark_version=saved_results['metadata']['vidore_benchmark_version'],
),
metrics=saved_results['metrics'],
)
else:
metrics = {dataset_name: vidore_evaluator_qa.evaluate_dataset(
ds=load_dataset(dataset_name, split="test"),
batch_query=batch_size,
batch_passage=batch_size,
batch_score=128,
dataloader_prebatch_query=512,
dataloader_prebatch_passage=512,
)}
results = ViDoReBenchmarkResults(
metadata=MetadataModel(
timestamp=datetime.now(),
vidore_benchmark_version=version("vidore_benchmark"),
),
metrics={dataset_name: metrics[dataset_name]},
)
with open(str(savepath_results), "w", encoding="utf-8") as f:
f.write(results.model_dump_json(indent=4))
metrics_all.update(metrics)
print(f"nDCG@5 on {dataset_name}: {metrics[dataset_name]['ndcg_at_5']}")
results_all.append(results)
original_commits = {
"vidore/synthetic_rse_restaurant_filtered_v1.0_multilingual": "c05a1da867bbedebef239d4aa96cab19160b3d88",
"vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered_multilingual": "9daa25abc1026f812834ca9a6b48b26ecbc61317",
"vidore/synthetics_economics_macro_economy_2024_filtered_v1.0_multilingual": "909aa23589332c30d7c6c9a89102fe2711cbb7a9",
"vidore/restaurant_esg_reports_beir": "d8830ba2d04b285cfb2532b95be3748214e305da",
"vidore/synthetic_rse_restaurant_filtered_v1.0": "4e52fd878318adb8799d0b6567f1134b3985b9d3",
"vidore/synthetic_economics_macro_economy_2024_filtered_v1.0": "b6ff628a0b3c49f074abdcc86d29bc0ec21fd0c1",
"vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered": "c1b889b051113c41e32960cd6b7c5ba5b27e39e2",
}
for dataset_name, revision in vidore_v2_original_commits.items():
sanitized_dataset_name = dataset_name.replace("/", "_")
savepath_results = savedir_datasets + f"/{sanitized_dataset_name}_metrics.json"
if os.path.isfile(savepath_results):
saved_results = json.load(open(savepath_results, 'r'))
metrics = saved_results['metrics']
results = ViDoReBenchmarkResults(
metadata=MetadataModel(
timestamp=saved_results['metadata']['timestamp'],
vidore_benchmark_version=saved_results['metadata']['vidore_benchmark_version'],
),
metrics=saved_results['metrics'],
)
else:
ds = {
"corpus": load_dataset(dataset_name, name="corpus", split="test", revision=revision),
"queries": load_dataset(dataset_name, name="queries", split="test", revision=revision),
"qrels": load_dataset(dataset_name, name="qrels", split="test", revision=revision)
}
metrics = {dataset_name: vidore_evaluator_beir.evaluate_dataset(
ds=ds,
batch_query=batch_size,
batch_passage=batch_size,
batch_score=128,
dataloader_prebatch_query=512,
dataloader_prebatch_passage=512,
)}
results = ViDoReBenchmarkResults(
metadata=MetadataModel(
timestamp=datetime.now(),
vidore_benchmark_version=version("vidore_benchmark"),
),
metrics={dataset_name: metrics[dataset_name]},
)
with open(str(savepath_results), "w", encoding="utf-8") as f:
f.write(results.model_dump_json(indent=4))
metrics_all.update(metrics)
print(f"nDCG@5 on {dataset_name}: {metrics[dataset_name]['ndcg_at_5']}")
results_all.append(results)
results_merged = ViDoReBenchmarkResults.merge(results_all)
savepath_results_merged = savedir_datasets + f"/merged_metrics.json"
with open(str(savepath_results_merged), "w", encoding="utf-8") as f:
f.write(results_merged.model_dump_json(indent=4))