|
|
|
|
|
|
|
|
|
|
|
import mteb |
|
import argparse |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--model_name_or_path', |
|
type=str, |
|
help='Path to model checkpoint if HF', |
|
default='' |
|
) |
|
parser.add_argument( |
|
'--task', |
|
type=str, |
|
help='Task to evaluate, if None then whole VisualDocumentRetrieval', |
|
default=None, |
|
) |
|
args, extra_args = parser.parse_known_args() |
|
|
|
def convert_value(value): |
|
if value.replace('.', '', 1).isdigit(): |
|
return int(value) if '.' not in value else float(value) |
|
return value |
|
|
|
|
|
extra_args_dict = {extra_args[i].lstrip('-'): convert_value(extra_args[i + 1]) |
|
for i in range(0, len(extra_args), 2)} |
|
|
|
return args, extra_args_dict |
|
|
|
if __name__ == "__main__": |
|
args, add_args = get_args() |
|
|
|
model_meta = mteb.get_model_meta(args.model_name_or_path) |
|
model = model_meta.load_model() |
|
model.mteb_model_meta = model_meta |
|
if args.task is not None: |
|
tasks = [mteb.get_task(args.task)] |
|
else: |
|
tasks = mteb.get_benchmark("VisualDocumentRetrieval") |
|
evaluation = mteb.MTEB(tasks=tasks) |
|
results = evaluation.run(model, corpus_chunk_size=250) |