In [1]:
!python settings.py

Using device: cuda


In [None]:
import os
import pandas as pd
from datasets import Dataset
from tqdm.autonotebook import tqdm

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.readers       import InputExample
from sentence_transformers.models        import Transformer, Pooling
from sentence_transformers.losses        import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

from settings import MODEL_ID, MODEL_NAME, CACHE_DIR, OUTPUT_DIR, MAX_SEQ_LEN, EPOCHS, LR, BATCH_SIZE, DEVICE

os.environ['WANDB_DISABLED'] = 'true'

Using device: cuda


In [3]:
data = {
    'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),
    'train' : pd.read_parquet('data/processed/train_data.parquet'),
    'test'  : pd.read_parquet('data/processed/test_data.parquet')
}
for split in ['train', 'test']:
    data[split]['cid']          = data[split]['cid'].apply(lambda x: x.tolist())
    data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())
    
examples = {'train': [], 'test': []}

In [4]:
data['train'].head()

Unnamed: 0,question,context_list,qid,cid
0,Liên đoàn Luật sư Việt Nam là tổ chức xã hội –...,[“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư...,72600,[142820]
1,Tên hợp tác xã bị rơi vào trường hợp cấm thì c...,"[""Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã...",147562,"[27817, 72117]"
2,Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t...,"[""1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu...",142107,"[33215, 56201]"
3,Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ...,[BỘT CRAVATE\n...\nIV. CHUẨN BỊ\n1. Người thực...,77353,[148158]
4,Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ...,[Hộ sinh hạng IV - Mã số: V.08.06.16\n1. Nhiệm...,113090,[188132]


In [5]:
# Debug
for col in data['test'].columns:
    print(col, type(data['test'][col][0]))
    
print((data['test']['cid'].apply(len) == data['test']['context_list'].apply(len)).all())

question <class 'str'>
context_list <class 'list'>
qid <class 'numpy.int64'>
cid <class 'list'>
True


In [6]:
for split in ['train', 'test']:
    rows = list(data[split].itertuples(index=False))
    
    for row in tqdm(rows, desc=f"Processing {split}", unit='rows'):
        q = row.question
        for c in row.context_list:
            examples[split].append(InputExample(texts=[q, c]))

print(f"Training examples: {len(examples['train'])}") # Compare with sum(data['train']['cid'].apply(len))

Processing train:   0%|          | 0/89162 [00:00<?, ?rows/s]

Processing test:   0%|          | 0/29723 [00:00<?, ?rows/s]

Training examples: 99580


In [7]:
embedding_model = Transformer(MODEL_ID, max_seq_length=MAX_SEQ_LEN, cache_dir=CACHE_DIR)
pooling_model   = Pooling(
    embedding_model.get_word_embedding_dimension(), 
    pooling_mode_mean_tokens=True
)

model = SentenceTransformer(
    modules=[embedding_model, pooling_model], device=DEVICE, 
    cache_folder=CACHE_DIR,
    model_card_data=SentenceTransformerModelCardData(
        model_id=MODEL_ID, 
        model_name=MODEL_NAME, 
        language='vi',
        license='mit',
    )
)

In [None]:
loss = CachedMultipleNegativesRankingLoss(model=model)

args = SentenceTransformerTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LR,
    warmup_ratio=0.1,
    fp16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    logging_steps=100
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [9]:
def to_frame(ex_list):
    rows = [(ex.texts[0], ex.texts[1]) for ex in ex_list]
    return pd.DataFrame(rows, columns=['text_0', 'text_1'])

train_ds = Dataset.from_pandas(to_frame(examples['train']))

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    loss=loss,
)
trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
100,1.8827
200,0.4428
300,0.3564
400,0.2856
500,0.2445
600,0.2241
700,0.1938
800,0.1894
900,0.1432
1000,0.1432


TrainOutput(global_step=3890, training_loss=0.1604946916084976, metrics={'train_runtime': 12756.5123, 'train_samples_per_second': 39.031, 'train_steps_per_second': 0.305, 'total_flos': 0.0, 'train_loss': 0.1604946916084976, 'epoch': 5.0})

In [None]:
model.save_pretrained(OUTPUT_DIR)
# model.push_to_hub(
#     repo_id='YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs', 
#     commit_message='Update README.md',
#     exist_ok=True,
#     replace_model_card=False,
#     train_datasets=['tmnam20/BKAI-Legal-Retrieval']
# )