{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import DataLoader\n", "from torch.optim import Optimizer\n", "import os\n", "from datetime import datetime\n", "from train.learner import DiffproLearner\n", "\n", "class TrainConfig:\n", "\n", " model: torch.nn.Module\n", " train_dl: DataLoader\n", " val_dl: DataLoader\n", " optimizer: Optimizer\n", "\n", " def __init__(self, params, param_scheduler, output_dir) -> None:\n", " self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " self.params = params\n", " self.param_scheduler = param_scheduler\n", " self.output_dir = output_dir\n", "\n", " def train(self):\n", " # collect and display total parameters\n", " total_parameters = sum(\n", " p.numel() for p in self.model.parameters() if p.requires_grad\n", " )\n", " print(f\"Total parameters: {total_parameters}\")\n", "\n", " # dealing with the output storing\n", " output_dir = self.output_dir\n", " if os.path.exists(f\"{output_dir}/chkpts/weights.pt\"):\n", " print(\"Checkpoint already exists.\")\n", " if input(\"Resume training? (y/n)\") != \"y\":\n", " return\n", " else:\n", " output_dir = f\"{output_dir}/{datetime.now().strftime('%m-%d_%H%M%S')}\"\n", " print(f\"Creating new log folder as {output_dir}\")\n", "\n", " # prepare the learner structure and parameters\n", " learner = DiffproLearner(\n", " output_dir, self.model, self.train_dl, self.val_dl, self.optimizer,\n", " self.params\n", " )\n", " learner.train(max_epoch=self.params.max_epoch)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from model import init_ldm_model, init_diff_pro_sdf\n", "from data.dataset_loading import load_datasets, create_dataloader\n", "\n", "WITH_RHYTHM = \"onset\"\n", "\n", "class LdmTrainConfig(TrainConfig):\n", "\n", " def __init__(self, params, output_dir, debug_mode=False) -> None:\n", " super().__init__(params, None, output_dir)\n", " self.debug_mode = debug_mode\n", " #self.use_autoreg_cond = use_autoreg_cond\n", " #self.use_external_cond = use_external_cond\n", " #self.mask_background = mask_background\n", " #self.random_pitch_aug = random_pitch_aug\n", "\n", " # create model\n", " self.ldm_model = init_ldm_model(params, debug_mode)\n", " self.model = init_diff_pro_sdf(self.ldm_model, params, self.device)\n", "\n", " # Create dataloader\n", " train_set = load_datasets(with_rhythm=WITH_RHYTHM)\n", " self.train_dl = create_dataloader(params.batch_size, train_set)\n", " self.val_dl = create_dataloader(params.batch_size, train_set) # we temporarily use train_set for validation\n", "\n", " # Create optimizer4\n", " self.optimizer = torch.optim.Adam(\n", " self.model.parameters(), lr=params.learning_rate\n", " )\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/music/chord_trainer/train/learner.py:45: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", " self.autocast = torch.cuda.amp.autocast(enabled=params.fp16)\n", "/home/music/chord_trainer/train/learner.py:46: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n", " self.scaler = torch.cuda.amp.GradScaler(enabled=params.fp16)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total parameters: 36755330\n", "Creating new log folder as results/test/09-13_171940\n", "{\n", " \"attention_levels\": [\n", " 2,\n", " 3\n", " ],\n", " \"batch_size\": 16,\n", " \"channel_multipliers\": [\n", " 1,\n", " 2,\n", " 4,\n", " 4\n", " ],\n", " \"channels\": 64,\n", " \"d_cond\": 2,\n", " \"fp16\": true,\n", " \"in_channels\": 4,\n", " \"latent_scaling_factor\": 0.18215,\n", " \"learning_rate\": 5e-05,\n", " \"linear_end\": 0.012,\n", " \"linear_start\": 0.00085,\n", " \"max_epoch\": 10,\n", " \"max_grad_norm\": 10,\n", " \"n_heads\": 4,\n", " \"n_res_blocks\": 2,\n", " \"n_steps\": 1000,\n", " \"out_channels\": 2,\n", " \"tf_layers\": 1\n", "}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████| 1141/1141 [00:51<00:00, 22.08it/s]\n", "Epoch 1: 100%|██████████| 1141/1141 [00:50<00:00, 22.43it/s]\n", "Epoch 2: 100%|██████████| 1141/1141 [00:47<00:00, 24.02it/s]\n", "Epoch 3: 100%|██████████| 1141/1141 [00:47<00:00, 24.07it/s]\n", "Epoch 4: 100%|██████████| 1141/1141 [01:04<00:00, 17.70it/s]\n", "Epoch 5: 100%|██████████| 1141/1141 [00:50<00:00, 22.42it/s]\n", "Epoch 6: 100%|██████████| 1141/1141 [00:50<00:00, 22.38it/s]\n", "Epoch 7: 100%|██████████| 1141/1141 [00:50<00:00, 22.38it/s]\n", "Epoch 8: 100%|██████████| 1141/1141 [01:05<00:00, 17.38it/s]\n", "Epoch 9: 100%|██████████| 1141/1141 [00:49<00:00, 22.83it/s]\n" ] } ], "source": [ "\n", "# Import necessary libraries\n", "from train.train_params import params_chord_cond, params_chord\n", "import os\n", "\n", "# Set the argument values directly\n", "args = {\n", " 'output_dir': 'results',\n", " 'uniform_pitch_shift': False,\n", " # 'debug': False,\n", " # 'data_source': \"lmd\",\n", " # 'load_chkpt_from': None,\n", " # 'dataset_path': \"data/lmd_sample/no_drum_sample\",\n", "}\n", "\n", "# Determine random pitch augmentation\n", "random_pitch_aug = not args['uniform_pitch_shift']\n", "\n", "# Generate the filename based on argument settings\n", "fn = 'test'\n", "\n", "# Set the output directory\n", "output_dir = os.path.join(args['output_dir'], fn)\n", "\n", "# Create the training configuration\n", "config = LdmTrainConfig(params_chord_cond, output_dir)\n", "\n", "config.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "music_demo", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }