|
import os |
|
import shutil |
|
from pathlib import Path |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from peft import LoraConfig |
|
from transformers import TrainingArguments |
|
|
|
from colpali_engine.data.dataset import ColPaliEngineDataset |
|
from colpali_engine.models import BiSiglip, BiSiglipProcessor |
|
from colpali_engine.models.siglip.loss_bisiglip import BiSigLipEncoderLoss |
|
from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig |
|
from colpali_engine.utils.dataset_transformation import load_train_set |
|
|
|
config = ColModelTrainingConfig( |
|
output_dir="./models/bisiglip-0804", |
|
processor=BiSiglipProcessor.from_pretrained( |
|
pretrained_model_name_or_path="./models/base_models/siglip2-base-patch32-256", |
|
), |
|
model=BiSiglip.from_pretrained( |
|
pretrained_model_name_or_path="./models/base_models/siglip2-base-patch32-256", |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2", |
|
), |
|
train_dataset=load_train_set(), |
|
eval_dataset=ColPaliEngineDataset( |
|
load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image" |
|
), |
|
run_eval=True, |
|
|
|
loss_func=BiSigLipEncoderLoss(), |
|
tr_args=TrainingArguments( |
|
output_dir=None, |
|
overwrite_output_dir=True, |
|
num_train_epochs=5, |
|
per_device_train_batch_size=64, |
|
gradient_checkpointing=True, |
|
gradient_checkpointing_kwargs={"use_reentrant": False}, |
|
per_device_eval_batch_size=16, |
|
eval_strategy="steps", |
|
dataloader_num_workers=4, |
|
save_steps=500, |
|
logging_steps=10, |
|
eval_steps=100, |
|
warmup_steps=100, |
|
learning_rate=2e-4, |
|
save_total_limit=1, |
|
), |
|
peft_config=LoraConfig( |
|
r=32, |
|
lora_alpha=32, |
|
lora_dropout=0.1, |
|
init_lora_weights="gaussian", |
|
bias="none", |
|
task_type="FEATURE_EXTRACTION", |
|
target_modules="((.*(text_model).*(k_proj|q_proj|v_proj|out_proj).*$)|logit_scale|logit_bias)", |
|
), |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
|
current_script = Path(__file__) |
|
shutil.copy(current_script, Path(config.output_dir) / current_script.name) |
|
|
|
training_app = ColModelTraining(config) |
|
|
|
training_app.train() |
|
training_app.save() |
|
|