In [1]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Optimizer
import os
from datetime import datetime
from train.learner import DiffproLearner

class TrainConfig:

 model: torch.nn.Module
 train_dl: DataLoader
 val_dl: DataLoader
 optimizer: Optimizer

 def __init__(self, params, param_scheduler, output_dir) -> None:
 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 self.params = params
 self.param_scheduler = param_scheduler
 self.output_dir = output_dir

 def train(self):
 # collect and display total parameters
 total_parameters = sum(
 p.numel() for p in self.model.parameters() if p.requires_grad
 )
 print(f"Total parameters: {total_parameters}")

 # dealing with the output storing
 output_dir = self.output_dir
 if os.path.exists(f"{output_dir}/chkpts/weights.pt"):
 print("Checkpoint already exists.")
 if input("Resume training? (y/n)") != "y":
 return
 else:
 output_dir = f"{output_dir}/{datetime.now().strftime('%m-%d_%H%M%S')}"
 print(f"Creating new log folder as {output_dir}")

 # prepare the learner structure and parameters
 learner = DiffproLearner(
 output_dir, self.model, self.train_dl, self.val_dl, self.optimizer,
 self.params
 )
 learner.train(max_epoch=self.params.max_epoch)


In [4]:
from model import init_ldm_model, init_diff_pro_sdf
from data.dataset_loading import load_datasets, create_dataloader

WITH_RHYTHM = "onset"

class LdmTrainConfig(TrainConfig):

 def __init__(self, params, output_dir, debug_mode=False) -> None:
 super().__init__(params, None, output_dir)
 self.debug_mode = debug_mode
 #self.use_autoreg_cond = use_autoreg_cond
 #self.use_external_cond = use_external_cond
 #self.mask_background = mask_background
 #self.random_pitch_aug = random_pitch_aug

 # create model
 self.ldm_model = init_ldm_model(params, debug_mode)
 self.model = init_diff_pro_sdf(self.ldm_model, params, self.device)

 # Create dataloader
 train_set = load_datasets(with_rhythm=WITH_RHYTHM)
 self.train_dl = create_dataloader(params.batch_size, train_set)
 self.val_dl = create_dataloader(params.batch_size, train_set) # we temporarily use train_set for validation

 # Create optimizer4
 self.optimizer = torch.optim.Adam(
 self.model.parameters(), lr=params.learning_rate
 )


In [5]:

# Import necessary libraries
from train.train_params import params_chord_cond, params_chord
import os

# Set the argument values directly
args = {
 'output_dir': 'results',
 'uniform_pitch_shift': False,
 # 'debug': False,
 # 'data_source': "lmd",
 # 'load_chkpt_from': None,
 # 'dataset_path': "data/lmd_sample/no_drum_sample",
}

# Determine random pitch augmentation
random_pitch_aug = not args['uniform_pitch_shift']

# Generate the filename based on argument settings
fn = 'test'

# Set the output directory
output_dir = os.path.join(args['output_dir'], fn)

# Create the training configuration
config = LdmTrainConfig(params_chord_cond, output_dir)

config.train()

 self.autocast = torch.cuda.amp.autocast(enabled=params.fp16)
 self.scaler = torch.cuda.amp.GradScaler(enabled=params.fp16)


Total parameters: 36755330
Creating new log folder as results/test/09-13_171940
{
 "attention_levels": [
 2,
 3
 ],
 "batch_size": 16,
 "channel_multipliers": [
 1,
 2,
 4,
 4
 ],
 "channels": 64,
 "d_cond": 2,
 "fp16": true,
 "in_channels": 4,
 "latent_scaling_factor": 0.18215,
 "learning_rate": 5e-05,
 "linear_end": 0.012,
 "linear_start": 0.00085,
 "max_epoch": 10,
 "max_grad_norm": 10,
 "n_heads": 4,
 "n_res_blocks": 2,
 "n_steps": 1000,
 "out_channels": 2,
 "tf_layers": 1
}


Epoch 0: 100%|██████████| 1141/1141 [00:51<00:00, 22.08it/s]
Epoch 1: 100%|██████████| 1141/1141 [00:50<00:00, 22.43it/s]
Epoch 2: 100%|██████████| 1141/1141 [00:47<00:00, 24.02it/s]
Epoch 3: 100%|██████████| 1141/1141 [00:47<00:00, 24.07it/s]
Epoch 4: 100%|██████████| 1141/1141 [01:04<00:00, 17.70it/s]
Epoch 5: 100%|██████████| 1141/1141 [00:50<00:00, 22.42it/s]
Epoch 6: 100%|██████████| 1141/1141 [00:50<00:00, 22.38it/s]
Epoch 7: 100%|██████████| 1141/1141 [00:50<00:00, 22.38it/s]
Epoch 8: 100%|██████████| 1141/1141 [01:05<00:00, 17.38it/s]
Epoch 9: 100%|██████████| 1141/1141 [00:49<00:00, 22.83it/s]
