File size: 2,534 Bytes
5789ab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import shutil
from pathlib import Path

import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments

from colpali_engine.data.dataset import ColPaliEngineDataset
from colpali_engine.models import BiSiglip, BiSiglipProcessor
from colpali_engine.models.siglip.loss_bisiglip import BiSigLipEncoderLoss
from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
from colpali_engine.utils.dataset_transformation import load_train_set

config = ColModelTrainingConfig(
    output_dir="./models/bisiglip-0804",
    processor=BiSiglipProcessor.from_pretrained(
        pretrained_model_name_or_path="./models/base_models/siglip2-base-patch32-256",
    ),
    model=BiSiglip.from_pretrained(
        pretrained_model_name_or_path="./models/base_models/siglip2-base-patch32-256",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    ),
    train_dataset=load_train_set(),  # load_train_set_ir_negs,
    eval_dataset=ColPaliEngineDataset(
        load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
    ),
    run_eval=True,
    # loss_func=BiEncoderLoss(),  # BiNegativeCELoss(in_batch_term_weight=0.5),
    loss_func=BiSigLipEncoderLoss(),
    tr_args=TrainingArguments(
        output_dir=None,
        overwrite_output_dir=True,
        num_train_epochs=5,
        per_device_train_batch_size=64,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        per_device_eval_batch_size=16,
        eval_strategy="steps",
        dataloader_num_workers=4,
        save_steps=500,
        logging_steps=10,
        eval_steps=100,
        warmup_steps=100,
        learning_rate=2e-4,
        save_total_limit=1,
    ),
    peft_config=LoraConfig(
        r=32,
        lora_alpha=32,
        lora_dropout=0.1,
        init_lora_weights="gaussian",
        bias="none",
        task_type="FEATURE_EXTRACTION",
        target_modules="((.*(text_model).*(k_proj|q_proj|v_proj|out_proj).*$)|logit_scale|logit_bias)",  # noqa: E501,
    ),
)


if __name__ == "__main__":
    # ensure output_dir exists
    os.makedirs(config.output_dir, exist_ok=True)
    # version this script by copying it into the output dir
    current_script = Path(__file__)
    shutil.copy(current_script, Path(config.output_dir) / current_script.name)

    training_app = ColModelTraining(config)

    training_app.train()
    training_app.save()