Commit
·
62f1377
0
Parent(s):
Initial commit with cleaned history
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +6 -0
- Aptfile +2 -0
- README.md +18 -0
- __pycache__/app.cpython-39.pyc +0 -0
- __pycache__/learner.cpython-39.pyc +0 -0
- __pycache__/params.cpython-39.pyc +0 -0
- __pycache__/train_params.cpython-39.pyc +0 -0
- app.py +508 -0
- filter_data/filter_by_instrument.ipynb +353 -0
- filter_data/midi_utils.py +139 -0
- generation/__pycache__/gen_utils.cpython-39.pyc +0 -0
- generation/gen_utils.py +302 -0
- model/__init__.py +59 -0
- model/__pycache__/__init__.cpython-39.pyc +0 -0
- model/__pycache__/latent_diffusion.cpython-39.pyc +0 -0
- model/__pycache__/model_sdf.cpython-39.pyc +0 -0
- model/__pycache__/sampler_sdf.cpython-39.pyc +0 -0
- model/architecture/__pycache__/unet.cpython-39.pyc +0 -0
- model/architecture/__pycache__/unet_attention.cpython-39.pyc +0 -0
- model/architecture/unet.py +364 -0
- model/architecture/unet_attention.py +321 -0
- model/latent_diffusion.py +222 -0
- model/model_sdf.py +55 -0
- model/sampler_sdf.py +538 -0
- output_0.mid +0 -0
- output_0.wav +3 -0
- packages.txt +1 -0
- piano_roll.png +3 -0
- requirements.txt +22 -0
- results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt +3 -0
- results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0 +0 -0
- results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2 +0 -0
- results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1 +0 -0
- results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4 +0 -0
- results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3 +0 -0
- results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json +1 -0
- rhythm_plot_0.png +3 -0
- runtime.txt +1 -0
- samples/control_vs_uncontrol/example_1_acc_control.jpg +3 -0
- samples/control_vs_uncontrol/example_1_acc_control.wav +3 -0
- samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg +3 -0
- samples/control_vs_uncontrol/example_1_acc_uncontrol.wav +3 -0
- samples/control_vs_uncontrol/example_1_mel_chd.jpg +3 -0
- samples/control_vs_uncontrol/example_1_mel_chd.wav +3 -0
- samples/control_vs_uncontrol/example_2_acc_control.jpg +3 -0
- samples/control_vs_uncontrol/example_2_acc_control.wav +3 -0
- samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg +3 -0
- samples/control_vs_uncontrol/example_2_acc_uncontrol.wav +3 -0
- samples/control_vs_uncontrol/example_2_mel_chd.jpg +3 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
Aptfile
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
lilypond
|
2 |
+
fluidsynth
|
README.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Interactive Symbolic Music Demo
|
3 |
+
emoji: 🖼
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.42.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
python_version: 3.9.19
|
11 |
+
license: mit
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
15 |
+
|
16 |
+
Please find mido.MidoFile inside the pretty_midi package, and set all arg "clip" to clip=True
|
17 |
+
|
18 |
+
use set_seed(42) in sampler_sdf.py, generation result from chord slice (index = 2) is a good example (a wrong note is shifted to a correct one)
|
__pycache__/app.cpython-39.pyc
ADDED
Binary file (12.6 kB). View file
|
|
__pycache__/learner.cpython-39.pyc
ADDED
Binary file (7.06 kB). View file
|
|
__pycache__/params.cpython-39.pyc
ADDED
Binary file (1.27 kB). View file
|
|
__pycache__/train_params.cpython-39.pyc
ADDED
Binary file (1.38 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pretty_midi
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import soundfile as sf
|
6 |
+
import cv2
|
7 |
+
import imageio
|
8 |
+
|
9 |
+
import sys
|
10 |
+
import subprocess
|
11 |
+
import os
|
12 |
+
import torch
|
13 |
+
from model import init_ldm_model
|
14 |
+
from model.model_sdf import Diffpro_SDF
|
15 |
+
from model.sampler_sdf import SDFSampler
|
16 |
+
|
17 |
+
import pickle
|
18 |
+
from train.train_params import params_chord_lsh_cond
|
19 |
+
from generation.gen_utils import *
|
20 |
+
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt'
|
23 |
+
chord_list = list(CHORD_DICTIONARY.keys())
|
24 |
+
|
25 |
+
def get_shape(file_path):
|
26 |
+
if file_path.endswith('.jpg'):
|
27 |
+
img = cv2.imread(file_path)
|
28 |
+
return img.shape # (height, width, channels)
|
29 |
+
|
30 |
+
elif file_path.endswith('.mp4'):
|
31 |
+
vid = imageio.get_reader(file_path)
|
32 |
+
return vid.get_meta_data()['size'] # (width, height)
|
33 |
+
|
34 |
+
else:
|
35 |
+
raise ValueError("Unsupported file type")
|
36 |
+
|
37 |
+
# Function to convert MIDI to WAV
|
38 |
+
def midi_to_wav(midi, output_file):
|
39 |
+
# Synthesize the waveform from the MIDI using pretty_midi
|
40 |
+
audio_data = midi.fluidsynth()
|
41 |
+
|
42 |
+
# Write the waveform to a WAV file
|
43 |
+
sf.write(output_file, audio_data, samplerate=44100)
|
44 |
+
|
45 |
+
def update_musescore_image(selected_prompt):
|
46 |
+
# Logic to return the correct image file based on the selected prompt
|
47 |
+
if selected_prompt == "example 1":
|
48 |
+
return "samples/diy_examples/example1/example1.jpg"
|
49 |
+
elif selected_prompt == "example 2":
|
50 |
+
return "samples/diy_examples/example2/example2.jpg"
|
51 |
+
elif selected_prompt == "example 3":
|
52 |
+
return "samples/diy_examples/example3/example3.jpg"
|
53 |
+
elif selected_prompt == "example 4":
|
54 |
+
return "samples/diy_examples/example4/example4.jpg"
|
55 |
+
elif selected_prompt == "example 5":
|
56 |
+
return "samples/diy_examples/example5/example5.jpg"
|
57 |
+
elif selected_prompt == "example 6":
|
58 |
+
return "samples/diy_examples/example6/example6.jpg"
|
59 |
+
|
60 |
+
|
61 |
+
# Model for generating music (example)
|
62 |
+
def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"):
|
63 |
+
|
64 |
+
ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False)
|
65 |
+
model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device)
|
66 |
+
sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False)
|
67 |
+
|
68 |
+
if mode=="example":
|
69 |
+
if prompt == "example 1":
|
70 |
+
background_condition = np.load("samples/diy_examples/example1/example1.npy")
|
71 |
+
tempo=70
|
72 |
+
elif prompt == "example 2":
|
73 |
+
background_condition = np.load("samples/diy_examples/example2/example2.npy")
|
74 |
+
elif prompt == "example 3":
|
75 |
+
background_condition = np.load("samples/diy_examples/example3/example3.npy")
|
76 |
+
elif prompt == "example 4":
|
77 |
+
background_condition = np.load("samples/diy_examples/example4/example4.npy")
|
78 |
+
|
79 |
+
background_condition = np.tile(background_condition, (num_samples,1,1,1))
|
80 |
+
background_condition = torch.Tensor(background_condition).to(device)
|
81 |
+
else:
|
82 |
+
background_condition = np.tile(prompt, (num_samples,1,1,1))
|
83 |
+
background_condition = torch.Tensor(background_condition).to(device)
|
84 |
+
|
85 |
+
if rhythm_control!="Yes":
|
86 |
+
background_condition[:,0:2] = background_condition[:,2:4]
|
87 |
+
# generate samples
|
88 |
+
output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples,
|
89 |
+
same_noise_all_measure=False, X0EditFunc=X0EditFunc,
|
90 |
+
use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False,
|
91 |
+
rhythm_control=rhythm_control)
|
92 |
+
output_x = torch.clamp(output_x, min=0, max=1)
|
93 |
+
output_x = output_x.cpu().numpy()
|
94 |
+
|
95 |
+
# save samples
|
96 |
+
for i in range(num_samples):
|
97 |
+
full_roll = extend_piano_roll(output_x[i]) # accompaniment roll
|
98 |
+
full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) # chord roll
|
99 |
+
full_lsh_roll = None
|
100 |
+
if background_condition.shape[1]>=6:
|
101 |
+
if background_condition[:,4:6,:,:].min()>=0:
|
102 |
+
full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy())
|
103 |
+
midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo)
|
104 |
+
# filename = f'DDIM_w_rhythm_onset_0to10_{i}_edit_x0_and_eps'+'.mid'
|
105 |
+
filename = f"output_{i}.mid"
|
106 |
+
save_midi(midi_file, filename)
|
107 |
+
subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate()
|
108 |
+
|
109 |
+
return 'output_0.mid', 'output_0.wav', midi_file
|
110 |
+
|
111 |
+
# Function to visualize MIDI notes
|
112 |
+
def visualize_midi(midi):
|
113 |
+
# Get piano roll from MIDI
|
114 |
+
roll = midi.get_piano_roll(fs=100)
|
115 |
+
|
116 |
+
# Plot the piano roll
|
117 |
+
plt.figure(figsize=(10, 4))
|
118 |
+
plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest')
|
119 |
+
plt.title("Piano Roll")
|
120 |
+
plt.xlabel("Time")
|
121 |
+
plt.ylabel("Pitch")
|
122 |
+
plt.colorbar()
|
123 |
+
|
124 |
+
# Save the plot as an image
|
125 |
+
output_image_path = "piano_roll.png"
|
126 |
+
plt.savefig(output_image_path)
|
127 |
+
return output_image_path
|
128 |
+
|
129 |
+
def plot_rhythm(rhythm_str, label):
|
130 |
+
if rhythm_str=="null rhythm":
|
131 |
+
return None
|
132 |
+
fig, ax = plt.subplots(figsize=(6, 2))
|
133 |
+
|
134 |
+
# Ensure it's a 16-bit string
|
135 |
+
rhythm_str = rhythm_str[:16]
|
136 |
+
|
137 |
+
# Convert string to a list of 0s and 1s
|
138 |
+
rhythm = [0 if bit=="0" else 1 for bit in rhythm_str]
|
139 |
+
|
140 |
+
# Define the x axis for the 16 sixteenth notes
|
141 |
+
x = list(range(1, 17)) # 1 to 16 sixteenth notes
|
142 |
+
|
143 |
+
# Plot each note (1 as filled circle, 0 as empty circle)
|
144 |
+
for i, bit in enumerate(rhythm):
|
145 |
+
if bit == 1:
|
146 |
+
ax.scatter(i + 1, 1, color='black', s=100, label="Note" if i == 0 else "")
|
147 |
+
else:
|
148 |
+
ax.scatter(i + 1, 1, edgecolor='black', facecolor='none', s=100, label="Rest" if i == 0 else "")
|
149 |
+
|
150 |
+
# Distinguish groups of 4 using vertical dashed lines (no solid grid lines)
|
151 |
+
for i in range(4, 17, 4):
|
152 |
+
ax.axvline(x=i + 0.5, color='grey', linestyle='--')
|
153 |
+
|
154 |
+
# Remove solid vertical grid lines by setting the grid off
|
155 |
+
ax.grid(False)
|
156 |
+
|
157 |
+
# Formatting the plot
|
158 |
+
ax.set_xlim(0.5, 16.5)
|
159 |
+
ax.set_ylim(0.8, 1.2)
|
160 |
+
ax.set_xticks(x)
|
161 |
+
ax.set_yticks([])
|
162 |
+
ax.set_xlabel("16th Notes")
|
163 |
+
ax.set_title("Rhythm Pattern")
|
164 |
+
|
165 |
+
fig.savefig(f'samples/diy_examples/rhythm_plot_{label}.png')
|
166 |
+
plt.close(fig)
|
167 |
+
return f'samples/diy_examples/rhythm_plot_{label}.png'
|
168 |
+
|
169 |
+
def adjust_rhythm_string(s):
|
170 |
+
# Truncate if longer than 16 characters
|
171 |
+
if len(s) > 16:
|
172 |
+
return s[:16]
|
173 |
+
# Pad with zeros if shorter than 16 characters
|
174 |
+
else:
|
175 |
+
return s.ljust(16, '0')
|
176 |
+
def rhythm_string_to_array(s):
|
177 |
+
# Ensure the string is 16 characters long
|
178 |
+
s = s[:16].ljust(16, '0') # Truncate or pad with '0' to make it 16 characters
|
179 |
+
# Convert to numpy array, treating non-'0' as '1'
|
180 |
+
arr = np.array([1 if char != '0' else 0 for char in s], dtype=int)
|
181 |
+
arr = arr*np.array([3,1,2,1,3,1,2,1,3,1,2,1,3,1,2,1])
|
182 |
+
print(arr)
|
183 |
+
return arr
|
184 |
+
|
185 |
+
# Gradio main function
|
186 |
+
def generate_from_example(prompt):
|
187 |
+
midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control=False)
|
188 |
+
piano_roll_image = visualize_midi(midi)
|
189 |
+
return audio_output, piano_roll_image
|
190 |
+
|
191 |
+
def generate_diy(m1_chord, m2_chord, m3_chord, m4_chord,
|
192 |
+
m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm, tempo):
|
193 |
+
print("\n\n\n",m1_chord,type(m1_chord), "\n\n\n")
|
194 |
+
test_chd_roll = np.concatenate([np.tile(CHORD_DICTIONARY[m1_chord], (16, 1)),
|
195 |
+
np.tile(CHORD_DICTIONARY[m2_chord], (16, 1)),
|
196 |
+
np.tile(CHORD_DICTIONARY[m3_chord], (16, 1)),
|
197 |
+
np.tile(CHORD_DICTIONARY[m4_chord], (16, 1))])
|
198 |
+
rhythms = [m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm]
|
199 |
+
|
200 |
+
chd_roll = np.concatenate([test_chd_roll[np.newaxis,:,:], test_chd_roll[np.newaxis,:,:]], axis=0)
|
201 |
+
|
202 |
+
chd_roll = circular_extend(chd_roll)
|
203 |
+
chd_roll = -chd_roll-1
|
204 |
+
|
205 |
+
real_chd_roll = chd_roll
|
206 |
+
|
207 |
+
melody_roll = -np.ones_like(chd_roll)
|
208 |
+
|
209 |
+
if "null rhythm" not in rhythms:
|
210 |
+
rhythm_full = []
|
211 |
+
for i in range(len(rhythms)):
|
212 |
+
rhythm = adjust_rhythm_string(rhythms[i])
|
213 |
+
rhythm = rhythm_string_to_array(rhythm)
|
214 |
+
rhythm_full.append(rhythm)
|
215 |
+
rhythm_full = np.concatenate(rhythm_full, axis=0)
|
216 |
+
|
217 |
+
onset_roll = test_chd_roll*rhythm_full[:, np.newaxis]
|
218 |
+
sustain_roll = np.zeros_like(onset_roll)
|
219 |
+
no_onset_pos = np.all(onset_roll == 0, axis=-1)
|
220 |
+
sustain_roll[no_onset_pos] = test_chd_roll[no_onset_pos]
|
221 |
+
|
222 |
+
real_chd_roll = np.concatenate([onset_roll[np.newaxis,:,:], sustain_roll[np.newaxis,:,:]], axis=0)
|
223 |
+
real_chd_roll = circular_extend(real_chd_roll)
|
224 |
+
|
225 |
+
background_condition = np.concatenate([real_chd_roll, chd_roll, melody_roll], axis=0)
|
226 |
+
|
227 |
+
midi_output, audio_output, midi = generate_music(background_condition, tempo, mode="diy")
|
228 |
+
piano_roll_image = visualize_midi(midi)
|
229 |
+
return midi_output, audio_output, piano_roll_image
|
230 |
+
|
231 |
+
# Prompt list
|
232 |
+
prompt_list = ["example 1", "example 2", "example 3", "example 4"]
|
233 |
+
rhythm_list = ["null rhythm", "1010101010101010", "1011101010111010","1111101010111010","1010001010101010","1010101000101010"]
|
234 |
+
|
235 |
+
|
236 |
+
custom_css = """
|
237 |
+
.custom-row1 {
|
238 |
+
background-color: #fdebd0;
|
239 |
+
padding: 10px;
|
240 |
+
border-radius: 5px;
|
241 |
+
}
|
242 |
+
.custom-row2 {
|
243 |
+
background-color: #d1f2eb;
|
244 |
+
padding: 10px;
|
245 |
+
border-radius: 5px;
|
246 |
+
}
|
247 |
+
.custom-grey {
|
248 |
+
background-color: #f0f0f0;
|
249 |
+
padding: 10px;
|
250 |
+
border-radius: 5px;
|
251 |
+
}
|
252 |
+
.custom-purple {
|
253 |
+
background-color: #d7bde2;
|
254 |
+
padding: 10px;
|
255 |
+
border-radius: 5px;
|
256 |
+
}
|
257 |
+
.audio_waveform-container {
|
258 |
+
display: none !important;
|
259 |
+
}
|
260 |
+
"""
|
261 |
+
|
262 |
+
|
263 |
+
with gr.Blocks(css=custom_css) as demo:
|
264 |
+
gr.Markdown("# <div style='text-align: center;font-size:40px'> Efficient Fine-Grained Guidance for Diffusion-Based Symbolic Music Generation <div style='text-align: center;'>")
|
265 |
+
|
266 |
+
gr.Markdown("<span style='font-size:25px;'> We introduce **Fine-Grained Guidance (FG)**, an efficient approach for symbolic music generation using **diffusion models**. Our method enhances guidance through:\
|
267 |
+
\n   (1) Fine-grained conditioning during training,\
|
268 |
+
\n   (2) Fine-grained control during the diffusion sampling process.\
|
269 |
+
\n In particular, **sampling control** ensures tonal accuracy in every generated sample, allowing our model to produce music with high precision, consistent rhythmic patterns,\
|
270 |
+
and even stylistic variations that align with user intent.<span>")
|
271 |
+
gr.Markdown("<span style='font-size:25px;color: red'> At the bottom of this page, we provide an interactive space for you to try our model by yourself! <span>")
|
272 |
+
|
273 |
+
|
274 |
+
gr.Markdown("\n\n\n")
|
275 |
+
gr.Markdown("# 1. Accompaniment Generation given Melody and Chord")
|
276 |
+
gr.Markdown("<span style='font-size:20px;'> In each example, the left column displays the melody provided as inputs to the model.\
|
277 |
+
The right column showcases music samples generated by the model.<span>")
|
278 |
+
|
279 |
+
with gr.Column(elem_classes="custom-row1"):
|
280 |
+
gr.Markdown("## Example 1")
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Column():
|
283 |
+
gr.Markdown("<span style='font-size:20px;'> With the following melody as condition <span>")
|
284 |
+
example1_mel = gr.Audio(value="samples/diy_examples/example1/example_1_mel.wav", label="Melody", scale = 5)
|
285 |
+
with gr.Column():
|
286 |
+
gr.Markdown("<span style='font-size:20px;'> Generated Accompaniments <span>")
|
287 |
+
example1_audio = gr.Audio(value="samples/diy_examples/example1/sample1.wav", label="Generated Accompaniment", scale = 5)
|
288 |
+
|
289 |
+
with gr.Column(elem_classes="custom-row2"):
|
290 |
+
gr.Markdown("## Example 2")
|
291 |
+
with gr.Row():
|
292 |
+
with gr.Column():
|
293 |
+
gr.Markdown("<span style='font-size:20px;'> With the following melody as condition <span>")
|
294 |
+
example1_mel = gr.Audio(value="samples/diy_examples/example2/example_2_mel.wav", label="Melody", scale = 5)
|
295 |
+
with gr.Column():
|
296 |
+
gr.Markdown("<span style='font-size:20px;'> Generated Accompaniments <span>")
|
297 |
+
example1_audio = gr.Audio(value="samples/diy_examples/example2/sample1.wav", label="Generated Accompaniment", scale = 5)
|
298 |
+
|
299 |
+
with gr.Column(elem_classes="custom-row1"):
|
300 |
+
gr.Markdown("## Example 3")
|
301 |
+
with gr.Row():
|
302 |
+
with gr.Column():
|
303 |
+
gr.Markdown("<span style='font-size:20px;'> With the following melody as condition <span>")
|
304 |
+
example1_mel = gr.Audio(value="samples/diy_examples/example3/example_3_mel.wav", label="Melody", scale = 5)
|
305 |
+
with gr.Column():
|
306 |
+
gr.Markdown("<span style='font-size:20px;'> Generated Accompaniments <span>")
|
307 |
+
example1_audio = gr.Audio(value="samples/diy_examples/example3/sample1.wav", label="Generated Accompaniment", scale = 5)
|
308 |
+
|
309 |
+
with gr.Column(elem_classes="custom-row2"):
|
310 |
+
gr.Markdown("## Example 4")
|
311 |
+
with gr.Row():
|
312 |
+
with gr.Column():
|
313 |
+
gr.Markdown("<span style='font-size:20px;'> With the following melody as condition <span>")
|
314 |
+
example1_mel = gr.Audio(value="samples/diy_examples/example4/example_4_mel.wav", label="Melody", scale = 5)
|
315 |
+
with gr.Column():
|
316 |
+
gr.Markdown("<span style='font-size:20px;'> Generated Accompaniments <span>")
|
317 |
+
example1_audio = gr.Audio(value="samples/diy_examples/example4/sample1.wav", label="Generated Accompaniment", scale = 5)
|
318 |
+
|
319 |
+
gr.HTML("<div style='height: 50px;'></div>")
|
320 |
+
gr.Markdown("# \n\n\n")
|
321 |
+
gr.Markdown("# 2. Style-Controlled Music Generation")
|
322 |
+
gr.Markdown("<span style='font-size:20px;'>Our approach enables controllable stylization in music generation. The sampling control is able to\
|
323 |
+
ensure that all generated notes strictly adhere to the target musical style's scale.\
|
324 |
+
This allows the model to generate music in specific styles — even those that were not present in \
|
325 |
+
the training data.<span>")
|
326 |
+
gr.Markdown("<span style='font-size:20px;'> Below, we demonstrate several examples of style-controlled music generation for:\
|
327 |
+
\n   (1) Dorian Mode: (with scale being A-B-C-D-E-F#-G);\
|
328 |
+
\n   (2) Chinese Style: (with scale being C-D-E-G-A). <span>")
|
329 |
+
|
330 |
+
with gr.Column(elem_classes="custom-row1"):
|
331 |
+
gr.Markdown("## Dorian Mode")
|
332 |
+
gr.Markdown("<span style='font-size:20px;'> The following are two examples generated by our method <span>")
|
333 |
+
with gr.Row():
|
334 |
+
with gr.Column(elem_classes="custom-grey"):
|
335 |
+
gr.Markdown("<span style='font-size:20px;'> Example 1 <span>")
|
336 |
+
example1_mel = gr.Audio(value="samples/different_styles/dorian_1.wav", scale = 5)
|
337 |
+
with gr.Column(elem_classes="custom-grey"):
|
338 |
+
gr.Markdown("<span style='font-size:20px;'> Example 2 <span>")
|
339 |
+
example1_audio = gr.Audio(value="samples/different_styles/dorian_2.wav", scale = 5)
|
340 |
+
|
341 |
+
with gr.Column(elem_classes="custom-row2"):
|
342 |
+
gr.Markdown("## Chinese Style")
|
343 |
+
gr.Markdown("<span style='font-size:20px;'> The following are two examples generated by our method <span>")
|
344 |
+
with gr.Row():
|
345 |
+
with gr.Column(elem_classes="custom-grey"):
|
346 |
+
gr.Markdown("<span style='font-size:20px;'> Example 1 <span>")
|
347 |
+
example1_mel = gr.Audio(value="samples/different_styles/chinese_1.wav", scale = 5)
|
348 |
+
with gr.Column(elem_classes="custom-grey"):
|
349 |
+
gr.Markdown("<span style='font-size:20px;'> Example 2 <span>")
|
350 |
+
example1_audio = gr.Audio(value="samples/different_styles/chinese_2.wav", scale = 5)
|
351 |
+
|
352 |
+
gr.HTML("<div style='height: 50px;'></div>")
|
353 |
+
gr.Markdown("\n\n\n")
|
354 |
+
gr.Markdown("# 3. Demonstrating the Effectiveness of Sampling Control by Comparison")
|
355 |
+
|
356 |
+
gr.Markdown("<span style='font-size:20px;'> We demonstrate the impact of sampling control in an **accompaniment generation** task, given a melody and chord progression.\
|
357 |
+
\n Each example generates accompaniments with and without sampling control using the same random seed, ensuring that the two results are comparable.\
|
358 |
+
\n Sampling control effectively removes or replaces harmonically conflicting notes, ensuring tonal consistency.\
|
359 |
+
\n We provide music sheets and audio files for both versions.<span>")
|
360 |
+
|
361 |
+
gr.Markdown("<span style='font-size:20px;'> Comparison of the results indicates that sampling control not only eliminates out-of-key notes but also enhances \
|
362 |
+
the overall coherence and harmonic consistency of the accompaniments.\
|
363 |
+
This highlights the effectiveness of our approach in maintaining musical coherence. <span>")
|
364 |
+
|
365 |
+
|
366 |
+
with gr.Column(elem_classes="custom-row1"):
|
367 |
+
gr.Markdown("## Example 1")
|
368 |
+
|
369 |
+
with gr.Row(elem_classes="custom-grey"):
|
370 |
+
gr.Markdown("<span style='font-size:20px;'> With pre-defined melody and chord as follows<span>")
|
371 |
+
with gr.Column(scale=2, min_width=10, ):
|
372 |
+
gr.Markdown("Melody Sheet")
|
373 |
+
example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
374 |
+
with gr.Column(scale=1, min_width=10, ):
|
375 |
+
gr.Markdown("Melody Audio")
|
376 |
+
example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
|
377 |
+
|
378 |
+
gr.Markdown("## Generated Accompaniments")
|
379 |
+
with gr.Row(elem_classes="custom-grey"):
|
380 |
+
gr.Markdown("<span style='font-size:20px;'> Without sampling control<span>")
|
381 |
+
with gr.Column(scale=2, min_width=300):
|
382 |
+
gr.Markdown("Music Sheet")
|
383 |
+
example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
384 |
+
with gr.Column(scale=1, min_width=150):
|
385 |
+
gr.Markdown("Audio")
|
386 |
+
example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
|
387 |
+
gr.Markdown("\n\n\n")
|
388 |
+
with gr.Row(elem_classes="custom-grey"):
|
389 |
+
with gr.Column(scale=1, min_width=150):
|
390 |
+
gr.Markdown("<span style='font-size:20px;'>With sampling control<span>")
|
391 |
+
with gr.Column(scale=2, min_width=300):
|
392 |
+
gr.Markdown("Music Sheet")
|
393 |
+
example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_acc_control.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
394 |
+
with gr.Column(scale=1, min_width=150):
|
395 |
+
gr.Markdown("Audio")
|
396 |
+
example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
|
397 |
+
|
398 |
+
|
399 |
+
with gr.Column(elem_classes="custom-row2"):
|
400 |
+
gr.Markdown("## Example 2")
|
401 |
+
|
402 |
+
with gr.Row(elem_classes="custom-grey"):
|
403 |
+
gr.Markdown("<span style='font-size:20px;'> With pre-defined melody and chord as follows<span>")
|
404 |
+
with gr.Column(scale=2, min_width=10, ):
|
405 |
+
gr.Markdown("Melody Sheet")
|
406 |
+
example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
407 |
+
with gr.Column(scale=1, min_width=10, ):
|
408 |
+
gr.Markdown("Melody Audio")
|
409 |
+
example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
|
410 |
+
|
411 |
+
gr.Markdown("## Generated Accompaniments")
|
412 |
+
with gr.Row(elem_classes="custom-grey"):
|
413 |
+
gr.Markdown("<span style='font-size:20px;'> Without sampling control<span>")
|
414 |
+
with gr.Column(scale=2, min_width=300):
|
415 |
+
gr.Markdown("Music Sheet")
|
416 |
+
example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
417 |
+
with gr.Column(scale=1, min_width=150):
|
418 |
+
gr.Markdown("Audio")
|
419 |
+
example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_acc_uncontrol.wav", scale = 1, min_width=10)
|
420 |
+
gr.Markdown("\n\n\n")
|
421 |
+
with gr.Row(elem_classes="custom-grey"):
|
422 |
+
with gr.Column(scale=1, min_width=150):
|
423 |
+
gr.Markdown("<span style='font-size:20px;'>With sampling control<span>")
|
424 |
+
with gr.Column(scale=2, min_width=300):
|
425 |
+
gr.Markdown("Music Sheet")
|
426 |
+
example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_acc_control.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
427 |
+
with gr.Column(scale=1, min_width=150):
|
428 |
+
gr.Markdown("Audio")
|
429 |
+
example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_acc_control.wav", scale = 1, min_width=10)
|
430 |
+
|
431 |
+
# with gr.Row():
|
432 |
+
# with gr.Column(scale=1, min_width=300, elem_classes="custom-row1"):
|
433 |
+
# gr.Markdown("## Example 1")
|
434 |
+
# gr.Markdown("<span style='font-size:20px;'> With pre-defined melody and chord as follows<span>")
|
435 |
+
# example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
436 |
+
# # Audio component to play the audio
|
437 |
+
# example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
|
438 |
+
|
439 |
+
# gr.Markdown("## Generated Accompaniments")
|
440 |
+
# with gr.Row():
|
441 |
+
# with gr.Column(scale=1, min_width=150):
|
442 |
+
# gr.Markdown("<span style='font-size:20px;'> without sampling control<span>")
|
443 |
+
# example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
444 |
+
# example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
|
445 |
+
# with gr.Column(scale=1, min_width=150):
|
446 |
+
# gr.Markdown("<span style='font-size:20px;'> with sampling control<span>")
|
447 |
+
# example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
448 |
+
# example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
|
449 |
+
# with gr.Column(scale=1, min_width=300, elem_classes="custom-row2"):
|
450 |
+
# gr.Markdown("## Example 2")
|
451 |
+
# gr.Markdown("<span style='font-size:20px;'> With pre-defined melody and chord as follows<span>")
|
452 |
+
# example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
453 |
+
# # Audio component to play the audio
|
454 |
+
# example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
|
455 |
+
|
456 |
+
# gr.Markdown("## Generated Accompaniments")
|
457 |
+
# with gr.Row():
|
458 |
+
# with gr.Column(scale=1, min_width=150):
|
459 |
+
# gr.Markdown("<span style='font-size:20px;'> without sampling control<span>")
|
460 |
+
# example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
461 |
+
# example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
|
462 |
+
# with gr.Column(scale=1, min_width=150):
|
463 |
+
# gr.Markdown("<span style='font-size:20px;'> with sampling control<span>")
|
464 |
+
# example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
|
465 |
+
# example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
|
466 |
+
|
467 |
+
|
468 |
+
|
469 |
+
|
470 |
+
|
471 |
+
''' Try to generate by users '''
|
472 |
+
gr.HTML("<div style='height: 50px;'></div>")
|
473 |
+
gr.Markdown("\n\n\n")
|
474 |
+
gr.Markdown("# <span style='color: red;'> 4. DIY in real time! </span>")
|
475 |
+
gr.Markdown("<span style='font-size:20px;'> Here is an interactive tool for you to try our model and generate by yourself.\
|
476 |
+
You can generate new accompaniments for given melody and chord conditions <span>")
|
477 |
+
|
478 |
+
gr.Markdown("### <span style='color: blue;'> Currently this space is supported with Hugging Face CPU and on average,\
|
479 |
+
it takes about 15 seconds to generate a 4-measure music piece. However, if other users are generating\
|
480 |
+
music at the same time, one may enter a queue, which could slow down the process significantly.\
|
481 |
+
If that happens, feel free to refresh the page. We appreciate your patience and understanding.\
|
482 |
+
</span>")
|
483 |
+
|
484 |
+
with gr.Column(elem_classes="custom-purple"):
|
485 |
+
gr.Markdown("### Select an example to generate music given melody and chord condition")
|
486 |
+
with gr.Row():
|
487 |
+
with gr.Column():
|
488 |
+
prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1")
|
489 |
+
gr.Markdown("### This is the melody to be conditioned on:")
|
490 |
+
condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition")
|
491 |
+
prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore)
|
492 |
+
|
493 |
+
with gr.Column():
|
494 |
+
generate_button = gr.Button("Generate")
|
495 |
+
gr.Markdown("### Generation results:")
|
496 |
+
audio_output = gr.Audio(label="Generated Music")
|
497 |
+
piano_roll_output = gr.Image(label="Generated Piano Roll")
|
498 |
+
|
499 |
+
generate_button.click(
|
500 |
+
fn=generate_from_example,
|
501 |
+
inputs=[prompt_selector],
|
502 |
+
outputs=[audio_output, piano_roll_output]
|
503 |
+
)
|
504 |
+
|
505 |
+
|
506 |
+
# Launch Gradio interface
|
507 |
+
if __name__ == "__main__":
|
508 |
+
demo.launch()
|
filter_data/filter_by_instrument.ipynb
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from midi_utils import is_timesig_44, gather_full_instr, gather_instr\n",
|
10 |
+
"from midi_utils import has_brass\n",
|
11 |
+
"from midi_utils import has_piano, has_string, has_guitar, has_drums\n"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 2,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [
|
19 |
+
{
|
20 |
+
"name": "stdout",
|
21 |
+
"output_type": "stream",
|
22 |
+
"text": [
|
23 |
+
"Program number 0: Acoustic Grand Piano\n",
|
24 |
+
"Program number 1: Bright Acoustic Piano\n",
|
25 |
+
"Program number 2: Electric Grand Piano\n",
|
26 |
+
"Program number 3: Honky-tonk Piano\n",
|
27 |
+
"Program number 4: Electric Piano 1\n",
|
28 |
+
"Program number 5: Electric Piano 2\n",
|
29 |
+
"Program number 6: Harpsichord\n",
|
30 |
+
"Program number 7: Clavinet\n",
|
31 |
+
"Program number 8: Celesta\n",
|
32 |
+
"Program number 9: Glockenspiel\n",
|
33 |
+
"Program number 10: Music Box\n",
|
34 |
+
"Program number 11: Vibraphone\n",
|
35 |
+
"Program number 12: Marimba\n",
|
36 |
+
"Program number 13: Xylophone\n",
|
37 |
+
"Program number 14: Tubular Bells\n",
|
38 |
+
"Program number 15: Dulcimer\n",
|
39 |
+
"Program number 16: Drawbar Organ\n",
|
40 |
+
"Program number 17: Percussive Organ\n",
|
41 |
+
"Program number 18: Rock Organ\n",
|
42 |
+
"Program number 19: Church Organ\n",
|
43 |
+
"Program number 20: Reed Organ\n",
|
44 |
+
"Program number 21: Accordion\n",
|
45 |
+
"Program number 22: Harmonica\n",
|
46 |
+
"Program number 23: Tango Accordion\n",
|
47 |
+
"Program number 24: Acoustic Guitar (nylon)\n",
|
48 |
+
"Program number 25: Acoustic Guitar (steel)\n",
|
49 |
+
"Program number 26: Electric Guitar (jazz)\n",
|
50 |
+
"Program number 27: Electric Guitar (clean)\n",
|
51 |
+
"Program number 28: Electric Guitar (muted)\n",
|
52 |
+
"Program number 29: Overdriven Guitar\n",
|
53 |
+
"Program number 30: Distortion Guitar\n",
|
54 |
+
"Program number 31: Guitar Harmonics\n",
|
55 |
+
"Program number 32: Acoustic Bass\n",
|
56 |
+
"Program number 33: Electric Bass (finger)\n",
|
57 |
+
"Program number 34: Electric Bass (pick)\n",
|
58 |
+
"Program number 35: Fretless Bass\n",
|
59 |
+
"Program number 36: Slap Bass 1\n",
|
60 |
+
"Program number 37: Slap Bass 2\n",
|
61 |
+
"Program number 38: Synth Bass 1\n",
|
62 |
+
"Program number 39: Synth Bass 2\n",
|
63 |
+
"Program number 40: Violin\n",
|
64 |
+
"Program number 41: Viola\n",
|
65 |
+
"Program number 42: Cello\n",
|
66 |
+
"Program number 43: Contrabass\n",
|
67 |
+
"Program number 44: Tremolo Strings\n",
|
68 |
+
"Program number 45: Pizzicato Strings\n",
|
69 |
+
"Program number 46: Orchestral Harp\n",
|
70 |
+
"Program number 47: Timpani\n",
|
71 |
+
"Program number 48: String Ensemble 1\n",
|
72 |
+
"Program number 49: String Ensemble 2\n",
|
73 |
+
"Program number 50: Synth Strings 1\n",
|
74 |
+
"Program number 51: Synth Strings 2\n",
|
75 |
+
"Program number 52: Choir Aahs\n",
|
76 |
+
"Program number 53: Voice Oohs\n",
|
77 |
+
"Program number 54: Synth Choir\n",
|
78 |
+
"Program number 55: Orchestra Hit\n",
|
79 |
+
"Program number 56: Trumpet\n",
|
80 |
+
"Program number 57: Trombone\n",
|
81 |
+
"Program number 58: Tuba\n",
|
82 |
+
"Program number 59: Muted Trumpet\n",
|
83 |
+
"Program number 60: French Horn\n",
|
84 |
+
"Program number 61: Brass Section\n",
|
85 |
+
"Program number 62: Synth Brass 1\n",
|
86 |
+
"Program number 63: Synth Brass 2\n",
|
87 |
+
"Program number 64: Soprano Sax\n",
|
88 |
+
"Program number 65: Alto Sax\n",
|
89 |
+
"Program number 66: Tenor Sax\n",
|
90 |
+
"Program number 67: Baritone Sax\n",
|
91 |
+
"Program number 68: Oboe\n",
|
92 |
+
"Program number 69: English Horn\n",
|
93 |
+
"Program number 70: Bassoon\n",
|
94 |
+
"Program number 71: Clarinet\n",
|
95 |
+
"Program number 72: Piccolo\n",
|
96 |
+
"Program number 73: Flute\n",
|
97 |
+
"Program number 74: Recorder\n",
|
98 |
+
"Program number 75: Pan Flute\n",
|
99 |
+
"Program number 76: Blown bottle\n",
|
100 |
+
"Program number 77: Shakuhachi\n",
|
101 |
+
"Program number 78: Whistle\n",
|
102 |
+
"Program number 79: Ocarina\n",
|
103 |
+
"Program number 80: Lead 1 (square)\n",
|
104 |
+
"Program number 81: Lead 2 (sawtooth)\n",
|
105 |
+
"Program number 82: Lead 3 (calliope)\n",
|
106 |
+
"Program number 83: Lead 4 chiff\n",
|
107 |
+
"Program number 84: Lead 5 (charang)\n",
|
108 |
+
"Program number 85: Lead 6 (voice)\n",
|
109 |
+
"Program number 86: Lead 7 (fifths)\n",
|
110 |
+
"Program number 87: Lead 8 (bass + lead)\n",
|
111 |
+
"Program number 88: Pad 1 (new age)\n",
|
112 |
+
"Program number 89: Pad 2 (warm)\n",
|
113 |
+
"Program number 90: Pad 3 (polysynth)\n",
|
114 |
+
"Program number 91: Pad 4 (choir)\n",
|
115 |
+
"Program number 92: Pad 5 (bowed)\n",
|
116 |
+
"Program number 93: Pad 6 (metallic)\n",
|
117 |
+
"Program number 94: Pad 7 (halo)\n",
|
118 |
+
"Program number 95: Pad 8 (sweep)\n",
|
119 |
+
"Program number 96: FX 1 (rain)\n",
|
120 |
+
"Program number 97: FX 2 (soundtrack)\n",
|
121 |
+
"Program number 98: FX 3 (crystal)\n",
|
122 |
+
"Program number 99: FX 4 (atmosphere)\n",
|
123 |
+
"Program number 100: FX 5 (brightness)\n",
|
124 |
+
"Program number 101: FX 6 (goblins)\n",
|
125 |
+
"Program number 102: FX 7 (echoes)\n",
|
126 |
+
"Program number 103: FX 8 (sci-fi)\n",
|
127 |
+
"Program number 104: Sitar\n",
|
128 |
+
"Program number 105: Banjo\n",
|
129 |
+
"Program number 106: Shamisen\n",
|
130 |
+
"Program number 107: Koto\n",
|
131 |
+
"Program number 108: Kalimba\n",
|
132 |
+
"Program number 109: Bagpipe\n",
|
133 |
+
"Program number 110: Fiddle\n",
|
134 |
+
"Program number 111: Shanai\n",
|
135 |
+
"Program number 112: Tinkle Bell\n",
|
136 |
+
"Program number 113: Agogo\n",
|
137 |
+
"Program number 114: Steel Drums\n",
|
138 |
+
"Program number 115: Woodblock\n",
|
139 |
+
"Program number 116: Taiko Drum\n",
|
140 |
+
"Program number 117: Melodic Tom\n",
|
141 |
+
"Program number 118: Synth Drum\n",
|
142 |
+
"Program number 119: Reverse Cymbal\n",
|
143 |
+
"Program number 120: Guitar Fret Noise\n",
|
144 |
+
"Program number 121: Breath Noise\n",
|
145 |
+
"Program number 122: Seashore\n",
|
146 |
+
"Program number 123: Bird Tweet\n",
|
147 |
+
"Program number 124: Telephone Ring\n",
|
148 |
+
"Program number 125: Helicopter\n",
|
149 |
+
"Program number 126: Applause\n",
|
150 |
+
"Program number 127: Gunshot\n"
|
151 |
+
]
|
152 |
+
}
|
153 |
+
],
|
154 |
+
"source": [
|
155 |
+
"import pretty_midi\n",
|
156 |
+
"\n",
|
157 |
+
"def display_all_instrument_names():\n",
|
158 |
+
" for program_number in range(128):\n",
|
159 |
+
" instrument_name = pretty_midi.program_to_instrument_name(program_number)\n",
|
160 |
+
" print(f\"Program number {program_number}: {instrument_name}\")\n",
|
161 |
+
"\n",
|
162 |
+
"# Call the function to display all instrument names\n",
|
163 |
+
"display_all_instrument_names()\n"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": 3,
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"#import mido\n",
|
173 |
+
"import pretty_midi\n",
|
174 |
+
"#from mido import KeySignatureError\n",
|
175 |
+
"\n",
|
176 |
+
"def filter_midi(file_path, max_track = 8, max_time = 300):\n",
|
177 |
+
" try:\n",
|
178 |
+
" pm = pretty_midi.PrettyMIDI(file_path)\n",
|
179 |
+
" except Exception as e:\n",
|
180 |
+
" return False\n",
|
181 |
+
" \n",
|
182 |
+
" # time signature 4/4\n",
|
183 |
+
" #if is_timesig_44(pm) == False:\n",
|
184 |
+
" #print(\"timesig\")\n",
|
185 |
+
" #return False\n",
|
186 |
+
" \n",
|
187 |
+
" # number of tracks\n",
|
188 |
+
" if len(pm.instruments)>max_track:\n",
|
189 |
+
" #print(\"tracks\")\n",
|
190 |
+
" return False\n",
|
191 |
+
" \n",
|
192 |
+
" # length of song\n",
|
193 |
+
" #if pm.get_end_time()>max_time:\n",
|
194 |
+
" #print(\"length\")\n",
|
195 |
+
" #return False\n",
|
196 |
+
"\n",
|
197 |
+
" # now filter by instruments\n",
|
198 |
+
"\n",
|
199 |
+
" # filter out the ones without drums\n",
|
200 |
+
" #if has_drums(pm)==False:\n",
|
201 |
+
" #print(\"no drums\")\n",
|
202 |
+
" #return False\n",
|
203 |
+
" \n",
|
204 |
+
" # filter out the ones with brass\n",
|
205 |
+
" instr = gather_instr(pm)\n",
|
206 |
+
" if has_brass(instr):\n",
|
207 |
+
" #print(\"has brass\")\n",
|
208 |
+
" return False\n",
|
209 |
+
" \n",
|
210 |
+
" # filter out the ones without full string and piano\n",
|
211 |
+
" full_instr = gather_full_instr(pm, threshold=0.7)\n",
|
212 |
+
" \n",
|
213 |
+
" if has_piano(full_instr)== False:\n",
|
214 |
+
" #print(\"no piano\")\n",
|
215 |
+
" return False\n",
|
216 |
+
"\n",
|
217 |
+
" if has_guitar(full_instr)== False:\n",
|
218 |
+
" return False\n",
|
219 |
+
" \n",
|
220 |
+
" #if has_string(full_instr)== False:\n",
|
221 |
+
" #print(\"no string\")\n",
|
222 |
+
" #return False\n",
|
223 |
+
" \n",
|
224 |
+
" return True\n"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": 4,
|
230 |
+
"metadata": {},
|
231 |
+
"outputs": [],
|
232 |
+
"source": [
|
233 |
+
"#import pretty_midi\n",
|
234 |
+
"#midi_path = '/home/ubuntu/lakh-pianoroll-dataset/data/samples_with_strings/c10e69ec7f8212c68ff3658cceef5b9b.mid'\n",
|
235 |
+
"#pm = pretty_midi.PrettyMIDI(midi_path)\n",
|
236 |
+
"#mid = mido.MidiFile(midi_path)"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "code",
|
241 |
+
"execution_count": 5,
|
242 |
+
"metadata": {},
|
243 |
+
"outputs": [],
|
244 |
+
"source": [
|
245 |
+
"import shutil\n",
|
246 |
+
"import glob\n",
|
247 |
+
"import os\n",
|
248 |
+
"from tqdm import tqdm\n",
|
249 |
+
"\n",
|
250 |
+
"def find_midi_files_upto(root_dir, sample_size):\n",
|
251 |
+
" midi_files = glob.glob(os.path.join(root_dir, '**/*.mid'), recursive=True)\n",
|
252 |
+
" matching_files = []\n",
|
253 |
+
" match_count = 0\n",
|
254 |
+
"\n",
|
255 |
+
" pbar = tqdm(total=len(midi_files), desc=\"Processing MIDI files\")\n",
|
256 |
+
" for midi_file in midi_files:\n",
|
257 |
+
" if filter_midi(midi_file):\n",
|
258 |
+
" matching_files.append(midi_file)\n",
|
259 |
+
" match_count += 1\n",
|
260 |
+
" pbar.set_postfix({'Matching files': match_count})\n",
|
261 |
+
" if match_count >= sample_size:\n",
|
262 |
+
" break\n",
|
263 |
+
" pbar.update(1)\n",
|
264 |
+
" pbar.close()\n",
|
265 |
+
" return matching_files\n",
|
266 |
+
"\n",
|
267 |
+
"def copy_files(files, target_dir):\n",
|
268 |
+
" os.makedirs(target_dir, exist_ok=True)\n",
|
269 |
+
" for file in files:\n",
|
270 |
+
" shutil.copy(file, target_dir)"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": 6,
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [
|
278 |
+
{
|
279 |
+
"name": "stderr",
|
280 |
+
"output_type": "stream",
|
281 |
+
"text": [
|
282 |
+
"Processing MIDI files: 0%| | 14/178561 [00:02<11:02:01, 4.49it/s]/home/ubuntu/.local/lib/python3.10/site-packages/pretty_midi/pretty_midi.py:100: RuntimeWarning: Tempo, Key or Time signature change events found on non-zero tracks. This is not a valid type 0 or type 1 MIDI file. Tempo, Key or Time Signature may be wrong.\n",
|
283 |
+
" warnings.warn(\n",
|
284 |
+
"Processing MIDI files: 0%| | 148/178561 [00:18<4:21:09, 11.39it/s, Matching files=4] "
|
285 |
+
]
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"ename": "KeyboardInterrupt",
|
289 |
+
"evalue": "",
|
290 |
+
"output_type": "error",
|
291 |
+
"traceback": [
|
292 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
293 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
294 |
+
"\u001b[0;32m/tmp/ipykernel_47712/2263633239.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mROOT_DIR\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/ubuntu/lakh-pianoroll-dataset/data/lmd/lmd_full'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msample_files\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_midi_files_upto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mROOT_DIR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1500\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtgt_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/ubuntu/lakh-pianoroll-dataset/data/instrument_samples'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcopy_files\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_files\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtgt_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
295 |
+
"\u001b[0;32m/tmp/ipykernel_47712/2439092612.py\u001b[0m in \u001b[0;36mfind_midi_files_upto\u001b[0;34m(root_dir, sample_size)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mpbar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_files\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Processing MIDI files\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmidi_file\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmidi_files\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mfilter_midi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mmatching_files\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mmatch_count\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
296 |
+
"\u001b[0;32m/tmp/ipykernel_47712/3402906515.py\u001b[0m in \u001b[0;36mfilter_midi\u001b[0;34m(file_path, max_track, max_time)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfilter_midi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_track\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m300\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mpm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpretty_midi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPrettyMIDI\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
297 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/pretty_midi/pretty_midi.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, midi_file, resolution, initial_tempo)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstring_types\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;31m# If a string was given, pass it as the string filename\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mmidi_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmido\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMidiFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# Otherwise, try passing it in as a file pointer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
298 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filename, file, type, ticks_per_beat, charset, debug, clip, tracks)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilename\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 321\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
299 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(self, infile)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0m_dbg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Track {i}:'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 371\u001b[0;31m self.tracks.append(read_track(infile,\n\u001b[0m\u001b[1;32m 372\u001b[0m \u001b[0mdebug\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m clip=self.clip))\n",
|
300 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36mread_track\u001b[0;34m(infile, debug, clip)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_sysex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 218\u001b[0;31m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus_byte\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpeek_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0mtrack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
301 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36mread_message\u001b[0;34m(infile, status_byte, peek_data, delta, clip)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mOSError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data byte must be in range 0..127'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 133\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mMessage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdata_bytes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdelta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
302 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/messages.py\u001b[0m in \u001b[0;36mfrom_bytes\u001b[0;34m(cl, data, time)\u001b[0m\n\u001b[1;32m 161\u001b[0m \"\"\"\n\u001b[1;32m 162\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__new__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0mmsgdict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecode_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'data'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmsgdict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mmsgdict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSysexData\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsgdict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
303 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36mdecode_message\u001b[0;34m(msg_bytes, time, check)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_SPECIAL_CASES\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_decode_data_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
304 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36m_decode_data_bytes\u001b[0;34m(status_byte, data, spec)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# TODO: better name than args?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'value_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'channel'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
305 |
+
"\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# TODO: better name than args?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'value_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'channel'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
306 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"name": "stderr",
|
311 |
+
"output_type": "stream",
|
312 |
+
"text": [
|
313 |
+
"Processing MIDI files: 0%| | 150/178561 [00:30<4:21:08, 11.39it/s, Matching files=4]"
|
314 |
+
]
|
315 |
+
}
|
316 |
+
],
|
317 |
+
"source": [
|
318 |
+
"ROOT_DIR = '/home/ubuntu/lakh-pianoroll-dataset/data/lmd/lmd_full'\n",
|
319 |
+
"sample_files = find_midi_files_upto(ROOT_DIR, sample_size=1500)\n",
|
320 |
+
"tgt_dir = '/home/ubuntu/lakh-pianoroll-dataset/data/instrument_samples'\n",
|
321 |
+
"copy_files(sample_files, tgt_dir)"
|
322 |
+
]
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "code",
|
326 |
+
"execution_count": null,
|
327 |
+
"metadata": {},
|
328 |
+
"outputs": [],
|
329 |
+
"source": []
|
330 |
+
}
|
331 |
+
],
|
332 |
+
"metadata": {
|
333 |
+
"kernelspec": {
|
334 |
+
"display_name": "Python 3",
|
335 |
+
"language": "python",
|
336 |
+
"name": "python3"
|
337 |
+
},
|
338 |
+
"language_info": {
|
339 |
+
"codemirror_mode": {
|
340 |
+
"name": "ipython",
|
341 |
+
"version": 3
|
342 |
+
},
|
343 |
+
"file_extension": ".py",
|
344 |
+
"mimetype": "text/x-python",
|
345 |
+
"name": "python",
|
346 |
+
"nbconvert_exporter": "python",
|
347 |
+
"pygments_lexer": "ipython3",
|
348 |
+
"version": "3.10.12"
|
349 |
+
}
|
350 |
+
},
|
351 |
+
"nbformat": 4,
|
352 |
+
"nbformat_minor": 2
|
353 |
+
}
|
filter_data/midi_utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import pretty_midi
|
3 |
+
|
4 |
+
PIANO = (0,1)
|
5 |
+
STRING = (40,41,42,48,49,50,51)
|
6 |
+
GUITAR = (24,25,27)
|
7 |
+
BRASS = (56,57,58,59,61,62,63,64,65,66,67)
|
8 |
+
|
9 |
+
############################## For single track analysis
|
10 |
+
|
11 |
+
def calculate_active_duration(instrument):
|
12 |
+
# Collect all note intervals
|
13 |
+
intervals = [(note.start, note.end) for note in instrument.notes]
|
14 |
+
|
15 |
+
# Sort intervals by start time
|
16 |
+
intervals.sort()
|
17 |
+
|
18 |
+
# Merge overlapping intervals and calculate the active duration
|
19 |
+
active_duration = 0
|
20 |
+
current_start, current_end = intervals[0]
|
21 |
+
|
22 |
+
for start, end in intervals[1:]:
|
23 |
+
if start <= current_end: # There is an overlap
|
24 |
+
current_end = max(current_end, end)
|
25 |
+
else: # No overlap, add the previous interval duration and start a new interval
|
26 |
+
active_duration += current_end - current_start
|
27 |
+
current_start, current_end = start, end
|
28 |
+
|
29 |
+
# Add the last interval
|
30 |
+
active_duration += current_end - current_start
|
31 |
+
|
32 |
+
return active_duration
|
33 |
+
|
34 |
+
def is_full_track(midi, instrument, threshold=0.6):
|
35 |
+
# Calculate the total duration of the track
|
36 |
+
total_duration = midi.get_end_time()
|
37 |
+
|
38 |
+
# Calculate the active duration (time during which notes are playing)
|
39 |
+
active_duration = calculate_active_duration(instrument)
|
40 |
+
|
41 |
+
# Calculate the percentage of active duration
|
42 |
+
active_percentage = active_duration / total_duration
|
43 |
+
|
44 |
+
#print(f"Total duration: {total_duration:.2f} seconds")
|
45 |
+
#print(f"Active duration: {active_duration:.2f} seconds")
|
46 |
+
#print(f"Active percentage: {active_percentage:.2%}")
|
47 |
+
|
48 |
+
# Check if the active duration meets or exceeds the threshold
|
49 |
+
return active_percentage >= threshold
|
50 |
+
|
51 |
+
#################################### For gathering full tracks
|
52 |
+
|
53 |
+
def gather_instr(pm):
|
54 |
+
# Gather all the program indexes of the instrument tracks
|
55 |
+
program_indexes = [instrument.program for instrument in pm.instruments]
|
56 |
+
|
57 |
+
# Sort the program indexes
|
58 |
+
program_indexes.sort()
|
59 |
+
|
60 |
+
# Convert the sorted list of program indexes to a tuple
|
61 |
+
program_indexes_tuple = tuple(program_indexes)
|
62 |
+
return program_indexes_tuple
|
63 |
+
|
64 |
+
def gather_full_instr(pm, threshold = 0.6):
|
65 |
+
# Gather all the program indexes of the instrument tracks that exceed the duration threshold
|
66 |
+
program_indexes = []
|
67 |
+
for instrument in pm.instruments:
|
68 |
+
if is_full_track(pm, instrument, threshold):
|
69 |
+
program_indexes.append(instrument.program)
|
70 |
+
program_indexes.sort()
|
71 |
+
# Convert the list of program indexes to a tuple
|
72 |
+
program_indexes_tuple = tuple(program_indexes)
|
73 |
+
|
74 |
+
return program_indexes_tuple
|
75 |
+
|
76 |
+
####################################### For finding instruments
|
77 |
+
|
78 |
+
def has_intersection(wanted_instr, exist_instr):
|
79 |
+
# Convert both the tuple and the group of integers to sets
|
80 |
+
tuple_set = set(wanted_instr)
|
81 |
+
group_set = set(exist_instr)
|
82 |
+
|
83 |
+
# Check if there is any intersection
|
84 |
+
return not tuple_set.isdisjoint(group_set)
|
85 |
+
|
86 |
+
# The functions checking instruments in the midi file tracks
|
87 |
+
def has_piano(exist_instr):
|
88 |
+
wanted_instr = PIANO
|
89 |
+
return has_intersection(wanted_instr, exist_instr)
|
90 |
+
|
91 |
+
def has_string(exist_instr):
|
92 |
+
wanted_instr = STRING
|
93 |
+
return has_intersection(wanted_instr, exist_instr)
|
94 |
+
|
95 |
+
def has_guitar(exist_instr):
|
96 |
+
wanted_instr = GUITAR
|
97 |
+
return has_intersection(wanted_instr, exist_instr)
|
98 |
+
|
99 |
+
def has_brass(exist_instr):
|
100 |
+
wanted_instr = BRASS
|
101 |
+
return has_intersection(wanted_instr, exist_instr)
|
102 |
+
|
103 |
+
def has_drums(pm):
|
104 |
+
for instrument in pm.instruments:
|
105 |
+
if instrument.is_drum:
|
106 |
+
return True
|
107 |
+
return False
|
108 |
+
|
109 |
+
|
110 |
+
def print_track_details(instrument):
|
111 |
+
"""
|
112 |
+
For visualizing the information in a midi track
|
113 |
+
"""
|
114 |
+
print(f"Instrument: {pretty_midi.program_to_instrument_name(instrument.program)}")
|
115 |
+
print(f"Is drum: {instrument.is_drum}")
|
116 |
+
|
117 |
+
print("\nNotes:")
|
118 |
+
for note in instrument.notes:
|
119 |
+
print(f"Start: {note.start:.2f}, End: {note.end:.2f}, Pitch: {note.pitch}, Velocity: {note.velocity}")
|
120 |
+
|
121 |
+
print("\nControl Changes:")
|
122 |
+
for cc in instrument.control_changes:
|
123 |
+
print(f"Time: {cc.time:.2f}, Number: {cc.number}, Value: {cc.value}")
|
124 |
+
|
125 |
+
print("\nPitch Bends:")
|
126 |
+
for pb in instrument.pitch_bends:
|
127 |
+
print(f"Time: {pb.time:.2f}, Pitch: {pb.pitch}")
|
128 |
+
|
129 |
+
def is_timesig_44(pm):
|
130 |
+
for time_signature in pm.time_signature_changes:
|
131 |
+
if time_signature.numerator != 4 or time_signature.denominator != 4:
|
132 |
+
return False
|
133 |
+
return True
|
134 |
+
|
135 |
+
def is_timesig_34(pm):
|
136 |
+
for time_signature in pm.time_signature_changes:
|
137 |
+
if time_signature.numerator != 4 or time_signature.denominator != 4:
|
138 |
+
return False
|
139 |
+
return True
|
generation/__pycache__/gen_utils.cpython-39.pyc
ADDED
Binary file (10.1 kB). View file
|
|
generation/gen_utils.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import pretty_midi as pm
|
4 |
+
|
5 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
+
|
7 |
+
CHORD_DICTIONARY = {
|
8 |
+
"C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]),
|
9 |
+
"C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]),
|
10 |
+
"D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]),
|
11 |
+
"Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]),
|
12 |
+
"E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]),
|
13 |
+
"F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]),
|
14 |
+
"F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]),
|
15 |
+
"G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]),
|
16 |
+
"Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]),
|
17 |
+
"A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]),
|
18 |
+
"Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]),
|
19 |
+
"B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]),
|
20 |
+
|
21 |
+
"c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]),
|
22 |
+
"c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]),
|
23 |
+
"d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]),
|
24 |
+
"eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]),
|
25 |
+
"e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]),
|
26 |
+
"f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]),
|
27 |
+
"f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]),
|
28 |
+
"g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]),
|
29 |
+
"g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]),
|
30 |
+
"a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]),
|
31 |
+
"bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]),
|
32 |
+
"b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]),
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True):
|
37 |
+
'''
|
38 |
+
piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch
|
39 |
+
num_notes_onset: a tensor with shape (batch_size, length)
|
40 |
+
mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma
|
41 |
+
reduce_extra_notes: True if want to reduce extra notes
|
42 |
+
'''
|
43 |
+
########## for those greater than the threshold, if num of notes exceed num_notes[i],
|
44 |
+
########## will keep the first ones and set others to threshold
|
45 |
+
print("Coming")
|
46 |
+
# we only edit onset
|
47 |
+
onset_roll = piano_roll_full[:,0,:,:]
|
48 |
+
mask = mask_full[:,0,:,:]
|
49 |
+
shape = onset_roll.shape
|
50 |
+
|
51 |
+
onset_roll = onset_roll.reshape(-1,shape[-1])
|
52 |
+
mask = mask.reshape(-1,shape[-1])
|
53 |
+
num_notes = num_notes_onset.reshape(-1)
|
54 |
+
|
55 |
+
reduce_note_threshold = 0.499
|
56 |
+
increase_note_threshold = 0.501
|
57 |
+
|
58 |
+
# Initialize a tensor to store the modified values
|
59 |
+
final_onset_roll = onset_roll.clone()
|
60 |
+
|
61 |
+
########### if number of notes > required, remove the extra notes ###############
|
62 |
+
if reduce_extra_notes:
|
63 |
+
threshold_mask = onset_roll > reduce_note_threshold
|
64 |
+
# Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection
|
65 |
+
values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
|
66 |
+
|
67 |
+
# Get the top num_notes.max() values for each row
|
68 |
+
num_notes_max = int(num_notes.max().item()) # Maximum number of notes needed in any row
|
69 |
+
topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1)
|
70 |
+
|
71 |
+
# Create a mask for the top num_notes[i] values for each row
|
72 |
+
col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max)
|
73 |
+
topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf"))
|
74 |
+
|
75 |
+
# Set all values greater than reduce_note_threshold to reduce_note_threshold initially
|
76 |
+
final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold
|
77 |
+
|
78 |
+
# Create a flattened index to scatter the top values back into final_onset_roll
|
79 |
+
flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices)
|
80 |
+
flat_row_indices = flat_row_indices[topk_mask]
|
81 |
+
|
82 |
+
# Gather the valid topk_indices and corresponding values
|
83 |
+
valid_topk_indices = topk_indices[topk_mask]
|
84 |
+
valid_topk_values = topk_values[topk_mask]
|
85 |
+
|
86 |
+
# Use scatter to place the top num_notes[i] values back to their original positions
|
87 |
+
final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values)
|
88 |
+
|
89 |
+
########### if number of notes < required, add some notes ###############
|
90 |
+
pitch_less_84_mask = torch.ones_like(mask)
|
91 |
+
pitch_less_84_mask[:,51:] = 0
|
92 |
+
|
93 |
+
# Count how many values >= increase_note_threshold for each row
|
94 |
+
threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1)
|
95 |
+
greater_than_threshold2_count = threshold_mask_2.sum(dim=1)
|
96 |
+
|
97 |
+
# For those rows, find the remaining number of values needed to be set to increase_note_threshold
|
98 |
+
remaining_needed = num_notes - greater_than_threshold2_count
|
99 |
+
remaining_needed_max = int(remaining_needed.max().item())
|
100 |
+
print("\n\n\n",remaining_needed_max,"\n\n\n")
|
101 |
+
if remaining_needed_max>=0: # need to add notes
|
102 |
+
# Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold)
|
103 |
+
values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
|
104 |
+
topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1)
|
105 |
+
|
106 |
+
# Mask to only adjust the needed number of values in each row
|
107 |
+
col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max)
|
108 |
+
adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf"))
|
109 |
+
|
110 |
+
# Flatten row indices for the new top-k below increase_note_threshold
|
111 |
+
flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices)
|
112 |
+
flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask]
|
113 |
+
|
114 |
+
# Gather the valid indices and set them to increase_note_threshold
|
115 |
+
valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask]
|
116 |
+
|
117 |
+
# Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold
|
118 |
+
final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device))
|
119 |
+
|
120 |
+
final_onset_roll = final_onset_roll.reshape(shape)
|
121 |
+
piano_roll_full[:,0,:,:] = final_onset_roll
|
122 |
+
return piano_roll_full
|
123 |
+
|
124 |
+
def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"):
|
125 |
+
# 预先计算 major 和 minor 和弦的所有旋转
|
126 |
+
maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
|
127 |
+
maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1))
|
128 |
+
min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
|
129 |
+
min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1))
|
130 |
+
|
131 |
+
# all chords, with rotation
|
132 |
+
maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
|
133 |
+
min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
|
134 |
+
|
135 |
+
# combine all chords
|
136 |
+
# chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types,
|
137 |
+
# 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes
|
138 |
+
chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0)
|
139 |
+
|
140 |
+
# if using null rhythm condition, have to convert -2 to 1 and -1 to 0
|
141 |
+
if background_condition[:,:2,:,:].min()<0:
|
142 |
+
correct_chord_condition = -background_condition[:,:2,:,:]-1
|
143 |
+
else:
|
144 |
+
correct_chord_condition = background_condition[:,:2,:,:]
|
145 |
+
merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond
|
146 |
+
chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond
|
147 |
+
shape = chd_chroma_ours.shape
|
148 |
+
chd_chroma_ours = chd_chroma_ours.reshape(-1,64)
|
149 |
+
matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1)
|
150 |
+
seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape)
|
151 |
+
seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1))
|
152 |
+
|
153 |
+
no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1)
|
154 |
+
seven_notes_chroma_ours[no_chd_match] = 1.
|
155 |
+
|
156 |
+
# edit notes based on chroma
|
157 |
+
x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0)
|
158 |
+
print("See Coming?")
|
159 |
+
# edit rhythm
|
160 |
+
if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): # only edit if rhythm is provided
|
161 |
+
num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1)
|
162 |
+
x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes)
|
163 |
+
|
164 |
+
return x0
|
165 |
+
|
166 |
+
def expand_roll(roll, unit=4, contain_onset=False):
|
167 |
+
# roll: (Channel, T, H) -> (Channel, T * unit, H)
|
168 |
+
n_channel, length, height = roll.shape
|
169 |
+
|
170 |
+
expanded_roll = roll.repeat(unit, axis=1)
|
171 |
+
if contain_onset:
|
172 |
+
expanded_roll = expanded_roll.reshape((n_channel, length, unit, height))
|
173 |
+
expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:])
|
174 |
+
|
175 |
+
expanded_roll[::2, :, 1:] = 0
|
176 |
+
expanded_roll = expanded_roll.reshape((n_channel, length * unit, height))
|
177 |
+
return expanded_roll
|
178 |
+
|
179 |
+
def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96):
|
180 |
+
piano_roll_cut = piano_roll[:,:,lowest:highest+1]
|
181 |
+
return piano_roll_cut
|
182 |
+
|
183 |
+
def circular_extend(chd_roll, lowest=33, highest=96):
|
184 |
+
#chd_roll: 6*L*12->6*L*64
|
185 |
+
C4 = 60-lowest
|
186 |
+
C3 = C4-12
|
187 |
+
shape = chd_roll.shape
|
188 |
+
ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest))
|
189 |
+
ext_chd[:,:,C4:C4+12] = chd_roll
|
190 |
+
ext_chd[:,:,C3:C3+12] = chd_roll
|
191 |
+
return ext_chd
|
192 |
+
|
193 |
+
|
194 |
+
def default_quantization(v):
|
195 |
+
return 1 if v > 0.5 else 0
|
196 |
+
|
197 |
+
def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96):
|
198 |
+
## this function is for extending the cutted piano rolls into the full 128 piano rolls
|
199 |
+
## recall that the piano rolls are of dimensions (2,L,64), we add zeros and fill it into (2,L,128)
|
200 |
+
padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0)
|
201 |
+
return padded_roll
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None):
|
206 |
+
"""
|
207 |
+
piano_roll: (2, L, 128), onset and sustain channel.
|
208 |
+
raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave
|
209 |
+
"""
|
210 |
+
def convert_p(p_, note_list):
|
211 |
+
edit_note_flag = False
|
212 |
+
for t in range(n_step):
|
213 |
+
onset_state = quantization_func(piano_roll[0, t, p_])
|
214 |
+
sustain_state = quantization_func(piano_roll[1, t, p_])
|
215 |
+
|
216 |
+
is_onset = bool(onset_state)
|
217 |
+
is_sustain = bool(sustain_state) and not is_onset
|
218 |
+
|
219 |
+
pitch = p_
|
220 |
+
|
221 |
+
if is_onset:
|
222 |
+
edit_note_flag = True
|
223 |
+
note_list.append([t, pitch, 1])
|
224 |
+
elif is_sustain:
|
225 |
+
if edit_note_flag:
|
226 |
+
note_list[-1][-1] += 1
|
227 |
+
else:
|
228 |
+
edit_note_flag = False
|
229 |
+
return note_list
|
230 |
+
|
231 |
+
quantization_func = default_quantization if quantization_func is None else quantization_func
|
232 |
+
assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}"
|
233 |
+
|
234 |
+
n_step = piano_roll.shape[1]
|
235 |
+
|
236 |
+
notes = []
|
237 |
+
for p in range(128):
|
238 |
+
convert_p(p, notes)
|
239 |
+
|
240 |
+
return notes
|
241 |
+
|
242 |
+
|
243 |
+
def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100):
|
244 |
+
"""Default use shift beat"""
|
245 |
+
|
246 |
+
beat_alpha = 60 / bpm
|
247 |
+
step_alpha = unit * beat_alpha
|
248 |
+
|
249 |
+
notes = []
|
250 |
+
|
251 |
+
shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha
|
252 |
+
|
253 |
+
for note in note_mat:
|
254 |
+
onset, pitch, dur = note
|
255 |
+
start = onset * step_alpha + shift_sec
|
256 |
+
end = (onset + dur) * step_alpha + shift_sec
|
257 |
+
|
258 |
+
notes.append(pm.Note(vel, int(pitch), start, end))
|
259 |
+
|
260 |
+
return notes
|
261 |
+
|
262 |
+
|
263 |
+
def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None):
|
264 |
+
midi = pm.PrettyMIDI(initial_tempo=bpm)
|
265 |
+
|
266 |
+
piano_program = pm.instrument_name_to_program('Acoustic Grand Piano')
|
267 |
+
piano = pm.Instrument(program=piano_program)
|
268 |
+
piano.notes+=piano_notes_list
|
269 |
+
midi.instruments.append(piano)
|
270 |
+
|
271 |
+
# chd_program = pm.instrument_name_to_program('Violin')
|
272 |
+
# chd = pm.Instrument(program=chd_program)
|
273 |
+
# chd.notes+=chd_notes_list
|
274 |
+
# midi.instruments.append(chd)
|
275 |
+
|
276 |
+
if lsh_notes_list is not None:
|
277 |
+
lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano')
|
278 |
+
lsh = pm.Instrument(program=lsh_program)
|
279 |
+
lsh.notes+=lsh_notes_list
|
280 |
+
midi.instruments.append(lsh)
|
281 |
+
|
282 |
+
return midi
|
283 |
+
|
284 |
+
def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80):
|
285 |
+
piano_mat = piano_roll_to_note_mat(piano_roll)
|
286 |
+
piano_notes = note_mat_to_notes(piano_mat, bpm)
|
287 |
+
|
288 |
+
chd_mat = piano_roll_to_note_mat(chd_roll)
|
289 |
+
chd_notes = note_mat_to_notes(chd_mat, bpm)
|
290 |
+
|
291 |
+
if lsh_roll is not None:
|
292 |
+
lsh_mat = piano_roll_to_note_mat(lsh_roll)
|
293 |
+
lsh_notes = note_mat_to_notes(lsh_mat, bpm)
|
294 |
+
else:
|
295 |
+
lsh_notes=None
|
296 |
+
|
297 |
+
piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes,
|
298 |
+
chd_notes_list=chd_notes, lsh_notes_list=lsh_notes)
|
299 |
+
return piano_pm
|
300 |
+
|
301 |
+
def save_midi(pm, filename):
|
302 |
+
pm.write(filename)
|
model/__init__.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .latent_diffusion import LatentDiffusion
|
2 |
+
from .model_sdf import Diffpro_SDF
|
3 |
+
from .architecture.unet import UNetModel
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def init_ldm_model(params, debug_mode=False):
|
8 |
+
unet_model = UNetModel(
|
9 |
+
in_channels=params.in_channels,
|
10 |
+
out_channels=params.out_channels,
|
11 |
+
channels=params.channels,
|
12 |
+
attention_levels=params.attention_levels,
|
13 |
+
n_res_blocks=params.n_res_blocks,
|
14 |
+
channel_multipliers=params.channel_multipliers,
|
15 |
+
n_heads=params.n_heads,
|
16 |
+
tf_layers=params.tf_layers,
|
17 |
+
#d_cond=params.d_cond,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
ldm_model = LatentDiffusion(
|
22 |
+
unet_model=unet_model,
|
23 |
+
#autoencoder=None,
|
24 |
+
#autoreg_cond_enc=autoreg_cond_enc,
|
25 |
+
#external_cond_enc=external_cond_enc,
|
26 |
+
latent_scaling_factor=params.latent_scaling_factor,
|
27 |
+
n_steps=params.n_steps,
|
28 |
+
linear_start=params.linear_start,
|
29 |
+
linear_end=params.linear_end,
|
30 |
+
debug_mode=debug_mode
|
31 |
+
)
|
32 |
+
|
33 |
+
return ldm_model
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
def init_diff_pro_sdf(ldm_model, params, device):
|
38 |
+
return Diffpro_SDF(ldm_model).to(device)
|
39 |
+
|
40 |
+
|
41 |
+
def get_model_path(model_dir, model_id=None):
|
42 |
+
model_desc = os.path.basename(model_dir)
|
43 |
+
if model_id is None:
|
44 |
+
model_path = os.path.join(model_dir, 'chkpts', 'weights.pt')
|
45 |
+
|
46 |
+
# retrieve real model_id from the actual file weights.pt is pointing to
|
47 |
+
model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0]
|
48 |
+
|
49 |
+
elif model_id == 'best':
|
50 |
+
model_path = os.path.join(model_dir, 'chkpts', 'weights_best.pt')
|
51 |
+
# retrieve real model_id from the actual file weights.pt is pointing to
|
52 |
+
model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0]
|
53 |
+
elif model_id == 'default':
|
54 |
+
model_path = os.path.join(model_dir, 'chkpts', 'weights_default.pt')
|
55 |
+
if not os.path.exists(model_path):
|
56 |
+
return get_model_path(model_dir, 'best')
|
57 |
+
else:
|
58 |
+
model_path = f"{model_dir}/chkpts/weights-{model_id}.pt"
|
59 |
+
return model_path, model_id, model_desc
|
model/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (1.53 kB). View file
|
|
model/__pycache__/latent_diffusion.cpython-39.pyc
ADDED
Binary file (5.99 kB). View file
|
|
model/__pycache__/model_sdf.cpython-39.pyc
ADDED
Binary file (2.39 kB). View file
|
|
model/__pycache__/sampler_sdf.cpython-39.pyc
ADDED
Binary file (11.3 kB). View file
|
|
model/architecture/__pycache__/unet.cpython-39.pyc
ADDED
Binary file (9.04 kB). View file
|
|
model/architecture/__pycache__/unet_attention.cpython-39.pyc
ADDED
Binary file (8.98 kB). View file
|
|
model/architecture/unet.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
---
|
3 |
+
title: U-Net for Stable Diffusion
|
4 |
+
summary: >
|
5 |
+
Annotated PyTorch implementation/tutorial of the U-Net in stable diffusion.
|
6 |
+
---
|
7 |
+
|
8 |
+
# U-Net for [Stable Diffusion](../index.html)
|
9 |
+
|
10 |
+
This implements the U-Net that
|
11 |
+
gives $\epsilon_\text{cond}(x_t, c)$
|
12 |
+
|
13 |
+
We have kept to the model definition and naming unchanged from
|
14 |
+
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
15 |
+
so that we can load the checkpoints directly.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import List
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
from .unet_attention import SpatialTransformer
|
27 |
+
|
28 |
+
|
29 |
+
class UNetModel(nn.Module):
|
30 |
+
"""
|
31 |
+
## U-Net model
|
32 |
+
"""
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
*,
|
36 |
+
in_channels: int,
|
37 |
+
out_channels: int,
|
38 |
+
channels: int,
|
39 |
+
n_res_blocks: int,
|
40 |
+
attention_levels: List[int],
|
41 |
+
channel_multipliers: List[int],
|
42 |
+
n_heads: int,
|
43 |
+
tf_layers: int = 1,
|
44 |
+
#d_cond: int = 768
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
:param in_channels: is the number of channels in the input feature map
|
48 |
+
:param out_channels: is the number of channels in the output feature map
|
49 |
+
:param channels: is the base channel count for the model
|
50 |
+
:param n_res_blocks: number of residual blocks at each level
|
51 |
+
:param attention_levels: are the levels at which attention should be performed
|
52 |
+
:param channel_multipliers: are the multiplicative factors for number of channels for each level
|
53 |
+
:param n_heads: the number of attention heads in the transformers
|
54 |
+
"""
|
55 |
+
super().__init__()
|
56 |
+
self.channels = channels
|
57 |
+
self.out_channels = out_channels
|
58 |
+
#self.d_cond = d_cond
|
59 |
+
|
60 |
+
# Number of levels
|
61 |
+
levels = len(channel_multipliers)
|
62 |
+
# Size time embeddings
|
63 |
+
d_time_emb = channels * 4
|
64 |
+
self.time_embed = nn.Sequential(
|
65 |
+
nn.Linear(channels, d_time_emb),
|
66 |
+
nn.SiLU(),
|
67 |
+
nn.Linear(d_time_emb, d_time_emb),
|
68 |
+
)
|
69 |
+
|
70 |
+
# Input half of the U-Net
|
71 |
+
self.input_blocks = nn.ModuleList()
|
72 |
+
# Initial $3 \times 3$ convolution that maps the input to `channels`.
|
73 |
+
# The blocks are wrapped in `TimestepEmbedSequential` module because
|
74 |
+
# different modules have different forward function signatures;
|
75 |
+
# for example, convolution only accepts the feature map and
|
76 |
+
# residual blocks accept the feature map and time embedding.
|
77 |
+
# `TimestepEmbedSequential` calls them accordingly.
|
78 |
+
self.input_blocks.append(
|
79 |
+
TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1))
|
80 |
+
)
|
81 |
+
# Number of channels at each block in the input half of U-Net
|
82 |
+
input_block_channels = [channels]
|
83 |
+
# Number of channels at each level
|
84 |
+
channels_list = [channels * m for m in channel_multipliers]
|
85 |
+
# Prepare levels
|
86 |
+
for i in range(levels):
|
87 |
+
# Add the residual blocks and attentions
|
88 |
+
for _ in range(n_res_blocks):
|
89 |
+
# Residual block maps from previous number of channels to the number of
|
90 |
+
# channels in the current level
|
91 |
+
layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
|
92 |
+
channels = channels_list[i]
|
93 |
+
# Add transformer
|
94 |
+
if i in attention_levels:
|
95 |
+
layers.append(
|
96 |
+
SpatialTransformer(channels, n_heads, tf_layers)
|
97 |
+
)
|
98 |
+
# Add them to the input half of the U-Net and keep track of the number of channels of
|
99 |
+
# its output
|
100 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
101 |
+
input_block_channels.append(channels)
|
102 |
+
# Down sample at all levels except last
|
103 |
+
if i != levels - 1:
|
104 |
+
self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
|
105 |
+
input_block_channels.append(channels)
|
106 |
+
|
107 |
+
# The middle of the U-Net
|
108 |
+
self.middle_block = TimestepEmbedSequential(
|
109 |
+
ResBlock(channels, d_time_emb),
|
110 |
+
SpatialTransformer(channels, n_heads, tf_layers),
|
111 |
+
ResBlock(channels, d_time_emb),
|
112 |
+
)
|
113 |
+
|
114 |
+
# Second half of the U-Net
|
115 |
+
self.output_blocks = nn.ModuleList([])
|
116 |
+
# Prepare levels in reverse order
|
117 |
+
for i in reversed(range(levels)):
|
118 |
+
# Add the residual blocks and attentions
|
119 |
+
for j in range(n_res_blocks + 1):
|
120 |
+
# Residual block maps from previous number of channels plus the
|
121 |
+
# skip connections from the input half of U-Net to the number of
|
122 |
+
# channels in the current level.
|
123 |
+
layers = [
|
124 |
+
ResBlock(
|
125 |
+
channels + input_block_channels.pop(),
|
126 |
+
d_time_emb,
|
127 |
+
out_channels=channels_list[i]
|
128 |
+
)
|
129 |
+
]
|
130 |
+
channels = channels_list[i]
|
131 |
+
# Add transformer
|
132 |
+
if i in attention_levels:
|
133 |
+
layers.append(
|
134 |
+
SpatialTransformer(channels, n_heads, tf_layers)
|
135 |
+
)
|
136 |
+
# Up-sample at every level after last residual block
|
137 |
+
# except the last one.
|
138 |
+
# Note that we are iterating in reverse; i.e. `i == 0` is the last.
|
139 |
+
if i != 0 and j == n_res_blocks:
|
140 |
+
layers.append(UpSample(channels))
|
141 |
+
# Add to the output half of the U-Net
|
142 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
143 |
+
|
144 |
+
# Final normalization and $3 \times 3$ convolution
|
145 |
+
self.out = nn.Sequential(
|
146 |
+
normalization(channels),
|
147 |
+
nn.SiLU(),
|
148 |
+
nn.Conv2d(channels, out_channels, 3, padding=1),
|
149 |
+
)
|
150 |
+
|
151 |
+
def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
|
152 |
+
"""
|
153 |
+
## Create sinusoidal time step embeddings
|
154 |
+
|
155 |
+
:param time_steps: are the time steps of shape `[batch_size]`
|
156 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
157 |
+
"""
|
158 |
+
# $\frac{c}{2}$; half the channels are sin and the other half is cos,
|
159 |
+
half = self.channels // 2
|
160 |
+
# $\frac{1}{10000^{\frac{2i}{c}}}$
|
161 |
+
frequencies = torch.exp(
|
162 |
+
-math.log(max_period) *
|
163 |
+
torch.arange(start=0, end=half, dtype=torch.float32) / half
|
164 |
+
).to(device=time_steps.device)
|
165 |
+
# $\frac{t}{10000^{\frac{2i}{c}}}$
|
166 |
+
args = time_steps[:, None].float() * frequencies[None]
|
167 |
+
# $\cos\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$ and $\sin\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$
|
168 |
+
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
169 |
+
|
170 |
+
def forward(self, x: torch.Tensor, time_steps: torch.Tensor):
|
171 |
+
"""
|
172 |
+
:param x: is the input feature map of shape `[batch_size, channels, width, height]`
|
173 |
+
:param time_steps: are the time steps of shape `[batch_size]`
|
174 |
+
:param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
|
175 |
+
"""
|
176 |
+
# To store the input half outputs for skip connections
|
177 |
+
x_input_block = []
|
178 |
+
|
179 |
+
# Get time step embeddings
|
180 |
+
t_emb = self.time_step_embedding(time_steps)
|
181 |
+
t_emb = self.time_embed(t_emb)
|
182 |
+
|
183 |
+
# Input half of the U-Net
|
184 |
+
for module in self.input_blocks:
|
185 |
+
##########################
|
186 |
+
#print("x:", x.dtype,"t_emb:",t_emb.dtype)
|
187 |
+
##########################
|
188 |
+
#x = x.to(torch.float16)
|
189 |
+
x = module(x, t_emb)
|
190 |
+
x_input_block.append(x)
|
191 |
+
# Middle of the U-Net
|
192 |
+
x = self.middle_block(x, t_emb)
|
193 |
+
# Output half of the U-Net
|
194 |
+
for module in self.output_blocks:
|
195 |
+
# print(x.size(), 'a')
|
196 |
+
x = torch.cat([x, x_input_block.pop()], dim=1)
|
197 |
+
# print(x.size(), 'b')
|
198 |
+
x = module(x, t_emb)
|
199 |
+
|
200 |
+
# Final normalization and $3 \times 3$ convolution
|
201 |
+
return self.out(x)
|
202 |
+
|
203 |
+
|
204 |
+
class TimestepEmbedSequential(nn.Sequential):
|
205 |
+
"""
|
206 |
+
### Sequential block for modules with different inputs
|
207 |
+
|
208 |
+
This sequential module can compose of different modules suck as `ResBlock`,
|
209 |
+
`nn.Conv` and `SpatialTransformer` and calls them with the matching signatures
|
210 |
+
"""
|
211 |
+
def forward(self, x, t_emb, cond=None):
|
212 |
+
for layer in self:
|
213 |
+
if isinstance(layer, ResBlock):
|
214 |
+
x = layer(x, t_emb)
|
215 |
+
elif isinstance(layer, SpatialTransformer):
|
216 |
+
x = layer(x)
|
217 |
+
else:
|
218 |
+
x = layer(x)
|
219 |
+
return x
|
220 |
+
|
221 |
+
|
222 |
+
class UpSample(nn.Module):
|
223 |
+
"""
|
224 |
+
### Up-sampling layer
|
225 |
+
"""
|
226 |
+
def __init__(self, channels: int):
|
227 |
+
"""
|
228 |
+
:param channels: is the number of channels
|
229 |
+
"""
|
230 |
+
super().__init__()
|
231 |
+
# $3 \times 3$ convolution mapping
|
232 |
+
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
|
233 |
+
|
234 |
+
def forward(self, x: torch.Tensor):
|
235 |
+
"""
|
236 |
+
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
237 |
+
"""
|
238 |
+
# Up-sample by a factor of $2$
|
239 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
240 |
+
# Apply convolution
|
241 |
+
return self.conv(x)
|
242 |
+
|
243 |
+
|
244 |
+
class DownSample(nn.Module):
|
245 |
+
"""
|
246 |
+
## Down-sampling layer
|
247 |
+
"""
|
248 |
+
def __init__(self, channels: int):
|
249 |
+
"""
|
250 |
+
:param channels: is the number of channels
|
251 |
+
"""
|
252 |
+
super().__init__()
|
253 |
+
# $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
|
254 |
+
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
|
255 |
+
|
256 |
+
def forward(self, x: torch.Tensor):
|
257 |
+
"""
|
258 |
+
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
259 |
+
"""
|
260 |
+
# Apply convolution
|
261 |
+
return self.op(x)
|
262 |
+
|
263 |
+
|
264 |
+
class ResBlock(nn.Module):
|
265 |
+
"""
|
266 |
+
## ResNet Block
|
267 |
+
"""
|
268 |
+
def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
|
269 |
+
"""
|
270 |
+
:param channels: the number of input channels
|
271 |
+
:param d_t_emb: the size of timestep embeddings
|
272 |
+
:param out_channels: is the number of out channels. defaults to `channels.
|
273 |
+
"""
|
274 |
+
super().__init__()
|
275 |
+
# `out_channels` not specified
|
276 |
+
if out_channels is None:
|
277 |
+
out_channels = channels
|
278 |
+
|
279 |
+
# First normalization and convolution
|
280 |
+
self.in_layers = nn.Sequential(
|
281 |
+
normalization(channels),
|
282 |
+
nn.SiLU(),
|
283 |
+
nn.Conv2d(channels, out_channels, 3, padding=1),
|
284 |
+
)
|
285 |
+
|
286 |
+
# Time step embeddings
|
287 |
+
self.emb_layers = nn.Sequential(
|
288 |
+
nn.SiLU(),
|
289 |
+
nn.Linear(d_t_emb, out_channels),
|
290 |
+
)
|
291 |
+
# Final convolution layer
|
292 |
+
self.out_layers = nn.Sequential(
|
293 |
+
normalization(out_channels), nn.SiLU(), nn.Dropout(0.),
|
294 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1)
|
295 |
+
)
|
296 |
+
|
297 |
+
# `channels` to `out_channels` mapping layer for residual connection
|
298 |
+
if out_channels == channels:
|
299 |
+
self.skip_connection = nn.Identity()
|
300 |
+
else:
|
301 |
+
self.skip_connection = nn.Conv2d(channels, out_channels, 1)
|
302 |
+
|
303 |
+
def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
|
304 |
+
"""
|
305 |
+
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
306 |
+
:param t_emb: is the time step embeddings of shape `[batch_size, d_t_emb]`
|
307 |
+
"""
|
308 |
+
# Initial convolution
|
309 |
+
h = self.in_layers(x)
|
310 |
+
# Time step embeddings
|
311 |
+
t_emb = self.emb_layers(t_emb).type(h.dtype)
|
312 |
+
# Add time step embeddings
|
313 |
+
h = h + t_emb[:, :, None, None]
|
314 |
+
# Final convolution
|
315 |
+
h = self.out_layers(h)
|
316 |
+
# Add skip connection
|
317 |
+
return self.skip_connection(x) + h
|
318 |
+
|
319 |
+
|
320 |
+
class GroupNorm32(nn.GroupNorm):
|
321 |
+
"""
|
322 |
+
### Group normalization with float32 casting
|
323 |
+
"""
|
324 |
+
def forward(self, x):
|
325 |
+
return super().forward(x.float()).type(x.dtype)
|
326 |
+
|
327 |
+
|
328 |
+
def normalization(channels):
|
329 |
+
"""
|
330 |
+
### Group normalization
|
331 |
+
|
332 |
+
This is a helper function, with fixed number of groups..
|
333 |
+
"""
|
334 |
+
return GroupNorm32(32, channels)
|
335 |
+
|
336 |
+
|
337 |
+
def _test_time_embeddings():
|
338 |
+
"""
|
339 |
+
Test sinusoidal time step embeddings
|
340 |
+
"""
|
341 |
+
import matplotlib.pyplot as plt
|
342 |
+
|
343 |
+
plt.figure(figsize=(15, 5))
|
344 |
+
m = UNetModel(
|
345 |
+
in_channels=1,
|
346 |
+
out_channels=1,
|
347 |
+
channels=320,
|
348 |
+
n_res_blocks=1,
|
349 |
+
attention_levels=[],
|
350 |
+
channel_multipliers=[],
|
351 |
+
n_heads=1,
|
352 |
+
tf_layers=1,
|
353 |
+
d_cond=1
|
354 |
+
)
|
355 |
+
te = m.time_step_embedding(torch.arange(0, 1000))
|
356 |
+
plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
|
357 |
+
plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
|
358 |
+
plt.title("Time embeddings")
|
359 |
+
plt.show()
|
360 |
+
|
361 |
+
|
362 |
+
#
|
363 |
+
if __name__ == '__main__':
|
364 |
+
_test_time_embeddings()
|
model/architecture/unet_attention.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
---
|
3 |
+
title: Transformer for Stable Diffusion U-Net
|
4 |
+
summary: >
|
5 |
+
Annotated PyTorch implementation/tutorial of the transformer
|
6 |
+
for U-Net in stable diffusion.
|
7 |
+
---
|
8 |
+
|
9 |
+
# Transformer for Stable Diffusion [U-Net](unet.html)
|
10 |
+
|
11 |
+
This implements the transformer module used in [U-Net](unet.html) that
|
12 |
+
gives $\epsilon_\text{cond}(x_t, c)$
|
13 |
+
|
14 |
+
We have kept to the model definition and naming unchanged from
|
15 |
+
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
16 |
+
so that we can load the checkpoints directly.
|
17 |
+
"""
|
18 |
+
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
|
26 |
+
class SpatialTransformer(nn.Module):
|
27 |
+
"""
|
28 |
+
## Spatial Transformer
|
29 |
+
"""
|
30 |
+
def __init__(self, channels: int, n_heads: int, n_layers: int):
|
31 |
+
"""
|
32 |
+
:param channels: is the number of channels in the feature map
|
33 |
+
:param n_heads: is the number of attention heads
|
34 |
+
:param n_layers: is the number of transformer layers
|
35 |
+
:param d_cond: is the size of the conditional embedding
|
36 |
+
"""
|
37 |
+
super().__init__()
|
38 |
+
# Initial group normalization
|
39 |
+
self.norm = torch.nn.GroupNorm(
|
40 |
+
num_groups=32, num_channels=channels, eps=1e-6, affine=True
|
41 |
+
)
|
42 |
+
# Initial $1 \times 1$ convolution
|
43 |
+
self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
|
44 |
+
|
45 |
+
# Transformer layers
|
46 |
+
self.transformer_blocks = nn.ModuleList(
|
47 |
+
[
|
48 |
+
BasicTransformerBlock(
|
49 |
+
channels, n_heads, channels // n_heads
|
50 |
+
) for _ in range(n_layers)
|
51 |
+
]
|
52 |
+
)
|
53 |
+
|
54 |
+
# Final $1 \times 1$ convolution
|
55 |
+
self.proj_out = nn.Conv2d(
|
56 |
+
channels, channels, kernel_size=1, stride=1, padding=0
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor):
|
60 |
+
"""
|
61 |
+
:param x: is the feature map of shape `[batch_size, channels, height, width]`
|
62 |
+
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
63 |
+
"""
|
64 |
+
# Get shape `[batch_size, channels, height, width]`
|
65 |
+
b, c, h, w = x.shape
|
66 |
+
# For residual connection
|
67 |
+
x_in = x
|
68 |
+
# Normalize
|
69 |
+
x = self.norm(x)
|
70 |
+
# Initial $1 \times 1$ convolution
|
71 |
+
x = self.proj_in(x)
|
72 |
+
# Transpose and reshape from `[batch_size, channels, height, width]`
|
73 |
+
# to `[batch_size, height * width, channels]`
|
74 |
+
x = x.permute(0, 2, 3, 1).view(b, h * w, c)
|
75 |
+
# Apply the transformer layers
|
76 |
+
for block in self.transformer_blocks:
|
77 |
+
x = block(x)
|
78 |
+
# Reshape and transpose from `[batch_size, height * width, channels]`
|
79 |
+
# to `[batch_size, channels, height, width]`
|
80 |
+
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
81 |
+
# Final $1 \times 1$ convolution
|
82 |
+
x = self.proj_out(x)
|
83 |
+
# Add residual
|
84 |
+
return x + x_in
|
85 |
+
|
86 |
+
|
87 |
+
class BasicTransformerBlock(nn.Module):
|
88 |
+
"""
|
89 |
+
### Transformer Layer
|
90 |
+
"""
|
91 |
+
def __init__(self, d_model: int, n_heads: int, d_head: int):
|
92 |
+
"""
|
93 |
+
:param d_model: is the input embedding size
|
94 |
+
:param n_heads: is the number of attention heads
|
95 |
+
:param d_head: is the size of a attention head
|
96 |
+
:param d_cond: is the size of the conditional embeddings
|
97 |
+
"""
|
98 |
+
super().__init__()
|
99 |
+
# Self-attention layer and pre-norm layer
|
100 |
+
self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
|
101 |
+
self.norm1 = nn.LayerNorm(d_model)
|
102 |
+
# Cross attention layer and pre-norm layer
|
103 |
+
#self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
|
104 |
+
self.norm2 = nn.LayerNorm(d_model)
|
105 |
+
# Feed-forward network and pre-norm layer
|
106 |
+
self.ff = FeedForward(d_model)
|
107 |
+
self.norm3 = nn.LayerNorm(d_model)
|
108 |
+
|
109 |
+
def forward(self, x: torch.Tensor):
|
110 |
+
"""
|
111 |
+
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
|
112 |
+
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
113 |
+
"""
|
114 |
+
# Self attention
|
115 |
+
x = self.attn1(self.norm1(x)) + x
|
116 |
+
# Cross-attention with conditioning
|
117 |
+
# x = self.attn2(self.norm2(x), cond=cond) + x
|
118 |
+
# Feed-forward network
|
119 |
+
x = self.ff(self.norm3(x)) + x
|
120 |
+
#
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class CrossAttention(nn.Module):
|
125 |
+
"""
|
126 |
+
### Cross Attention Layer
|
127 |
+
|
128 |
+
This falls-back to self-attention when conditional embeddings are not specified.
|
129 |
+
"""
|
130 |
+
|
131 |
+
use_flash_attention: bool = False
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
d_model: int,
|
136 |
+
d_cond: int,
|
137 |
+
n_heads: int,
|
138 |
+
d_head: int,
|
139 |
+
is_inplace: bool = True
|
140 |
+
):
|
141 |
+
"""
|
142 |
+
:param d_model: is the input embedding size
|
143 |
+
:param n_heads: is the number of attention heads
|
144 |
+
:param d_head: is the size of a attention head
|
145 |
+
:param d_cond: is the size of the conditional embeddings
|
146 |
+
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
|
147 |
+
save memory
|
148 |
+
"""
|
149 |
+
super().__init__()
|
150 |
+
|
151 |
+
self.is_inplace = is_inplace
|
152 |
+
self.n_heads = n_heads
|
153 |
+
self.d_head = d_head
|
154 |
+
|
155 |
+
# Attention scaling factor
|
156 |
+
self.scale = d_head**-0.5
|
157 |
+
|
158 |
+
# Query, key and value mappings
|
159 |
+
d_attn = d_head * n_heads
|
160 |
+
self.to_q = nn.Linear(d_model, d_attn, bias=False)
|
161 |
+
self.to_k = nn.Linear(d_cond, d_attn, bias=False)
|
162 |
+
self.to_v = nn.Linear(d_cond, d_attn, bias=False)
|
163 |
+
|
164 |
+
# Final linear layer
|
165 |
+
self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
|
166 |
+
|
167 |
+
# Setup [flash attention](https://github.com/HazyResearch/flash-attention).
|
168 |
+
# Flash attention is only used if it's installed
|
169 |
+
# and `CrossAttention.use_flash_attention` is set to `True`.
|
170 |
+
try:
|
171 |
+
# You can install flash attention by cloning their Github repo,
|
172 |
+
# [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
|
173 |
+
# and then running `python setup.py install`
|
174 |
+
from flash_attn.flash_attention import FlashAttention
|
175 |
+
self.flash = FlashAttention()
|
176 |
+
# Set the scale for scaled dot-product attention.
|
177 |
+
self.flash.softmax_scale = self.scale
|
178 |
+
# Set to `None` if it's not installed
|
179 |
+
except ImportError:
|
180 |
+
self.flash = None
|
181 |
+
|
182 |
+
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
|
183 |
+
"""
|
184 |
+
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
|
185 |
+
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
186 |
+
"""
|
187 |
+
|
188 |
+
# If `cond` is `None` we perform self attention
|
189 |
+
has_cond = cond is not None
|
190 |
+
if not has_cond:
|
191 |
+
cond = x
|
192 |
+
|
193 |
+
# Get query, key and value vectors
|
194 |
+
q = self.to_q(x)
|
195 |
+
k = self.to_k(cond)
|
196 |
+
v = self.to_v(cond)
|
197 |
+
|
198 |
+
# Use flash attention if it's available and the head size is less than or equal to `128`
|
199 |
+
if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
|
200 |
+
return self.flash_attention(q, k, v)
|
201 |
+
# Otherwise, fallback to normal attention
|
202 |
+
else:
|
203 |
+
return self.normal_attention(q, k, v)
|
204 |
+
|
205 |
+
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
206 |
+
"""
|
207 |
+
#### Flash Attention
|
208 |
+
|
209 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
210 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
211 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
212 |
+
"""
|
213 |
+
|
214 |
+
# Get batch size and number of elements along sequence axis (`width * height`)
|
215 |
+
batch_size, seq_len, _ = q.shape
|
216 |
+
|
217 |
+
# Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
|
218 |
+
# shape `[batch_size, seq_len, 3, n_heads * d_head]`
|
219 |
+
qkv = torch.stack((q, k, v), dim=2)
|
220 |
+
# Split the heads
|
221 |
+
qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
|
222 |
+
|
223 |
+
# Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
|
224 |
+
# fit this size.
|
225 |
+
if self.d_head <= 32:
|
226 |
+
pad = 32 - self.d_head
|
227 |
+
elif self.d_head <= 64:
|
228 |
+
pad = 64 - self.d_head
|
229 |
+
elif self.d_head <= 128:
|
230 |
+
pad = 128 - self.d_head
|
231 |
+
else:
|
232 |
+
raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
|
233 |
+
|
234 |
+
# Pad the heads
|
235 |
+
if pad:
|
236 |
+
qkv = torch.cat(
|
237 |
+
(qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
|
238 |
+
)
|
239 |
+
|
240 |
+
# Compute attention
|
241 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
242 |
+
# This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
|
243 |
+
out, _ = self.flash(qkv)
|
244 |
+
# Truncate the extra head size
|
245 |
+
out = out[:, :, :, : self.d_head]
|
246 |
+
# Reshape to `[batch_size, seq_len, n_heads * d_head]`
|
247 |
+
out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
|
248 |
+
|
249 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
250 |
+
return self.to_out(out)
|
251 |
+
|
252 |
+
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
253 |
+
"""
|
254 |
+
#### Normal Attention
|
255 |
+
|
256 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
257 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
258 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
259 |
+
"""
|
260 |
+
|
261 |
+
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
|
262 |
+
q = q.view(*q.shape[: 2], self.n_heads, -1)
|
263 |
+
k = k.view(*k.shape[: 2], self.n_heads, -1)
|
264 |
+
v = v.view(*v.shape[: 2], self.n_heads, -1)
|
265 |
+
|
266 |
+
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
|
267 |
+
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
|
268 |
+
|
269 |
+
# Compute softmax
|
270 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
|
271 |
+
if self.is_inplace:
|
272 |
+
half = attn.shape[0] // 2
|
273 |
+
attn[half :] = attn[half :].softmax(dim=-1)
|
274 |
+
attn[: half] = attn[: half].softmax(dim=-1)
|
275 |
+
else:
|
276 |
+
attn = attn.softmax(dim=-1)
|
277 |
+
|
278 |
+
# Compute attention output
|
279 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
280 |
+
out = torch.einsum('bhij,bjhd->bihd', attn, v)
|
281 |
+
# Reshape to `[batch_size, height * width, n_heads * d_head]`
|
282 |
+
out = out.reshape(*out.shape[: 2], -1)
|
283 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
284 |
+
return self.to_out(out)
|
285 |
+
|
286 |
+
|
287 |
+
class FeedForward(nn.Module):
|
288 |
+
"""
|
289 |
+
### Feed-Forward Network
|
290 |
+
"""
|
291 |
+
def __init__(self, d_model: int, d_mult: int = 4):
|
292 |
+
"""
|
293 |
+
:param d_model: is the input embedding size
|
294 |
+
:param d_mult: is multiplicative factor for the hidden layer size
|
295 |
+
"""
|
296 |
+
super().__init__()
|
297 |
+
self.net = nn.Sequential(
|
298 |
+
GeGLU(d_model, d_model * d_mult), nn.Dropout(0.),
|
299 |
+
nn.Linear(d_model * d_mult, d_model)
|
300 |
+
)
|
301 |
+
|
302 |
+
def forward(self, x: torch.Tensor):
|
303 |
+
return self.net(x)
|
304 |
+
|
305 |
+
|
306 |
+
class GeGLU(nn.Module):
|
307 |
+
"""
|
308 |
+
### GeGLU Activation
|
309 |
+
|
310 |
+
$$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
|
311 |
+
"""
|
312 |
+
def __init__(self, d_in: int, d_out: int):
|
313 |
+
super().__init__()
|
314 |
+
# Combined linear projections $xW + b$ and $xV + c$
|
315 |
+
self.proj = nn.Linear(d_in, d_out * 2)
|
316 |
+
|
317 |
+
def forward(self, x: torch.Tensor):
|
318 |
+
# Get $xW + b$ and $xV + c$
|
319 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
320 |
+
# $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
|
321 |
+
return x * F.gelu(gate)
|
model/latent_diffusion.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
---
|
3 |
+
title: Latent Diffusion Models
|
4 |
+
summary: >
|
5 |
+
Annotated PyTorch implementation/tutorial of latent diffusion models from paper
|
6 |
+
High-Resolution Image Synthesis with Latent Diffusion Models
|
7 |
+
---
|
8 |
+
|
9 |
+
# Latent Diffusion Models
|
10 |
+
|
11 |
+
Latent diffusion models use an auto-encoder to map between image space and
|
12 |
+
latent space. The diffusion model works on the diffusion space, which makes it
|
13 |
+
a lot easier to train.
|
14 |
+
It is based on paper
|
15 |
+
[High-Resolution Image Synthesis with Latent Diffusion Models](https://papers.labml.ai/paper/2112.10752).
|
16 |
+
|
17 |
+
They use a pre-trained auto-encoder and train the diffusion U-Net on the latent
|
18 |
+
space of the pre-trained auto-encoder.
|
19 |
+
|
20 |
+
For a simpler diffusion implementation refer to our [DDPM implementation](../ddpm/index.html).
|
21 |
+
We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.
|
22 |
+
"""
|
23 |
+
|
24 |
+
from typing import List, Tuple, Optional, Union
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from .architecture.unet import UNetModel
|
29 |
+
import random
|
30 |
+
|
31 |
+
|
32 |
+
def gather(consts: torch.Tensor, t: torch.Tensor):
|
33 |
+
"""Gather consts for $t$ and reshape to feature map shape"""
|
34 |
+
c = consts.gather(-1, t)
|
35 |
+
return c.reshape(-1, 1, 1, 1)
|
36 |
+
|
37 |
+
|
38 |
+
class LatentDiffusion(nn.Module):
|
39 |
+
"""
|
40 |
+
## Latent diffusion model
|
41 |
+
|
42 |
+
This contains following components:
|
43 |
+
|
44 |
+
* [AutoEncoder](model/autoencoder.html)
|
45 |
+
* [U-Net](model/unet.html) with [attention](model/unet_attention.html)
|
46 |
+
"""
|
47 |
+
eps_model: UNetModel
|
48 |
+
#first_stage_model: Optional[Autoencoder] = None
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
unet_model: UNetModel,
|
53 |
+
latent_scaling_factor: float,
|
54 |
+
n_steps: int,
|
55 |
+
linear_start: float,
|
56 |
+
linear_end: float,
|
57 |
+
debug_mode: Optional[bool] = False
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
:param unet_model: is the [U-Net](model/unet.html) that predicts noise
|
61 |
+
$\epsilon_\text{cond}(x_t, c)$, in latent space
|
62 |
+
:param autoencoder: is the [AutoEncoder](model/autoencoder.html)
|
63 |
+
:param latent_scaling_factor: is the scaling factor for the latent space. The encodings of
|
64 |
+
the autoencoder are scaled by this before feeding into the U-Net.
|
65 |
+
:param n_steps: is the number of diffusion steps $T$.
|
66 |
+
:param linear_start: is the start of the $\beta$ schedule.
|
67 |
+
:param linear_end: is the end of the $\beta$ schedule.
|
68 |
+
"""
|
69 |
+
super().__init__()
|
70 |
+
# Wrap the [U-Net](model/unet.html) to keep the same model structure as
|
71 |
+
# [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion).
|
72 |
+
self.eps_model = unet_model
|
73 |
+
self.latent_scaling_factor = latent_scaling_factor
|
74 |
+
|
75 |
+
# Number of steps $T$
|
76 |
+
self.n_steps = n_steps
|
77 |
+
|
78 |
+
# $\beta$ schedule
|
79 |
+
beta = torch.linspace(
|
80 |
+
linear_start**0.5, linear_end**0.5, n_steps, dtype=torch.float64
|
81 |
+
) ** 2
|
82 |
+
# $\alpha_t = 1 - \beta_t$
|
83 |
+
alpha = 1. - beta
|
84 |
+
# $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
|
85 |
+
alpha_bar = torch.cumprod(alpha, dim=0)
|
86 |
+
self.alpha = nn.Parameter(alpha.to(torch.float32), requires_grad=False)
|
87 |
+
self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
|
88 |
+
self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
|
89 |
+
self.alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])
|
90 |
+
self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev))
|
91 |
+
self.sigma2 = self.beta
|
92 |
+
|
93 |
+
self.debug_mode = debug_mode
|
94 |
+
|
95 |
+
@property
|
96 |
+
def device(self):
|
97 |
+
"""
|
98 |
+
### Get model device
|
99 |
+
"""
|
100 |
+
return next(iter(self.eps_model.parameters())).device
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
105 |
+
"""
|
106 |
+
### Predict noise
|
107 |
+
|
108 |
+
Predict noise given the latent representation $x_t$, time step $t$, and the
|
109 |
+
conditioning context $c$.
|
110 |
+
|
111 |
+
$$\epsilon_\text{cond}(x_t, c)$$
|
112 |
+
"""
|
113 |
+
return self.eps_model(x, t)
|
114 |
+
|
115 |
+
def q_xt_x0(self, x0: torch.Tensor,
|
116 |
+
t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
117 |
+
"""
|
118 |
+
#### Get $q(x_t|x_0)$ distribution
|
119 |
+
"""
|
120 |
+
|
121 |
+
# [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$
|
122 |
+
mean = gather(self.alpha_bar, t)**0.5 * x0
|
123 |
+
# $(1-\bar\alpha_t) \mathbf{I}$
|
124 |
+
var = 1 - gather(self.alpha_bar, t)
|
125 |
+
#
|
126 |
+
return mean, var
|
127 |
+
|
128 |
+
def q_sample(
|
129 |
+
self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None
|
130 |
+
):
|
131 |
+
"""
|
132 |
+
#### Sample from $q(x_t|x_0)$
|
133 |
+
"""
|
134 |
+
|
135 |
+
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
|
136 |
+
if eps is None:
|
137 |
+
eps = torch.randn_like(x0)
|
138 |
+
|
139 |
+
# get $q(x_t|x_0)$
|
140 |
+
mean, var = self.q_xt_x0(x0, t)
|
141 |
+
# Sample from $q(x_t|x_0)$
|
142 |
+
return mean + (var**0.5) * eps
|
143 |
+
|
144 |
+
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
|
145 |
+
"""
|
146 |
+
#### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
|
147 |
+
"""
|
148 |
+
|
149 |
+
# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
|
150 |
+
eps_theta = self.eps_model(xt, t)
|
151 |
+
# [gather](utils.html) $\bar\alpha_t$
|
152 |
+
alpha_bar = gather(self.alpha_bar, t)
|
153 |
+
# [gather](utils.html) $\bar\alpha_t-1$
|
154 |
+
alpha_bar_prev = gather(self.alpha_bar_prev, t)
|
155 |
+
# [gather](utils.html) $\sigma_t$
|
156 |
+
sigma_ddim = gather(self.sigma_ddim, t)
|
157 |
+
|
158 |
+
# DDIM sampling
|
159 |
+
# $\frac{x_t-\sqrt{1-\bar\alpha_t}\epsilon}{\sqrt{\bar\alpha_t}}$
|
160 |
+
predicted_x0 = (xt - (1-alpha_bar)**0.5 * eps_theta) / (alpha_bar)**.5
|
161 |
+
# $\sqrt{1-\alpha_{t-1}-\sigma_t^2}$
|
162 |
+
direction_to_xt = (1 - alpha_bar_prev - sigma_ddim**2)**0.5 * eps_theta
|
163 |
+
|
164 |
+
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
|
165 |
+
eps = torch.randn(xt.shape, device=xt.device)
|
166 |
+
|
167 |
+
# Sample
|
168 |
+
x_tm_1 = alpha_bar_prev**0.5 * predicted_x0 + direction_to_xt + sigma_ddim * eps
|
169 |
+
return x_tm_1
|
170 |
+
|
171 |
+
def loss(
|
172 |
+
self,
|
173 |
+
x0: torch.Tensor,
|
174 |
+
#autoreg_cond: Union[torch.Tensor, None], #This means it can be either a tensor or none
|
175 |
+
#external_cond: Union[torch.Tensor, None],
|
176 |
+
noise: Optional[torch.Tensor] = None,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
#### Simplified Loss
|
180 |
+
"""
|
181 |
+
# Get batch size
|
182 |
+
batch_size = x0.shape[0]
|
183 |
+
# Get random $t$ for each sample in the batch
|
184 |
+
t = torch.randint(
|
185 |
+
0, self.n_steps, (batch_size, ), device=x0.device, dtype=torch.long
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
#autoreg_cond = -torch.ones(x0.size(0), 1, self.eps_model.d_cond, device=x0.device, dtype=x0.dtype)
|
190 |
+
#cond = autoreg_cond
|
191 |
+
|
192 |
+
if x0.size(1) == self.eps_model.out_channels: # generating form
|
193 |
+
if self.debug_mode:
|
194 |
+
print('In the mode of root level:', x0.size())
|
195 |
+
if noise is None:
|
196 |
+
x0 = x0.to(torch.float32)
|
197 |
+
noise = torch.randn_like(x0)
|
198 |
+
|
199 |
+
xt = self.q_sample(x0, t, eps=noise)
|
200 |
+
|
201 |
+
eps_theta = self.eps_model(xt, t)
|
202 |
+
|
203 |
+
loss = F.mse_loss(noise, eps_theta)
|
204 |
+
else:
|
205 |
+
if self.debug_mode:
|
206 |
+
print('In the mode of non-root level:', x0.size())
|
207 |
+
|
208 |
+
if noise is None:
|
209 |
+
noise = torch.randn_like(x0[:, 0: 2])
|
210 |
+
|
211 |
+
front_t = self.q_sample(x0[:, 0: 2], t, eps=noise)
|
212 |
+
|
213 |
+
background_cond = x0[:, 2:]
|
214 |
+
|
215 |
+
xt = torch.cat([front_t, background_cond], 1)
|
216 |
+
|
217 |
+
eps_theta = self.eps_model(xt, t)
|
218 |
+
|
219 |
+
loss = F.mse_loss(noise, eps_theta)
|
220 |
+
if self.debug_mode:
|
221 |
+
print('loss:', loss)
|
222 |
+
return loss
|
model/model_sdf.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import torch.nn as nn
|
4 |
+
from .latent_diffusion import LatentDiffusion
|
5 |
+
|
6 |
+
|
7 |
+
class Diffpro_SDF(nn.Module):
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
ldm: LatentDiffusion,
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
cond_type: {chord, texture}
|
15 |
+
cond_mode: {cond, mix, uncond}
|
16 |
+
mix: use a special condition for unconditional learning with probability of 0.2
|
17 |
+
use_enc: whether to use pretrained chord encoder to generate encoded condition
|
18 |
+
"""
|
19 |
+
super(Diffpro_SDF, self).__init__()
|
20 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
self.ldm = ldm
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def load_trained(
|
25 |
+
cls,
|
26 |
+
ldm,
|
27 |
+
chkpt_fpath,
|
28 |
+
):
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
model = cls(ldm)
|
31 |
+
trained_leaner = torch.load(chkpt_fpath, map_location=device)
|
32 |
+
try:
|
33 |
+
model.load_state_dict(trained_leaner["model"])
|
34 |
+
except RuntimeError:
|
35 |
+
model_dict = trained_leaner["model"]
|
36 |
+
model_dict = {k.replace('cond_enc', 'autoreg_cond_enc'): v for k, v in model_dict.items()}
|
37 |
+
model_dict = {k.replace('style_enc', 'external_cond_enc'): v for k, v in model_dict.items()}
|
38 |
+
model.load_state_dict(model_dict)
|
39 |
+
return model
|
40 |
+
|
41 |
+
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
|
42 |
+
return self.ldm.p_sample(xt, t)
|
43 |
+
|
44 |
+
def q_sample(self, x0: torch.Tensor, t: torch.Tensor):
|
45 |
+
return self.ldm.q_sample(x0, t)
|
46 |
+
|
47 |
+
def get_loss_dict(self, batch, step):
|
48 |
+
"""
|
49 |
+
z_y is the stuff the diffusion model needs to learn
|
50 |
+
"""
|
51 |
+
# x = batch.float().to(self.device)
|
52 |
+
|
53 |
+
x= batch
|
54 |
+
loss = self.ldm.loss(x)
|
55 |
+
return {"loss": loss}
|
model/sampler_sdf.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List, Union
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from labml import monit
|
5 |
+
from .latent_diffusion import LatentDiffusion
|
6 |
+
|
7 |
+
def set_seed(seed):
|
8 |
+
np.random.seed(seed)
|
9 |
+
torch.manual_seed(seed)
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
torch.cuda.manual_seed(seed)
|
12 |
+
torch.cuda.manual_seed_all(seed)
|
13 |
+
|
14 |
+
# Call the function to set the seed
|
15 |
+
# set_seed(42)
|
16 |
+
|
17 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
18 |
+
"""
|
19 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
20 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
21 |
+
"""
|
22 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
23 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
24 |
+
# rescale the results from guidance (fixes overexposure)
|
25 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
26 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
27 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
28 |
+
return noise_cfg
|
29 |
+
|
30 |
+
|
31 |
+
class DiffusionSampler:
|
32 |
+
"""
|
33 |
+
## Base class for sampling algorithms
|
34 |
+
"""
|
35 |
+
model: LatentDiffusion
|
36 |
+
|
37 |
+
def __init__(self, model: LatentDiffusion):
|
38 |
+
"""
|
39 |
+
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
# Set the model $\epsilon_\text{cond}(x_t, c)$
|
43 |
+
self.model = model
|
44 |
+
# Get number of steps the model was trained with $T$
|
45 |
+
self.n_steps = model.n_steps
|
46 |
+
|
47 |
+
|
48 |
+
class SDFSampler(DiffusionSampler):
|
49 |
+
"""
|
50 |
+
## DDPM Sampler
|
51 |
+
|
52 |
+
This extends the [`DiffusionSampler` base class](index.html).
|
53 |
+
|
54 |
+
DDPM samples images by repeatedly removing noise by sampling step by step from
|
55 |
+
$p_\theta(x_{t-1} | x_t)$,
|
56 |
+
|
57 |
+
\begin{align}
|
58 |
+
|
59 |
+
p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\
|
60 |
+
|
61 |
+
\mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
|
62 |
+
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
|
63 |
+
|
64 |
+
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\
|
65 |
+
|
66 |
+
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\
|
67 |
+
|
68 |
+
\end{align}
|
69 |
+
"""
|
70 |
+
|
71 |
+
model: LatentDiffusion
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
model: LatentDiffusion,
|
76 |
+
max_l,
|
77 |
+
h,
|
78 |
+
is_autocast=False,
|
79 |
+
is_show_image=False,
|
80 |
+
device=None,
|
81 |
+
debug_mode=False
|
82 |
+
):
|
83 |
+
"""
|
84 |
+
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
|
85 |
+
"""
|
86 |
+
super().__init__(model)
|
87 |
+
if device is None:
|
88 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
89 |
+
else:
|
90 |
+
self.device = device
|
91 |
+
# selected time steps ($\tau$) $1, 2, \dots, T$
|
92 |
+
# self.time_steps = np.asarray(list(range(self.n_steps)), dtype=np.int32)
|
93 |
+
self.tau = torch.tensor([13, 53, 116, 193, 310, 443, 587, 730, 845, 999], device=self.device) # torch.tensor([999, 845, 730, 587, 443, 310, 193, 116, 53, 13])
|
94 |
+
# self.tau = torch.tensor(np.asarray(list(range(self.n_steps)), dtype=np.int32), device=self.device)
|
95 |
+
self.used_n_steps = len(self.tau)
|
96 |
+
|
97 |
+
self.is_show_image = is_show_image
|
98 |
+
|
99 |
+
self.autocast = torch.cuda.amp.autocast(enabled=is_autocast)
|
100 |
+
|
101 |
+
self.out_channel = self.model.eps_model.out_channels
|
102 |
+
self.max_l = max_l
|
103 |
+
self.h = h
|
104 |
+
self.debug_mode = debug_mode
|
105 |
+
self.guidance_scale = 7.5
|
106 |
+
self.guidance_rescale = 0.7
|
107 |
+
|
108 |
+
# now, we set the coefficients
|
109 |
+
with torch.no_grad():
|
110 |
+
# $\bar\alpha_t$
|
111 |
+
self.alpha_bar = self.model.alpha_bar
|
112 |
+
# $\beta_t$ schedule
|
113 |
+
beta = self.model.beta
|
114 |
+
# $\bar\alpha_{t-1}$
|
115 |
+
self.alpha_bar_prev = torch.cat([self.alpha_bar.new_tensor([1.]), self.alpha_bar[:-1]])
|
116 |
+
# $\sigma_t$ in DDIM
|
117 |
+
self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev)) # DDPM noise schedule
|
118 |
+
|
119 |
+
# $\frac{1}{\sqrt{\bar\alpha}}$
|
120 |
+
self.one_over_sqrt_alpha_bar = 1 / (self.alpha_bar ** 0.5)
|
121 |
+
# $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$
|
122 |
+
self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar = (1 - self.alpha_bar)**0.5 / self.alpha_bar**0.5
|
123 |
+
|
124 |
+
# $\sqrt{\bar\alpha}$
|
125 |
+
self.sqrt_alpha_bar = self.alpha_bar ** 0.5
|
126 |
+
# $\sqrt{1 - \bar\alpha}$
|
127 |
+
self.sqrt_1m_alpha_bar = (1 - self.alpha_bar) ** 0.5
|
128 |
+
# # $\sqrt{\bar\alpha_{t-1}}$
|
129 |
+
# self.sqrt_alpha_bar_prev = self.alpha_bar_prev ** 0.5
|
130 |
+
# # $\sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}$
|
131 |
+
# self.sqrt_1m_alpha_bar_prev_m_sigma2 = (1 - self.alpha_bar_prev - self.sigma_ddim ** 2) ** 0.5
|
132 |
+
|
133 |
+
#@property
|
134 |
+
# def d_cond(self):
|
135 |
+
#return self.model.eps_model.d_cond
|
136 |
+
|
137 |
+
def get_eps(
|
138 |
+
self,
|
139 |
+
x: torch.Tensor,
|
140 |
+
t: torch.Tensor,
|
141 |
+
background_cond: Optional[torch.Tensor],
|
142 |
+
|
143 |
+
uncond_scale: Optional[float],
|
144 |
+
):
|
145 |
+
"""
|
146 |
+
## Get $\epsilon(x_t, c)$
|
147 |
+
|
148 |
+
:param x: is $x_t$ of shape `[batch_size, channels, height, width]`
|
149 |
+
:param t: is $t$ of shape `[batch_size]`
|
150 |
+
:param background_cond: background condition
|
151 |
+
:param autoreg_cond: autoregressive condition
|
152 |
+
:param external_cond: external condition
|
153 |
+
:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
|
154 |
+
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
155 |
+
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
156 |
+
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
157 |
+
"""
|
158 |
+
# When the scale $s = 1$
|
159 |
+
# $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$
|
160 |
+
|
161 |
+
batch_size = x.size(0)
|
162 |
+
|
163 |
+
# if hasattr(self.model, 'style_enc'):
|
164 |
+
# if external_cond is not None:
|
165 |
+
# external_cond = self.model.external_cond_enc(external_cond)
|
166 |
+
# if uncond_scale is None or uncond_scale == 1:
|
167 |
+
# external_uncond = (-torch.ones_like(external_cond)).to(self.device)
|
168 |
+
# else:
|
169 |
+
# external_uncond = None
|
170 |
+
# # if random.random() < 0.2:
|
171 |
+
# # external_cond = (-torch.ones_like(external_cond)).to(self.device)
|
172 |
+
# else:
|
173 |
+
# external_cond = -torch.ones(batch_size, 4, self.d_cond, device=x.device, dtype=x.dtype)
|
174 |
+
# external_uncond = None
|
175 |
+
# cond = torch.cat([autoreg_cond, external_cond], 1)
|
176 |
+
# if external_uncond is None:
|
177 |
+
# uncond = None
|
178 |
+
# else:
|
179 |
+
# uncond = torch.cat([autoreg_cond, external_uncond], 1)
|
180 |
+
# else:
|
181 |
+
# cond = autoreg_cond
|
182 |
+
# uncond = None
|
183 |
+
|
184 |
+
if background_cond is not None:
|
185 |
+
x = torch.cat([x, background_cond], 1) if background_cond is not None else x
|
186 |
+
|
187 |
+
# if uncond is None:
|
188 |
+
# e_t = self.model(x, t, cond)
|
189 |
+
# else:
|
190 |
+
# e_t_cond = self.model(x, t, cond)
|
191 |
+
# e_t_uncond = self.model(x, t, uncond)
|
192 |
+
# e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
|
193 |
+
|
194 |
+
e_t = self.model(x,t)
|
195 |
+
return e_t
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
def p_sample(
|
199 |
+
self,
|
200 |
+
x: torch.Tensor,
|
201 |
+
background_cond: Optional[torch.Tensor],
|
202 |
+
#autoreg_cond: Optional[torch.Tensor],
|
203 |
+
#external_cond: Optional[torch.Tensor],
|
204 |
+
t: torch.Tensor,
|
205 |
+
step: int,
|
206 |
+
repeat_noise: bool = False,
|
207 |
+
temperature: float = 1.,
|
208 |
+
uncond_scale: float = 1.,
|
209 |
+
same_noise_all_measure: bool = False,
|
210 |
+
X0EditFunc = None,
|
211 |
+
use_classifier_free_guidance = False,
|
212 |
+
use_lsh = False,
|
213 |
+
reduce_extra_notes=True,
|
214 |
+
rhythm_control="Yes",
|
215 |
+
):
|
216 |
+
print("p_sample")
|
217 |
+
"""
|
218 |
+
### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$
|
219 |
+
|
220 |
+
:param x: is $x_t$ of shape `[batch_size, channels, height, width]`
|
221 |
+
:param background_cond: background condition
|
222 |
+
:param autoreg_cond: autoregressive condition
|
223 |
+
:param external_cond: external condition
|
224 |
+
:param t: is $t$ of shape `[batch_size]`
|
225 |
+
:param step: is the step $t$ as an integer
|
226 |
+
:param repeat_noise: specified whether the noise should be same for all samples in the batch
|
227 |
+
:param temperature: is the noise temperature (random noise gets multiplied by this)
|
228 |
+
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
229 |
+
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
230 |
+
"""
|
231 |
+
# Get current tau_i and tau_{i-1}
|
232 |
+
tau_i = self.tau[t]
|
233 |
+
step_tau_i = self.tau[step]
|
234 |
+
|
235 |
+
# Get $\epsilon_\theta$
|
236 |
+
with self.autocast:
|
237 |
+
if use_classifier_free_guidance:
|
238 |
+
if use_lsh:
|
239 |
+
assert background_cond.shape[1] == 6 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain, lsh_onset, lsh_sustain
|
240 |
+
null_lsh = -torch.ones_like(background_cond[:,4:,:,:])
|
241 |
+
null_background_cond = torch.cat([background_cond[:,2:4,:,:], null_lsh], axis=1)
|
242 |
+
real_background_cond = torch.cat([background_cond[:,:2,:,:], background_cond[:,4:,:,:]], axis=1)
|
243 |
+
|
244 |
+
e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale)
|
245 |
+
e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale)
|
246 |
+
e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null)
|
247 |
+
if self.guidance_rescale > 0:
|
248 |
+
e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale)
|
249 |
+
else:
|
250 |
+
assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain
|
251 |
+
null_background_cond = background_cond[:,2:,:,:]
|
252 |
+
real_background_cond = background_cond[:,:2,:,:]
|
253 |
+
e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale)
|
254 |
+
e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale)
|
255 |
+
e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null)
|
256 |
+
if self.guidance_rescale > 0:
|
257 |
+
e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale)
|
258 |
+
else:
|
259 |
+
if use_lsh:
|
260 |
+
assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, lsh_onset, lsh_sustain
|
261 |
+
e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale)
|
262 |
+
else:
|
263 |
+
assert background_cond.shape[1] == 2 # chd_onset, chd_sustain
|
264 |
+
e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale)
|
265 |
+
|
266 |
+
# Get batch size
|
267 |
+
bs = x.shape[0]
|
268 |
+
|
269 |
+
# $\frac{1}{\sqrt{\bar\alpha}}$
|
270 |
+
one_over_sqrt_alpha_bar = x.new_full(
|
271 |
+
(bs, 1, 1, 1), self.one_over_sqrt_alpha_bar[step_tau_i]
|
272 |
+
)
|
273 |
+
# $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$
|
274 |
+
sqrt_1m_alpha_bar_over_sqrt_alpha_bar = x.new_full(
|
275 |
+
(bs, 1, 1, 1), self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar[step_tau_i]
|
276 |
+
)
|
277 |
+
|
278 |
+
# $\sigma_t$ in DDIM
|
279 |
+
sigma_ddim = x.new_full(
|
280 |
+
(bs, 1, 1, 1), self.sigma_ddim[step_tau_i]
|
281 |
+
)
|
282 |
+
|
283 |
+
|
284 |
+
# Calculate $x_0$ with current $\epsilon_\theta$
|
285 |
+
#
|
286 |
+
# predicted x_0 in DDIM
|
287 |
+
predicted_x0 = one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - sqrt_1m_alpha_bar_over_sqrt_alpha_bar * e_tau_i
|
288 |
+
|
289 |
+
# edit predicted x_0
|
290 |
+
if X0EditFunc is not None:
|
291 |
+
predicted_x0 = X0EditFunc(predicted_x0, background_cond, reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control)
|
292 |
+
e_tau_i = (one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - predicted_x0) / sqrt_1m_alpha_bar_over_sqrt_alpha_bar
|
293 |
+
|
294 |
+
# Do not add noise when $t = 1$ (final step sampling process).
|
295 |
+
# Note that `step` is `0` when $t = 1$)
|
296 |
+
if step == 0:
|
297 |
+
noise = 0
|
298 |
+
# If same noise is used for all samples in the batch
|
299 |
+
elif repeat_noise:
|
300 |
+
if same_noise_all_measure:
|
301 |
+
noise = torch.randn((1, predicted_x0.shape[1], 16, predicted_x0.shape[3]), device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1)
|
302 |
+
else:
|
303 |
+
noise = torch.randn((1, *predicted_x0.shape[1:]), device=self.device)
|
304 |
+
# Different noise for each sample
|
305 |
+
else:
|
306 |
+
if same_noise_all_measure:
|
307 |
+
noise = torch.randn(predicted_x0.shape[0], predicted_x0.shape[1], 16, predicted_x0.shape[3], device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1)
|
308 |
+
else:
|
309 |
+
noise = torch.randn(predicted_x0.shape, device=self.device)
|
310 |
+
|
311 |
+
# Multiply noise by the temperature
|
312 |
+
noise = noise * temperature
|
313 |
+
|
314 |
+
if step > 0:
|
315 |
+
step_tau_i_m_1 = self.tau[step-1]
|
316 |
+
# $\sqrt{\bar\alpha_{\tau_i-1}}$
|
317 |
+
sqrt_alpha_bar_prev = x.new_full(
|
318 |
+
(bs, 1, 1, 1), self.sqrt_alpha_bar[step_tau_i_m_1]
|
319 |
+
)
|
320 |
+
# $\sqrt{1-\bar\alpha_{\tau_i-1}-\sigma_\tau^2}$
|
321 |
+
sqrt_1m_alpha_bar_prev_m_sigma2 = x.new_full(
|
322 |
+
(bs, 1, 1, 1), (1 - self.alpha_bar[step_tau_i_m_1] - self.sigma_ddim[step_tau_i] ** 2) ** 0.5
|
323 |
+
)
|
324 |
+
direction_to_xt = sqrt_1m_alpha_bar_prev_m_sigma2 * e_tau_i
|
325 |
+
x_prev = sqrt_alpha_bar_prev * predicted_x0 + direction_to_xt + sigma_ddim * noise
|
326 |
+
else:
|
327 |
+
x_prev = predicted_x0 + sigma_ddim * noise
|
328 |
+
|
329 |
+
# Sample from,
|
330 |
+
#
|
331 |
+
# $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$
|
332 |
+
|
333 |
+
#
|
334 |
+
return x_prev, predicted_x0, e_tau_i
|
335 |
+
|
336 |
+
@torch.no_grad()
|
337 |
+
def q_sample(
|
338 |
+
self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None
|
339 |
+
):
|
340 |
+
"""
|
341 |
+
### Sample from $q(x_t|x_0)$
|
342 |
+
|
343 |
+
$$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$
|
344 |
+
|
345 |
+
:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
|
346 |
+
:param index: is the time step $t$ index
|
347 |
+
:param noise: is the noise, $\epsilon$
|
348 |
+
"""
|
349 |
+
|
350 |
+
# Random noise, if noise is not specified
|
351 |
+
if noise is None:
|
352 |
+
noise = torch.randn_like(x0, device=self.device)
|
353 |
+
|
354 |
+
# Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$
|
355 |
+
return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise
|
356 |
+
|
357 |
+
@torch.no_grad()
|
358 |
+
def sample(
|
359 |
+
self,
|
360 |
+
shape: List[int],
|
361 |
+
background_cond: Optional[torch.Tensor] = None,
|
362 |
+
#autoreg_cond: Optional[torch.Tensor] = None,
|
363 |
+
#external_cond: Optional[torch.Tensor] = None,
|
364 |
+
repeat_noise: bool = False,
|
365 |
+
temperature: float = 1.,
|
366 |
+
uncond_scale: float = 1.,
|
367 |
+
x_last: Optional[torch.Tensor] = None,
|
368 |
+
t_start: int = 0,
|
369 |
+
same_noise_all_measure: bool = False,
|
370 |
+
X0EditFunc = None,
|
371 |
+
use_classifier_free_guidance = False,
|
372 |
+
use_lsh = False,
|
373 |
+
reduce_extra_notes=True,
|
374 |
+
rhythm_control="Yes",
|
375 |
+
):
|
376 |
+
"""
|
377 |
+
### Sampling Loop
|
378 |
+
|
379 |
+
:param shape: is the shape of the generated images in the
|
380 |
+
form `[batch_size, channels, height, width]`
|
381 |
+
:param background_cond: background condition
|
382 |
+
:param autoreg_cond: autoregressive condition
|
383 |
+
:param external_cond: external condition
|
384 |
+
:param repeat_noise: specified whether the noise should be same for all samples in the batch
|
385 |
+
:param temperature: is the noise temperature (random noise gets multiplied by this)
|
386 |
+
:param x_last: is $x_T$. If not provided random noise will be used.
|
387 |
+
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
388 |
+
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
389 |
+
:param t_start: t_start
|
390 |
+
"""
|
391 |
+
|
392 |
+
# Get device and batch size
|
393 |
+
bs = shape[0]
|
394 |
+
|
395 |
+
######
|
396 |
+
print(shape)
|
397 |
+
######
|
398 |
+
|
399 |
+
|
400 |
+
# Get $x_T$
|
401 |
+
if same_noise_all_measure:
|
402 |
+
x = x_last if x_last is not None else torch.randn(shape[0],shape[1],16,shape[3], device=self.device).repeat(1,1,int(shape[2]/16),1)
|
403 |
+
else:
|
404 |
+
x = x_last if x_last is not None else torch.randn(shape, device=self.device)
|
405 |
+
|
406 |
+
# Time steps to sample at $T - t', T - t' - 1, \dots, 1$
|
407 |
+
time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:]
|
408 |
+
|
409 |
+
# Sampling loop
|
410 |
+
for step in monit.iterate('Sample', time_steps):
|
411 |
+
# Time step $t$
|
412 |
+
ts = x.new_full((bs, ), step, dtype=torch.long)
|
413 |
+
|
414 |
+
x, pred_x0, e_t = self.p_sample(
|
415 |
+
x,
|
416 |
+
background_cond,
|
417 |
+
#autoreg_cond,
|
418 |
+
#external_cond,
|
419 |
+
ts,
|
420 |
+
step,
|
421 |
+
repeat_noise=repeat_noise,
|
422 |
+
temperature=temperature,
|
423 |
+
uncond_scale=uncond_scale,
|
424 |
+
same_noise_all_measure=same_noise_all_measure,
|
425 |
+
X0EditFunc = X0EditFunc,
|
426 |
+
use_classifier_free_guidance = use_classifier_free_guidance,
|
427 |
+
use_lsh=use_lsh,
|
428 |
+
reduce_extra_notes=reduce_extra_notes,
|
429 |
+
rhythm_control=rhythm_control
|
430 |
+
)
|
431 |
+
|
432 |
+
s1 = step + 1
|
433 |
+
|
434 |
+
# if self.is_show_image:
|
435 |
+
# if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0):
|
436 |
+
# show_image(x, f"exp/img/x{s1}.png")
|
437 |
+
|
438 |
+
# Return $x_0$
|
439 |
+
# if self.is_show_image:
|
440 |
+
# show_image(x, f"exp/img/x0.png")
|
441 |
+
|
442 |
+
return x
|
443 |
+
|
444 |
+
@torch.no_grad()
|
445 |
+
def paint(
|
446 |
+
self,
|
447 |
+
x: Optional[torch.Tensor] = None,
|
448 |
+
background_cond: Optional[torch.Tensor] = None,
|
449 |
+
#autoreg_cond: Optional[torch.Tensor] = None,
|
450 |
+
#external_cond: Optional[torch.Tensor] = None,
|
451 |
+
t_start: int = 0,
|
452 |
+
orig: Optional[torch.Tensor] = None,
|
453 |
+
mask: Optional[torch.Tensor] = None,
|
454 |
+
orig_noise: Optional[torch.Tensor] = None,
|
455 |
+
uncond_scale: float = 1.,
|
456 |
+
same_noise_all_measure: bool = False,
|
457 |
+
X0EditFunc = None,
|
458 |
+
use_classifier_free_guidance = False,
|
459 |
+
use_lsh = False,
|
460 |
+
):
|
461 |
+
"""
|
462 |
+
### Painting Loop
|
463 |
+
|
464 |
+
:param x: is $x_{S'}$ of shape `[batch_size, channels, height, width]`
|
465 |
+
:param background_cond: background condition
|
466 |
+
:param autoreg_cond: autoregressive condition
|
467 |
+
:param external_cond: external condition
|
468 |
+
:param t_start: is the sampling step to start from, $S'$
|
469 |
+
:param orig: is the original image in latent page which we are in paining.
|
470 |
+
If this is not provided, it'll be an image to image transformation.
|
471 |
+
:param mask: is the mask to keep the original image.
|
472 |
+
:param orig_noise: is fixed noise to be added to the original image.
|
473 |
+
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
474 |
+
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
475 |
+
"""
|
476 |
+
# Get batch size
|
477 |
+
bs = orig.size(0)
|
478 |
+
|
479 |
+
if x is None:
|
480 |
+
x = torch.randn(orig.shape, device=self.device)
|
481 |
+
|
482 |
+
# Time steps to sample at $\tau_{S`}, \tau_{S' - 1}, \dots, \tau_1$
|
483 |
+
# time_steps = np.flip(self.time_steps[: t_start])
|
484 |
+
time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:]
|
485 |
+
|
486 |
+
for i, step in monit.enum('Paint', time_steps):
|
487 |
+
# Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
|
488 |
+
# index = len(time_steps) - i - 1
|
489 |
+
# Time step $\tau_i$
|
490 |
+
ts = x.new_full((bs, ), step, dtype=torch.long)
|
491 |
+
|
492 |
+
# Sample $x_{\tau_{i-1}}$
|
493 |
+
x, _, _ = self.p_sample(
|
494 |
+
x,
|
495 |
+
background_cond,
|
496 |
+
#autoreg_cond,
|
497 |
+
#external_cond,
|
498 |
+
t=ts,
|
499 |
+
step=step,
|
500 |
+
uncond_scale=uncond_scale,
|
501 |
+
same_noise_all_measure=same_noise_all_measure,
|
502 |
+
X0EditFunc = X0EditFunc,
|
503 |
+
use_classifier_free_guidance = use_classifier_free_guidance,
|
504 |
+
use_lsh=use_lsh,
|
505 |
+
)
|
506 |
+
|
507 |
+
# Replace the masked area with original image
|
508 |
+
if orig is not None:
|
509 |
+
assert mask is not None
|
510 |
+
# Get the $q_{\sigma,\tau}(x_{\tau_i}|x_0)$ for original image in latent space
|
511 |
+
orig_t = self.q_sample(orig, self.tau[step], noise=orig_noise)
|
512 |
+
# Replace the masked area
|
513 |
+
x = orig_t * mask + x * (1 - mask)
|
514 |
+
|
515 |
+
s1 = step + 1
|
516 |
+
|
517 |
+
# if self.is_show_image:
|
518 |
+
# if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0):
|
519 |
+
# show_image(x, f"exp/img/x{s1}.png")
|
520 |
+
|
521 |
+
# if self.is_show_image:
|
522 |
+
# show_image(x, f"exp/img/x0.png")
|
523 |
+
return x
|
524 |
+
|
525 |
+
def generate(self, background_cond=None, batch_size=1, uncond_scale=None,
|
526 |
+
same_noise_all_measure=False, X0EditFunc=None,
|
527 |
+
use_classifier_free_guidance=False, use_lsh=False, reduce_extra_notes=True, rhythm_control="Yes"):
|
528 |
+
|
529 |
+
shape = [batch_size, self.out_channel, self.max_l, self.h]
|
530 |
+
|
531 |
+
if self.debug_mode:
|
532 |
+
return torch.randn(shape, dtype=torch.float)
|
533 |
+
|
534 |
+
return self.sample(shape, background_cond, uncond_scale=uncond_scale, same_noise_all_measure=same_noise_all_measure,
|
535 |
+
X0EditFunc=X0EditFunc, use_classifier_free_guidance=use_classifier_free_guidance, use_lsh=use_lsh,
|
536 |
+
reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control
|
537 |
+
)
|
538 |
+
|
output_0.mid
ADDED
Binary file (376 Bytes). View file
|
|
output_0.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9abffb8b039f86161f025cab6419eecf93ec741ea67e66964ca8e79d333c9d4
|
3 |
+
size 2469720
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
TiMidity++
|
piano_roll.png
ADDED
![]() |
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio#==4.44.0
|
2 |
+
imageio#==2.35.1
|
3 |
+
imageio[ffmpeg]
|
4 |
+
labml#==0.5.3
|
5 |
+
librosa
|
6 |
+
mir_eval
|
7 |
+
matplotlib
|
8 |
+
music21
|
9 |
+
numba==0.53.1
|
10 |
+
numpy==1.19.5
|
11 |
+
opencv-python
|
12 |
+
# pandas==1.2.5
|
13 |
+
pretty_midi
|
14 |
+
pydub
|
15 |
+
requests
|
16 |
+
soundfile
|
17 |
+
fluidsynth
|
18 |
+
scikit-learn
|
19 |
+
torch==2.4.1
|
20 |
+
torchvision
|
21 |
+
tqdm
|
22 |
+
tensorboard
|
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:852cfd7b011bb1ba0ce1d0d05a7acd672c7cd4934756b2e0d357d8002b5ecb6b
|
3 |
+
size 441623592
|
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0
ADDED
Binary file (157 Bytes). View file
|
|
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2
ADDED
Binary file (10.2 kB). View file
|
|
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1
ADDED
Binary file (10.2 kB). View file
|
|
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4
ADDED
Binary file (592 Bytes). View file
|
|
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3
ADDED
Binary file (592 Bytes). View file
|
|
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"batch_size": 16, "max_epoch": 10, "learning_rate": 5e-05, "max_grad_norm": 10, "fp16": true, "in_channels": 6, "out_channels": 2, "channels": 64, "attention_levels": [2, 3], "n_res_blocks": 2, "channel_multipliers": [1, 2, 4, 4], "n_heads": 4, "tf_layers": 1, "d_cond": 2, "linear_start": 0.00085, "linear_end": 0.012, "n_steps": 1000, "latent_scaling_factor": 0.18215}
|
rhythm_plot_0.png
ADDED
![]() |
Git LFS Details
|
runtime.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python-3.9
|
samples/control_vs_uncontrol/example_1_acc_control.jpg
ADDED
![]() |
Git LFS Details
|
samples/control_vs_uncontrol/example_1_acc_control.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98254caf2f2943c0f4bf4a07cfddf0c0dfa8bc33f41e5aafcd2626f33763b680
|
3 |
+
size 2469720
|
samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg
ADDED
![]() |
Git LFS Details
|
samples/control_vs_uncontrol/example_1_acc_uncontrol.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:582e9d559424e40216f3d7e2ded5b07844311da1783fd77792b58a27fa27ba8c
|
3 |
+
size 2469720
|
samples/control_vs_uncontrol/example_1_mel_chd.jpg
ADDED
![]() |
Git LFS Details
|
samples/control_vs_uncontrol/example_1_mel_chd.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cedfa10bc44605e9cd1c0a04a82d90f7b6c31dfe68d790a7d9c8e6e27117c90e
|
3 |
+
size 5292046
|
samples/control_vs_uncontrol/example_2_acc_control.jpg
ADDED
![]() |
Git LFS Details
|
samples/control_vs_uncontrol/example_2_acc_control.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6e70d13d7d97c0cb3cbf38d9ad2473f5df5eb04c69a54ddb228aca1134f819a
|
3 |
+
size 2364108
|
samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg
ADDED
![]() |
Git LFS Details
|
samples/control_vs_uncontrol/example_2_acc_uncontrol.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a60053aca94dac302379e536774832a3f90bf2fd7a0d7a23c2dc97b742f78910
|
3 |
+
size 2364108
|
samples/control_vs_uncontrol/example_2_mel_chd.jpg
ADDED
![]() |
Git LFS Details
|