File size: 2,534 Bytes
5789ab1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import shutil
from pathlib import Path
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments
from colpali_engine.data.dataset import ColPaliEngineDataset
from colpali_engine.models import BiSiglip, BiSiglipProcessor
from colpali_engine.models.siglip.loss_bisiglip import BiSigLipEncoderLoss
from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
from colpali_engine.utils.dataset_transformation import load_train_set
config = ColModelTrainingConfig(
output_dir="./models/bisiglip-0804",
processor=BiSiglipProcessor.from_pretrained(
pretrained_model_name_or_path="./models/base_models/siglip2-base-patch32-256",
),
model=BiSiglip.from_pretrained(
pretrained_model_name_or_path="./models/base_models/siglip2-base-patch32-256",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
),
train_dataset=load_train_set(), # load_train_set_ir_negs,
eval_dataset=ColPaliEngineDataset(
load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
),
run_eval=True,
# loss_func=BiEncoderLoss(), # BiNegativeCELoss(in_batch_term_weight=0.5),
loss_func=BiSigLipEncoderLoss(),
tr_args=TrainingArguments(
output_dir=None,
overwrite_output_dir=True,
num_train_epochs=5,
per_device_train_batch_size=64,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
per_device_eval_batch_size=16,
eval_strategy="steps",
dataloader_num_workers=4,
save_steps=500,
logging_steps=10,
eval_steps=100,
warmup_steps=100,
learning_rate=2e-4,
save_total_limit=1,
),
peft_config=LoraConfig(
r=32,
lora_alpha=32,
lora_dropout=0.1,
init_lora_weights="gaussian",
bias="none",
task_type="FEATURE_EXTRACTION",
target_modules="((.*(text_model).*(k_proj|q_proj|v_proj|out_proj).*$)|logit_scale|logit_bias)", # noqa: E501,
),
)
if __name__ == "__main__":
# ensure output_dir exists
os.makedirs(config.output_dir, exist_ok=True)
# version this script by copying it into the output dir
current_script = Path(__file__)
shutil.copy(current_script, Path(config.output_dir) / current_script.name)
training_app = ColModelTraining(config)
training_app.train()
training_app.save()
|