interactive-symbolic-music commited on
Commit
62f1377
·
0 Parent(s):

Initial commit with cleaned history

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +6 -0
  3. Aptfile +2 -0
  4. README.md +18 -0
  5. __pycache__/app.cpython-39.pyc +0 -0
  6. __pycache__/learner.cpython-39.pyc +0 -0
  7. __pycache__/params.cpython-39.pyc +0 -0
  8. __pycache__/train_params.cpython-39.pyc +0 -0
  9. app.py +508 -0
  10. filter_data/filter_by_instrument.ipynb +353 -0
  11. filter_data/midi_utils.py +139 -0
  12. generation/__pycache__/gen_utils.cpython-39.pyc +0 -0
  13. generation/gen_utils.py +302 -0
  14. model/__init__.py +59 -0
  15. model/__pycache__/__init__.cpython-39.pyc +0 -0
  16. model/__pycache__/latent_diffusion.cpython-39.pyc +0 -0
  17. model/__pycache__/model_sdf.cpython-39.pyc +0 -0
  18. model/__pycache__/sampler_sdf.cpython-39.pyc +0 -0
  19. model/architecture/__pycache__/unet.cpython-39.pyc +0 -0
  20. model/architecture/__pycache__/unet_attention.cpython-39.pyc +0 -0
  21. model/architecture/unet.py +364 -0
  22. model/architecture/unet_attention.py +321 -0
  23. model/latent_diffusion.py +222 -0
  24. model/model_sdf.py +55 -0
  25. model/sampler_sdf.py +538 -0
  26. output_0.mid +0 -0
  27. output_0.wav +3 -0
  28. packages.txt +1 -0
  29. piano_roll.png +3 -0
  30. requirements.txt +22 -0
  31. results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt +3 -0
  32. results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0 +0 -0
  33. results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2 +0 -0
  34. results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1 +0 -0
  35. results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4 +0 -0
  36. results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3 +0 -0
  37. results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json +1 -0
  38. rhythm_plot_0.png +3 -0
  39. runtime.txt +1 -0
  40. samples/control_vs_uncontrol/example_1_acc_control.jpg +3 -0
  41. samples/control_vs_uncontrol/example_1_acc_control.wav +3 -0
  42. samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg +3 -0
  43. samples/control_vs_uncontrol/example_1_acc_uncontrol.wav +3 -0
  44. samples/control_vs_uncontrol/example_1_mel_chd.jpg +3 -0
  45. samples/control_vs_uncontrol/example_1_mel_chd.wav +3 -0
  46. samples/control_vs_uncontrol/example_2_acc_control.jpg +3 -0
  47. samples/control_vs_uncontrol/example_2_acc_control.wav +3 -0
  48. samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg +3 -0
  49. samples/control_vs_uncontrol/example_2_acc_uncontrol.wav +3 -0
  50. samples/control_vs_uncontrol/example_2_mel_chd.jpg +3 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
3
+ *.wav filter=lfs diff=lfs merge=lfs -text
4
+ *.png filter=lfs diff=lfs merge=lfs -text
5
+ *.jpg filter=lfs diff=lfs merge=lfs -text
6
+ *.pth filter=lfs diff=lfs merge=lfs -text
Aptfile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ lilypond
2
+ fluidsynth
README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Interactive Symbolic Music Demo
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ python_version: 3.9.19
11
+ license: mit
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ Please find mido.MidoFile inside the pretty_midi package, and set all arg "clip" to clip=True
17
+
18
+ use set_seed(42) in sampler_sdf.py, generation result from chord slice (index = 2) is a good example (a wrong note is shifted to a correct one)
__pycache__/app.cpython-39.pyc ADDED
Binary file (12.6 kB). View file
 
__pycache__/learner.cpython-39.pyc ADDED
Binary file (7.06 kB). View file
 
__pycache__/params.cpython-39.pyc ADDED
Binary file (1.27 kB). View file
 
__pycache__/train_params.cpython-39.pyc ADDED
Binary file (1.38 kB). View file
 
app.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pretty_midi
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import cv2
7
+ import imageio
8
+
9
+ import sys
10
+ import subprocess
11
+ import os
12
+ import torch
13
+ from model import init_ldm_model
14
+ from model.model_sdf import Diffpro_SDF
15
+ from model.sampler_sdf import SDFSampler
16
+
17
+ import pickle
18
+ from train.train_params import params_chord_lsh_cond
19
+ from generation.gen_utils import *
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt'
23
+ chord_list = list(CHORD_DICTIONARY.keys())
24
+
25
+ def get_shape(file_path):
26
+ if file_path.endswith('.jpg'):
27
+ img = cv2.imread(file_path)
28
+ return img.shape # (height, width, channels)
29
+
30
+ elif file_path.endswith('.mp4'):
31
+ vid = imageio.get_reader(file_path)
32
+ return vid.get_meta_data()['size'] # (width, height)
33
+
34
+ else:
35
+ raise ValueError("Unsupported file type")
36
+
37
+ # Function to convert MIDI to WAV
38
+ def midi_to_wav(midi, output_file):
39
+ # Synthesize the waveform from the MIDI using pretty_midi
40
+ audio_data = midi.fluidsynth()
41
+
42
+ # Write the waveform to a WAV file
43
+ sf.write(output_file, audio_data, samplerate=44100)
44
+
45
+ def update_musescore_image(selected_prompt):
46
+ # Logic to return the correct image file based on the selected prompt
47
+ if selected_prompt == "example 1":
48
+ return "samples/diy_examples/example1/example1.jpg"
49
+ elif selected_prompt == "example 2":
50
+ return "samples/diy_examples/example2/example2.jpg"
51
+ elif selected_prompt == "example 3":
52
+ return "samples/diy_examples/example3/example3.jpg"
53
+ elif selected_prompt == "example 4":
54
+ return "samples/diy_examples/example4/example4.jpg"
55
+ elif selected_prompt == "example 5":
56
+ return "samples/diy_examples/example5/example5.jpg"
57
+ elif selected_prompt == "example 6":
58
+ return "samples/diy_examples/example6/example6.jpg"
59
+
60
+
61
+ # Model for generating music (example)
62
+ def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"):
63
+
64
+ ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False)
65
+ model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device)
66
+ sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False)
67
+
68
+ if mode=="example":
69
+ if prompt == "example 1":
70
+ background_condition = np.load("samples/diy_examples/example1/example1.npy")
71
+ tempo=70
72
+ elif prompt == "example 2":
73
+ background_condition = np.load("samples/diy_examples/example2/example2.npy")
74
+ elif prompt == "example 3":
75
+ background_condition = np.load("samples/diy_examples/example3/example3.npy")
76
+ elif prompt == "example 4":
77
+ background_condition = np.load("samples/diy_examples/example4/example4.npy")
78
+
79
+ background_condition = np.tile(background_condition, (num_samples,1,1,1))
80
+ background_condition = torch.Tensor(background_condition).to(device)
81
+ else:
82
+ background_condition = np.tile(prompt, (num_samples,1,1,1))
83
+ background_condition = torch.Tensor(background_condition).to(device)
84
+
85
+ if rhythm_control!="Yes":
86
+ background_condition[:,0:2] = background_condition[:,2:4]
87
+ # generate samples
88
+ output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples,
89
+ same_noise_all_measure=False, X0EditFunc=X0EditFunc,
90
+ use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False,
91
+ rhythm_control=rhythm_control)
92
+ output_x = torch.clamp(output_x, min=0, max=1)
93
+ output_x = output_x.cpu().numpy()
94
+
95
+ # save samples
96
+ for i in range(num_samples):
97
+ full_roll = extend_piano_roll(output_x[i]) # accompaniment roll
98
+ full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) # chord roll
99
+ full_lsh_roll = None
100
+ if background_condition.shape[1]>=6:
101
+ if background_condition[:,4:6,:,:].min()>=0:
102
+ full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy())
103
+ midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo)
104
+ # filename = f'DDIM_w_rhythm_onset_0to10_{i}_edit_x0_and_eps'+'.mid'
105
+ filename = f"output_{i}.mid"
106
+ save_midi(midi_file, filename)
107
+ subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate()
108
+
109
+ return 'output_0.mid', 'output_0.wav', midi_file
110
+
111
+ # Function to visualize MIDI notes
112
+ def visualize_midi(midi):
113
+ # Get piano roll from MIDI
114
+ roll = midi.get_piano_roll(fs=100)
115
+
116
+ # Plot the piano roll
117
+ plt.figure(figsize=(10, 4))
118
+ plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest')
119
+ plt.title("Piano Roll")
120
+ plt.xlabel("Time")
121
+ plt.ylabel("Pitch")
122
+ plt.colorbar()
123
+
124
+ # Save the plot as an image
125
+ output_image_path = "piano_roll.png"
126
+ plt.savefig(output_image_path)
127
+ return output_image_path
128
+
129
+ def plot_rhythm(rhythm_str, label):
130
+ if rhythm_str=="null rhythm":
131
+ return None
132
+ fig, ax = plt.subplots(figsize=(6, 2))
133
+
134
+ # Ensure it's a 16-bit string
135
+ rhythm_str = rhythm_str[:16]
136
+
137
+ # Convert string to a list of 0s and 1s
138
+ rhythm = [0 if bit=="0" else 1 for bit in rhythm_str]
139
+
140
+ # Define the x axis for the 16 sixteenth notes
141
+ x = list(range(1, 17)) # 1 to 16 sixteenth notes
142
+
143
+ # Plot each note (1 as filled circle, 0 as empty circle)
144
+ for i, bit in enumerate(rhythm):
145
+ if bit == 1:
146
+ ax.scatter(i + 1, 1, color='black', s=100, label="Note" if i == 0 else "")
147
+ else:
148
+ ax.scatter(i + 1, 1, edgecolor='black', facecolor='none', s=100, label="Rest" if i == 0 else "")
149
+
150
+ # Distinguish groups of 4 using vertical dashed lines (no solid grid lines)
151
+ for i in range(4, 17, 4):
152
+ ax.axvline(x=i + 0.5, color='grey', linestyle='--')
153
+
154
+ # Remove solid vertical grid lines by setting the grid off
155
+ ax.grid(False)
156
+
157
+ # Formatting the plot
158
+ ax.set_xlim(0.5, 16.5)
159
+ ax.set_ylim(0.8, 1.2)
160
+ ax.set_xticks(x)
161
+ ax.set_yticks([])
162
+ ax.set_xlabel("16th Notes")
163
+ ax.set_title("Rhythm Pattern")
164
+
165
+ fig.savefig(f'samples/diy_examples/rhythm_plot_{label}.png')
166
+ plt.close(fig)
167
+ return f'samples/diy_examples/rhythm_plot_{label}.png'
168
+
169
+ def adjust_rhythm_string(s):
170
+ # Truncate if longer than 16 characters
171
+ if len(s) > 16:
172
+ return s[:16]
173
+ # Pad with zeros if shorter than 16 characters
174
+ else:
175
+ return s.ljust(16, '0')
176
+ def rhythm_string_to_array(s):
177
+ # Ensure the string is 16 characters long
178
+ s = s[:16].ljust(16, '0') # Truncate or pad with '0' to make it 16 characters
179
+ # Convert to numpy array, treating non-'0' as '1'
180
+ arr = np.array([1 if char != '0' else 0 for char in s], dtype=int)
181
+ arr = arr*np.array([3,1,2,1,3,1,2,1,3,1,2,1,3,1,2,1])
182
+ print(arr)
183
+ return arr
184
+
185
+ # Gradio main function
186
+ def generate_from_example(prompt):
187
+ midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control=False)
188
+ piano_roll_image = visualize_midi(midi)
189
+ return audio_output, piano_roll_image
190
+
191
+ def generate_diy(m1_chord, m2_chord, m3_chord, m4_chord,
192
+ m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm, tempo):
193
+ print("\n\n\n",m1_chord,type(m1_chord), "\n\n\n")
194
+ test_chd_roll = np.concatenate([np.tile(CHORD_DICTIONARY[m1_chord], (16, 1)),
195
+ np.tile(CHORD_DICTIONARY[m2_chord], (16, 1)),
196
+ np.tile(CHORD_DICTIONARY[m3_chord], (16, 1)),
197
+ np.tile(CHORD_DICTIONARY[m4_chord], (16, 1))])
198
+ rhythms = [m1_rhythm, m2_rhythm, m3_rhythm, m4_rhythm]
199
+
200
+ chd_roll = np.concatenate([test_chd_roll[np.newaxis,:,:], test_chd_roll[np.newaxis,:,:]], axis=0)
201
+
202
+ chd_roll = circular_extend(chd_roll)
203
+ chd_roll = -chd_roll-1
204
+
205
+ real_chd_roll = chd_roll
206
+
207
+ melody_roll = -np.ones_like(chd_roll)
208
+
209
+ if "null rhythm" not in rhythms:
210
+ rhythm_full = []
211
+ for i in range(len(rhythms)):
212
+ rhythm = adjust_rhythm_string(rhythms[i])
213
+ rhythm = rhythm_string_to_array(rhythm)
214
+ rhythm_full.append(rhythm)
215
+ rhythm_full = np.concatenate(rhythm_full, axis=0)
216
+
217
+ onset_roll = test_chd_roll*rhythm_full[:, np.newaxis]
218
+ sustain_roll = np.zeros_like(onset_roll)
219
+ no_onset_pos = np.all(onset_roll == 0, axis=-1)
220
+ sustain_roll[no_onset_pos] = test_chd_roll[no_onset_pos]
221
+
222
+ real_chd_roll = np.concatenate([onset_roll[np.newaxis,:,:], sustain_roll[np.newaxis,:,:]], axis=0)
223
+ real_chd_roll = circular_extend(real_chd_roll)
224
+
225
+ background_condition = np.concatenate([real_chd_roll, chd_roll, melody_roll], axis=0)
226
+
227
+ midi_output, audio_output, midi = generate_music(background_condition, tempo, mode="diy")
228
+ piano_roll_image = visualize_midi(midi)
229
+ return midi_output, audio_output, piano_roll_image
230
+
231
+ # Prompt list
232
+ prompt_list = ["example 1", "example 2", "example 3", "example 4"]
233
+ rhythm_list = ["null rhythm", "1010101010101010", "1011101010111010","1111101010111010","1010001010101010","1010101000101010"]
234
+
235
+
236
+ custom_css = """
237
+ .custom-row1 {
238
+ background-color: #fdebd0;
239
+ padding: 10px;
240
+ border-radius: 5px;
241
+ }
242
+ .custom-row2 {
243
+ background-color: #d1f2eb;
244
+ padding: 10px;
245
+ border-radius: 5px;
246
+ }
247
+ .custom-grey {
248
+ background-color: #f0f0f0;
249
+ padding: 10px;
250
+ border-radius: 5px;
251
+ }
252
+ .custom-purple {
253
+ background-color: #d7bde2;
254
+ padding: 10px;
255
+ border-radius: 5px;
256
+ }
257
+ .audio_waveform-container {
258
+ display: none !important;
259
+ }
260
+ """
261
+
262
+
263
+ with gr.Blocks(css=custom_css) as demo:
264
+ gr.Markdown("# <div style='text-align: center;font-size:40px'> Efficient Fine-Grained Guidance for Diffusion-Based Symbolic Music Generation <div style='text-align: center;'>")
265
+
266
+ gr.Markdown("<span style='font-size:25px;'> We introduce **Fine-Grained Guidance (FG)**, an efficient approach for symbolic music generation using **diffusion models**. Our method enhances guidance through:\
267
+ \n &emsp; (1) Fine-grained conditioning during training,\
268
+ \n &emsp; (2) Fine-grained control during the diffusion sampling process.\
269
+ \n In particular, **sampling control** ensures tonal accuracy in every generated sample, allowing our model to produce music with high precision, consistent rhythmic patterns,\
270
+ and even stylistic variations that align with user intent.<span>")
271
+ gr.Markdown("<span style='font-size:25px;color: red'> At the bottom of this page, we provide an interactive space for you to try our model by yourself! <span>")
272
+
273
+
274
+ gr.Markdown("\n\n\n")
275
+ gr.Markdown("# 1. Accompaniment Generation given Melody and Chord")
276
+ gr.Markdown("<span style='font-size:20px;'> In each example, the left column displays the melody provided as inputs to the model.\
277
+ The right column showcases music samples generated by the model.<span>")
278
+
279
+ with gr.Column(elem_classes="custom-row1"):
280
+ gr.Markdown("## Example 1")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ gr.Markdown("<span style='font-size:20px;'> With the following melody as condition <span>")
284
+ example1_mel = gr.Audio(value="samples/diy_examples/example1/example_1_mel.wav", label="Melody", scale = 5)
285
+ with gr.Column():
286
+ gr.Markdown("<span style='font-size:20px;'> Generated Accompaniments <span>")
287
+ example1_audio = gr.Audio(value="samples/diy_examples/example1/sample1.wav", label="Generated Accompaniment", scale = 5)
288
+
289
+ with gr.Column(elem_classes="custom-row2"):
290
+ gr.Markdown("## Example 2")
291
+ with gr.Row():
292
+ with gr.Column():
293
+ gr.Markdown("<span style='font-size:20px;'> With the following melody as condition <span>")
294
+ example1_mel = gr.Audio(value="samples/diy_examples/example2/example_2_mel.wav", label="Melody", scale = 5)
295
+ with gr.Column():
296
+ gr.Markdown("<span style='font-size:20px;'> Generated Accompaniments <span>")
297
+ example1_audio = gr.Audio(value="samples/diy_examples/example2/sample1.wav", label="Generated Accompaniment", scale = 5)
298
+
299
+ with gr.Column(elem_classes="custom-row1"):
300
+ gr.Markdown("## Example 3")
301
+ with gr.Row():
302
+ with gr.Column():
303
+ gr.Markdown("<span style='font-size:20px;'> With the following melody as condition <span>")
304
+ example1_mel = gr.Audio(value="samples/diy_examples/example3/example_3_mel.wav", label="Melody", scale = 5)
305
+ with gr.Column():
306
+ gr.Markdown("<span style='font-size:20px;'> Generated Accompaniments <span>")
307
+ example1_audio = gr.Audio(value="samples/diy_examples/example3/sample1.wav", label="Generated Accompaniment", scale = 5)
308
+
309
+ with gr.Column(elem_classes="custom-row2"):
310
+ gr.Markdown("## Example 4")
311
+ with gr.Row():
312
+ with gr.Column():
313
+ gr.Markdown("<span style='font-size:20px;'> With the following melody as condition <span>")
314
+ example1_mel = gr.Audio(value="samples/diy_examples/example4/example_4_mel.wav", label="Melody", scale = 5)
315
+ with gr.Column():
316
+ gr.Markdown("<span style='font-size:20px;'> Generated Accompaniments <span>")
317
+ example1_audio = gr.Audio(value="samples/diy_examples/example4/sample1.wav", label="Generated Accompaniment", scale = 5)
318
+
319
+ gr.HTML("<div style='height: 50px;'></div>")
320
+ gr.Markdown("# \n\n\n")
321
+ gr.Markdown("# 2. Style-Controlled Music Generation")
322
+ gr.Markdown("<span style='font-size:20px;'>Our approach enables controllable stylization in music generation. The sampling control is able to\
323
+ ensure that all generated notes strictly adhere to the target musical style's scale.\
324
+ This allows the model to generate music in specific styles — even those that were not present in \
325
+ the training data.<span>")
326
+ gr.Markdown("<span style='font-size:20px;'> Below, we demonstrate several examples of style-controlled music generation for:\
327
+ \n &emsp; (1) Dorian Mode: (with scale being A-B-C-D-E-F#-G);\
328
+ \n &emsp; (2) Chinese Style: (with scale being C-D-E-G-A). <span>")
329
+
330
+ with gr.Column(elem_classes="custom-row1"):
331
+ gr.Markdown("## Dorian Mode")
332
+ gr.Markdown("<span style='font-size:20px;'> The following are two examples generated by our method <span>")
333
+ with gr.Row():
334
+ with gr.Column(elem_classes="custom-grey"):
335
+ gr.Markdown("<span style='font-size:20px;'> Example 1 <span>")
336
+ example1_mel = gr.Audio(value="samples/different_styles/dorian_1.wav", scale = 5)
337
+ with gr.Column(elem_classes="custom-grey"):
338
+ gr.Markdown("<span style='font-size:20px;'> Example 2 <span>")
339
+ example1_audio = gr.Audio(value="samples/different_styles/dorian_2.wav", scale = 5)
340
+
341
+ with gr.Column(elem_classes="custom-row2"):
342
+ gr.Markdown("## Chinese Style")
343
+ gr.Markdown("<span style='font-size:20px;'> The following are two examples generated by our method <span>")
344
+ with gr.Row():
345
+ with gr.Column(elem_classes="custom-grey"):
346
+ gr.Markdown("<span style='font-size:20px;'> Example 1 <span>")
347
+ example1_mel = gr.Audio(value="samples/different_styles/chinese_1.wav", scale = 5)
348
+ with gr.Column(elem_classes="custom-grey"):
349
+ gr.Markdown("<span style='font-size:20px;'> Example 2 <span>")
350
+ example1_audio = gr.Audio(value="samples/different_styles/chinese_2.wav", scale = 5)
351
+
352
+ gr.HTML("<div style='height: 50px;'></div>")
353
+ gr.Markdown("\n\n\n")
354
+ gr.Markdown("# 3. Demonstrating the Effectiveness of Sampling Control by Comparison")
355
+
356
+ gr.Markdown("<span style='font-size:20px;'> We demonstrate the impact of sampling control in an **accompaniment generation** task, given a melody and chord progression.\
357
+ \n Each example generates accompaniments with and without sampling control using the same random seed, ensuring that the two results are comparable.\
358
+ \n Sampling control effectively removes or replaces harmonically conflicting notes, ensuring tonal consistency.\
359
+ \n We provide music sheets and audio files for both versions.<span>")
360
+
361
+ gr.Markdown("<span style='font-size:20px;'> Comparison of the results indicates that sampling control not only eliminates out-of-key notes but also enhances \
362
+ the overall coherence and harmonic consistency of the accompaniments.\
363
+ This highlights the effectiveness of our approach in maintaining musical coherence. <span>")
364
+
365
+
366
+ with gr.Column(elem_classes="custom-row1"):
367
+ gr.Markdown("## Example 1")
368
+
369
+ with gr.Row(elem_classes="custom-grey"):
370
+ gr.Markdown("<span style='font-size:20px;'> With pre-defined melody and chord as follows<span>")
371
+ with gr.Column(scale=2, min_width=10, ):
372
+ gr.Markdown("Melody Sheet")
373
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
374
+ with gr.Column(scale=1, min_width=10, ):
375
+ gr.Markdown("Melody Audio")
376
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
377
+
378
+ gr.Markdown("## Generated Accompaniments")
379
+ with gr.Row(elem_classes="custom-grey"):
380
+ gr.Markdown("<span style='font-size:20px;'> Without sampling control<span>")
381
+ with gr.Column(scale=2, min_width=300):
382
+ gr.Markdown("Music Sheet")
383
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
384
+ with gr.Column(scale=1, min_width=150):
385
+ gr.Markdown("Audio")
386
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
387
+ gr.Markdown("\n\n\n")
388
+ with gr.Row(elem_classes="custom-grey"):
389
+ with gr.Column(scale=1, min_width=150):
390
+ gr.Markdown("<span style='font-size:20px;'>With sampling control<span>")
391
+ with gr.Column(scale=2, min_width=300):
392
+ gr.Markdown("Music Sheet")
393
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_acc_control.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
394
+ with gr.Column(scale=1, min_width=150):
395
+ gr.Markdown("Audio")
396
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
397
+
398
+
399
+ with gr.Column(elem_classes="custom-row2"):
400
+ gr.Markdown("## Example 2")
401
+
402
+ with gr.Row(elem_classes="custom-grey"):
403
+ gr.Markdown("<span style='font-size:20px;'> With pre-defined melody and chord as follows<span>")
404
+ with gr.Column(scale=2, min_width=10, ):
405
+ gr.Markdown("Melody Sheet")
406
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
407
+ with gr.Column(scale=1, min_width=10, ):
408
+ gr.Markdown("Melody Audio")
409
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
410
+
411
+ gr.Markdown("## Generated Accompaniments")
412
+ with gr.Row(elem_classes="custom-grey"):
413
+ gr.Markdown("<span style='font-size:20px;'> Without sampling control<span>")
414
+ with gr.Column(scale=2, min_width=300):
415
+ gr.Markdown("Music Sheet")
416
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
417
+ with gr.Column(scale=1, min_width=150):
418
+ gr.Markdown("Audio")
419
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_acc_uncontrol.wav", scale = 1, min_width=10)
420
+ gr.Markdown("\n\n\n")
421
+ with gr.Row(elem_classes="custom-grey"):
422
+ with gr.Column(scale=1, min_width=150):
423
+ gr.Markdown("<span style='font-size:20px;'>With sampling control<span>")
424
+ with gr.Column(scale=2, min_width=300):
425
+ gr.Markdown("Music Sheet")
426
+ example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_2_acc_control.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
427
+ with gr.Column(scale=1, min_width=150):
428
+ gr.Markdown("Audio")
429
+ example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_2_acc_control.wav", scale = 1, min_width=10)
430
+
431
+ # with gr.Row():
432
+ # with gr.Column(scale=1, min_width=300, elem_classes="custom-row1"):
433
+ # gr.Markdown("## Example 1")
434
+ # gr.Markdown("<span style='font-size:20px;'> With pre-defined melody and chord as follows<span>")
435
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
436
+ # # Audio component to play the audio
437
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
438
+
439
+ # gr.Markdown("## Generated Accompaniments")
440
+ # with gr.Row():
441
+ # with gr.Column(scale=1, min_width=150):
442
+ # gr.Markdown("<span style='font-size:20px;'> without sampling control<span>")
443
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
444
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
445
+ # with gr.Column(scale=1, min_width=150):
446
+ # gr.Markdown("<span style='font-size:20px;'> with sampling control<span>")
447
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
448
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
449
+ # with gr.Column(scale=1, min_width=300, elem_classes="custom-row2"):
450
+ # gr.Markdown("## Example 2")
451
+ # gr.Markdown("<span style='font-size:20px;'> With pre-defined melody and chord as follows<span>")
452
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/example_1_mel_chd.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
453
+ # # Audio component to play the audio
454
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_mel_chd.wav", label="Melody, wav", waveform_options=gr.WaveformOptions(show_recording_waveform=False), scale = 1, min_width=10)
455
+
456
+ # gr.Markdown("## Generated Accompaniments")
457
+ # with gr.Row():
458
+ # with gr.Column(scale=1, min_width=150):
459
+ # gr.Markdown("<span style='font-size:20px;'> without sampling control<span>")
460
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
461
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_uncontrol.wav", scale = 1, min_width=10)
462
+ # with gr.Column(scale=1, min_width=150):
463
+ # gr.Markdown("<span style='font-size:20px;'> with sampling control<span>")
464
+ # example1_sheet = gr.Image(value="samples/control_vs_uncontrol/sample_1.jpg", label="Music Sheet of Melody and Chord", scale=1, min_width=10)
465
+ # example1_melody = gr.Audio(value="samples/control_vs_uncontrol/example_1_acc_control.wav", scale = 1, min_width=10)
466
+
467
+
468
+
469
+
470
+
471
+ ''' Try to generate by users '''
472
+ gr.HTML("<div style='height: 50px;'></div>")
473
+ gr.Markdown("\n\n\n")
474
+ gr.Markdown("# <span style='color: red;'> 4. DIY in real time! </span>")
475
+ gr.Markdown("<span style='font-size:20px;'> Here is an interactive tool for you to try our model and generate by yourself.\
476
+ You can generate new accompaniments for given melody and chord conditions <span>")
477
+
478
+ gr.Markdown("### <span style='color: blue;'> Currently this space is supported with Hugging Face CPU and on average,\
479
+ it takes about 15 seconds to generate a 4-measure music piece. However, if other users are generating\
480
+ music at the same time, one may enter a queue, which could slow down the process significantly.\
481
+ If that happens, feel free to refresh the page. We appreciate your patience and understanding.\
482
+ </span>")
483
+
484
+ with gr.Column(elem_classes="custom-purple"):
485
+ gr.Markdown("### Select an example to generate music given melody and chord condition")
486
+ with gr.Row():
487
+ with gr.Column():
488
+ prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1")
489
+ gr.Markdown("### This is the melody to be conditioned on:")
490
+ condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition")
491
+ prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore)
492
+
493
+ with gr.Column():
494
+ generate_button = gr.Button("Generate")
495
+ gr.Markdown("### Generation results:")
496
+ audio_output = gr.Audio(label="Generated Music")
497
+ piano_roll_output = gr.Image(label="Generated Piano Roll")
498
+
499
+ generate_button.click(
500
+ fn=generate_from_example,
501
+ inputs=[prompt_selector],
502
+ outputs=[audio_output, piano_roll_output]
503
+ )
504
+
505
+
506
+ # Launch Gradio interface
507
+ if __name__ == "__main__":
508
+ demo.launch()
filter_data/filter_by_instrument.ipynb ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from midi_utils import is_timesig_44, gather_full_instr, gather_instr\n",
10
+ "from midi_utils import has_brass\n",
11
+ "from midi_utils import has_piano, has_string, has_guitar, has_drums\n"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "name": "stdout",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "Program number 0: Acoustic Grand Piano\n",
24
+ "Program number 1: Bright Acoustic Piano\n",
25
+ "Program number 2: Electric Grand Piano\n",
26
+ "Program number 3: Honky-tonk Piano\n",
27
+ "Program number 4: Electric Piano 1\n",
28
+ "Program number 5: Electric Piano 2\n",
29
+ "Program number 6: Harpsichord\n",
30
+ "Program number 7: Clavinet\n",
31
+ "Program number 8: Celesta\n",
32
+ "Program number 9: Glockenspiel\n",
33
+ "Program number 10: Music Box\n",
34
+ "Program number 11: Vibraphone\n",
35
+ "Program number 12: Marimba\n",
36
+ "Program number 13: Xylophone\n",
37
+ "Program number 14: Tubular Bells\n",
38
+ "Program number 15: Dulcimer\n",
39
+ "Program number 16: Drawbar Organ\n",
40
+ "Program number 17: Percussive Organ\n",
41
+ "Program number 18: Rock Organ\n",
42
+ "Program number 19: Church Organ\n",
43
+ "Program number 20: Reed Organ\n",
44
+ "Program number 21: Accordion\n",
45
+ "Program number 22: Harmonica\n",
46
+ "Program number 23: Tango Accordion\n",
47
+ "Program number 24: Acoustic Guitar (nylon)\n",
48
+ "Program number 25: Acoustic Guitar (steel)\n",
49
+ "Program number 26: Electric Guitar (jazz)\n",
50
+ "Program number 27: Electric Guitar (clean)\n",
51
+ "Program number 28: Electric Guitar (muted)\n",
52
+ "Program number 29: Overdriven Guitar\n",
53
+ "Program number 30: Distortion Guitar\n",
54
+ "Program number 31: Guitar Harmonics\n",
55
+ "Program number 32: Acoustic Bass\n",
56
+ "Program number 33: Electric Bass (finger)\n",
57
+ "Program number 34: Electric Bass (pick)\n",
58
+ "Program number 35: Fretless Bass\n",
59
+ "Program number 36: Slap Bass 1\n",
60
+ "Program number 37: Slap Bass 2\n",
61
+ "Program number 38: Synth Bass 1\n",
62
+ "Program number 39: Synth Bass 2\n",
63
+ "Program number 40: Violin\n",
64
+ "Program number 41: Viola\n",
65
+ "Program number 42: Cello\n",
66
+ "Program number 43: Contrabass\n",
67
+ "Program number 44: Tremolo Strings\n",
68
+ "Program number 45: Pizzicato Strings\n",
69
+ "Program number 46: Orchestral Harp\n",
70
+ "Program number 47: Timpani\n",
71
+ "Program number 48: String Ensemble 1\n",
72
+ "Program number 49: String Ensemble 2\n",
73
+ "Program number 50: Synth Strings 1\n",
74
+ "Program number 51: Synth Strings 2\n",
75
+ "Program number 52: Choir Aahs\n",
76
+ "Program number 53: Voice Oohs\n",
77
+ "Program number 54: Synth Choir\n",
78
+ "Program number 55: Orchestra Hit\n",
79
+ "Program number 56: Trumpet\n",
80
+ "Program number 57: Trombone\n",
81
+ "Program number 58: Tuba\n",
82
+ "Program number 59: Muted Trumpet\n",
83
+ "Program number 60: French Horn\n",
84
+ "Program number 61: Brass Section\n",
85
+ "Program number 62: Synth Brass 1\n",
86
+ "Program number 63: Synth Brass 2\n",
87
+ "Program number 64: Soprano Sax\n",
88
+ "Program number 65: Alto Sax\n",
89
+ "Program number 66: Tenor Sax\n",
90
+ "Program number 67: Baritone Sax\n",
91
+ "Program number 68: Oboe\n",
92
+ "Program number 69: English Horn\n",
93
+ "Program number 70: Bassoon\n",
94
+ "Program number 71: Clarinet\n",
95
+ "Program number 72: Piccolo\n",
96
+ "Program number 73: Flute\n",
97
+ "Program number 74: Recorder\n",
98
+ "Program number 75: Pan Flute\n",
99
+ "Program number 76: Blown bottle\n",
100
+ "Program number 77: Shakuhachi\n",
101
+ "Program number 78: Whistle\n",
102
+ "Program number 79: Ocarina\n",
103
+ "Program number 80: Lead 1 (square)\n",
104
+ "Program number 81: Lead 2 (sawtooth)\n",
105
+ "Program number 82: Lead 3 (calliope)\n",
106
+ "Program number 83: Lead 4 chiff\n",
107
+ "Program number 84: Lead 5 (charang)\n",
108
+ "Program number 85: Lead 6 (voice)\n",
109
+ "Program number 86: Lead 7 (fifths)\n",
110
+ "Program number 87: Lead 8 (bass + lead)\n",
111
+ "Program number 88: Pad 1 (new age)\n",
112
+ "Program number 89: Pad 2 (warm)\n",
113
+ "Program number 90: Pad 3 (polysynth)\n",
114
+ "Program number 91: Pad 4 (choir)\n",
115
+ "Program number 92: Pad 5 (bowed)\n",
116
+ "Program number 93: Pad 6 (metallic)\n",
117
+ "Program number 94: Pad 7 (halo)\n",
118
+ "Program number 95: Pad 8 (sweep)\n",
119
+ "Program number 96: FX 1 (rain)\n",
120
+ "Program number 97: FX 2 (soundtrack)\n",
121
+ "Program number 98: FX 3 (crystal)\n",
122
+ "Program number 99: FX 4 (atmosphere)\n",
123
+ "Program number 100: FX 5 (brightness)\n",
124
+ "Program number 101: FX 6 (goblins)\n",
125
+ "Program number 102: FX 7 (echoes)\n",
126
+ "Program number 103: FX 8 (sci-fi)\n",
127
+ "Program number 104: Sitar\n",
128
+ "Program number 105: Banjo\n",
129
+ "Program number 106: Shamisen\n",
130
+ "Program number 107: Koto\n",
131
+ "Program number 108: Kalimba\n",
132
+ "Program number 109: Bagpipe\n",
133
+ "Program number 110: Fiddle\n",
134
+ "Program number 111: Shanai\n",
135
+ "Program number 112: Tinkle Bell\n",
136
+ "Program number 113: Agogo\n",
137
+ "Program number 114: Steel Drums\n",
138
+ "Program number 115: Woodblock\n",
139
+ "Program number 116: Taiko Drum\n",
140
+ "Program number 117: Melodic Tom\n",
141
+ "Program number 118: Synth Drum\n",
142
+ "Program number 119: Reverse Cymbal\n",
143
+ "Program number 120: Guitar Fret Noise\n",
144
+ "Program number 121: Breath Noise\n",
145
+ "Program number 122: Seashore\n",
146
+ "Program number 123: Bird Tweet\n",
147
+ "Program number 124: Telephone Ring\n",
148
+ "Program number 125: Helicopter\n",
149
+ "Program number 126: Applause\n",
150
+ "Program number 127: Gunshot\n"
151
+ ]
152
+ }
153
+ ],
154
+ "source": [
155
+ "import pretty_midi\n",
156
+ "\n",
157
+ "def display_all_instrument_names():\n",
158
+ " for program_number in range(128):\n",
159
+ " instrument_name = pretty_midi.program_to_instrument_name(program_number)\n",
160
+ " print(f\"Program number {program_number}: {instrument_name}\")\n",
161
+ "\n",
162
+ "# Call the function to display all instrument names\n",
163
+ "display_all_instrument_names()\n"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 3,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "#import mido\n",
173
+ "import pretty_midi\n",
174
+ "#from mido import KeySignatureError\n",
175
+ "\n",
176
+ "def filter_midi(file_path, max_track = 8, max_time = 300):\n",
177
+ " try:\n",
178
+ " pm = pretty_midi.PrettyMIDI(file_path)\n",
179
+ " except Exception as e:\n",
180
+ " return False\n",
181
+ " \n",
182
+ " # time signature 4/4\n",
183
+ " #if is_timesig_44(pm) == False:\n",
184
+ " #print(\"timesig\")\n",
185
+ " #return False\n",
186
+ " \n",
187
+ " # number of tracks\n",
188
+ " if len(pm.instruments)>max_track:\n",
189
+ " #print(\"tracks\")\n",
190
+ " return False\n",
191
+ " \n",
192
+ " # length of song\n",
193
+ " #if pm.get_end_time()>max_time:\n",
194
+ " #print(\"length\")\n",
195
+ " #return False\n",
196
+ "\n",
197
+ " # now filter by instruments\n",
198
+ "\n",
199
+ " # filter out the ones without drums\n",
200
+ " #if has_drums(pm)==False:\n",
201
+ " #print(\"no drums\")\n",
202
+ " #return False\n",
203
+ " \n",
204
+ " # filter out the ones with brass\n",
205
+ " instr = gather_instr(pm)\n",
206
+ " if has_brass(instr):\n",
207
+ " #print(\"has brass\")\n",
208
+ " return False\n",
209
+ " \n",
210
+ " # filter out the ones without full string and piano\n",
211
+ " full_instr = gather_full_instr(pm, threshold=0.7)\n",
212
+ " \n",
213
+ " if has_piano(full_instr)== False:\n",
214
+ " #print(\"no piano\")\n",
215
+ " return False\n",
216
+ "\n",
217
+ " if has_guitar(full_instr)== False:\n",
218
+ " return False\n",
219
+ " \n",
220
+ " #if has_string(full_instr)== False:\n",
221
+ " #print(\"no string\")\n",
222
+ " #return False\n",
223
+ " \n",
224
+ " return True\n"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": 4,
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "#import pretty_midi\n",
234
+ "#midi_path = '/home/ubuntu/lakh-pianoroll-dataset/data/samples_with_strings/c10e69ec7f8212c68ff3658cceef5b9b.mid'\n",
235
+ "#pm = pretty_midi.PrettyMIDI(midi_path)\n",
236
+ "#mid = mido.MidiFile(midi_path)"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 5,
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "import shutil\n",
246
+ "import glob\n",
247
+ "import os\n",
248
+ "from tqdm import tqdm\n",
249
+ "\n",
250
+ "def find_midi_files_upto(root_dir, sample_size):\n",
251
+ " midi_files = glob.glob(os.path.join(root_dir, '**/*.mid'), recursive=True)\n",
252
+ " matching_files = []\n",
253
+ " match_count = 0\n",
254
+ "\n",
255
+ " pbar = tqdm(total=len(midi_files), desc=\"Processing MIDI files\")\n",
256
+ " for midi_file in midi_files:\n",
257
+ " if filter_midi(midi_file):\n",
258
+ " matching_files.append(midi_file)\n",
259
+ " match_count += 1\n",
260
+ " pbar.set_postfix({'Matching files': match_count})\n",
261
+ " if match_count >= sample_size:\n",
262
+ " break\n",
263
+ " pbar.update(1)\n",
264
+ " pbar.close()\n",
265
+ " return matching_files\n",
266
+ "\n",
267
+ "def copy_files(files, target_dir):\n",
268
+ " os.makedirs(target_dir, exist_ok=True)\n",
269
+ " for file in files:\n",
270
+ " shutil.copy(file, target_dir)"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 6,
276
+ "metadata": {},
277
+ "outputs": [
278
+ {
279
+ "name": "stderr",
280
+ "output_type": "stream",
281
+ "text": [
282
+ "Processing MIDI files: 0%| | 14/178561 [00:02<11:02:01, 4.49it/s]/home/ubuntu/.local/lib/python3.10/site-packages/pretty_midi/pretty_midi.py:100: RuntimeWarning: Tempo, Key or Time signature change events found on non-zero tracks. This is not a valid type 0 or type 1 MIDI file. Tempo, Key or Time Signature may be wrong.\n",
283
+ " warnings.warn(\n",
284
+ "Processing MIDI files: 0%| | 148/178561 [00:18<4:21:09, 11.39it/s, Matching files=4] "
285
+ ]
286
+ },
287
+ {
288
+ "ename": "KeyboardInterrupt",
289
+ "evalue": "",
290
+ "output_type": "error",
291
+ "traceback": [
292
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
293
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
294
+ "\u001b[0;32m/tmp/ipykernel_47712/2263633239.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mROOT_DIR\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/ubuntu/lakh-pianoroll-dataset/data/lmd/lmd_full'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msample_files\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_midi_files_upto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mROOT_DIR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1500\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtgt_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/ubuntu/lakh-pianoroll-dataset/data/instrument_samples'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcopy_files\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_files\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtgt_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
295
+ "\u001b[0;32m/tmp/ipykernel_47712/2439092612.py\u001b[0m in \u001b[0;36mfind_midi_files_upto\u001b[0;34m(root_dir, sample_size)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mpbar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_files\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Processing MIDI files\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmidi_file\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmidi_files\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mfilter_midi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mmatching_files\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mmatch_count\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
296
+ "\u001b[0;32m/tmp/ipykernel_47712/3402906515.py\u001b[0m in \u001b[0;36mfilter_midi\u001b[0;34m(file_path, max_track, max_time)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfilter_midi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_track\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m300\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mpm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpretty_midi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPrettyMIDI\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
297
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/pretty_midi/pretty_midi.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, midi_file, resolution, initial_tempo)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstring_types\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;31m# If a string was given, pass it as the string filename\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mmidi_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmido\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMidiFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmidi_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# Otherwise, try passing it in as a file pointer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
298
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filename, file, type, ticks_per_beat, charset, debug, clip, tracks)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilename\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 321\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
299
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(self, infile)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0m_dbg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Track {i}:'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 371\u001b[0;31m self.tracks.append(read_track(infile,\n\u001b[0m\u001b[1;32m 372\u001b[0m \u001b[0mdebug\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m clip=self.clip))\n",
300
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36mread_track\u001b[0;34m(infile, debug, clip)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_sysex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 218\u001b[0;31m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus_byte\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpeek_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclip\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0mtrack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
301
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/midifiles/midifiles.py\u001b[0m in \u001b[0;36mread_message\u001b[0;34m(infile, status_byte, peek_data, delta, clip)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mOSError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data byte must be in range 0..127'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 133\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mMessage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdata_bytes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdelta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
302
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/messages.py\u001b[0m in \u001b[0;36mfrom_bytes\u001b[0;34m(cl, data, time)\u001b[0m\n\u001b[1;32m 161\u001b[0m \"\"\"\n\u001b[1;32m 162\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__new__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0mmsgdict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecode_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'data'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmsgdict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mmsgdict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSysexData\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsgdict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
303
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36mdecode_message\u001b[0;34m(msg_bytes, time, check)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_SPECIAL_CASES\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_decode_data_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstatus_byte\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
304
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36m_decode_data_bytes\u001b[0;34m(status_byte, data, spec)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# TODO: better name than args?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'value_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'channel'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
305
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/mido/messages/decode.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# TODO: better name than args?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'value_names'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'channel'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
306
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
307
+ ]
308
+ },
309
+ {
310
+ "name": "stderr",
311
+ "output_type": "stream",
312
+ "text": [
313
+ "Processing MIDI files: 0%| | 150/178561 [00:30<4:21:08, 11.39it/s, Matching files=4]"
314
+ ]
315
+ }
316
+ ],
317
+ "source": [
318
+ "ROOT_DIR = '/home/ubuntu/lakh-pianoroll-dataset/data/lmd/lmd_full'\n",
319
+ "sample_files = find_midi_files_upto(ROOT_DIR, sample_size=1500)\n",
320
+ "tgt_dir = '/home/ubuntu/lakh-pianoroll-dataset/data/instrument_samples'\n",
321
+ "copy_files(sample_files, tgt_dir)"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": null,
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": []
330
+ }
331
+ ],
332
+ "metadata": {
333
+ "kernelspec": {
334
+ "display_name": "Python 3",
335
+ "language": "python",
336
+ "name": "python3"
337
+ },
338
+ "language_info": {
339
+ "codemirror_mode": {
340
+ "name": "ipython",
341
+ "version": 3
342
+ },
343
+ "file_extension": ".py",
344
+ "mimetype": "text/x-python",
345
+ "name": "python",
346
+ "nbconvert_exporter": "python",
347
+ "pygments_lexer": "ipython3",
348
+ "version": "3.10.12"
349
+ }
350
+ },
351
+ "nbformat": 4,
352
+ "nbformat_minor": 2
353
+ }
filter_data/midi_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pretty_midi
3
+
4
+ PIANO = (0,1)
5
+ STRING = (40,41,42,48,49,50,51)
6
+ GUITAR = (24,25,27)
7
+ BRASS = (56,57,58,59,61,62,63,64,65,66,67)
8
+
9
+ ############################## For single track analysis
10
+
11
+ def calculate_active_duration(instrument):
12
+ # Collect all note intervals
13
+ intervals = [(note.start, note.end) for note in instrument.notes]
14
+
15
+ # Sort intervals by start time
16
+ intervals.sort()
17
+
18
+ # Merge overlapping intervals and calculate the active duration
19
+ active_duration = 0
20
+ current_start, current_end = intervals[0]
21
+
22
+ for start, end in intervals[1:]:
23
+ if start <= current_end: # There is an overlap
24
+ current_end = max(current_end, end)
25
+ else: # No overlap, add the previous interval duration and start a new interval
26
+ active_duration += current_end - current_start
27
+ current_start, current_end = start, end
28
+
29
+ # Add the last interval
30
+ active_duration += current_end - current_start
31
+
32
+ return active_duration
33
+
34
+ def is_full_track(midi, instrument, threshold=0.6):
35
+ # Calculate the total duration of the track
36
+ total_duration = midi.get_end_time()
37
+
38
+ # Calculate the active duration (time during which notes are playing)
39
+ active_duration = calculate_active_duration(instrument)
40
+
41
+ # Calculate the percentage of active duration
42
+ active_percentage = active_duration / total_duration
43
+
44
+ #print(f"Total duration: {total_duration:.2f} seconds")
45
+ #print(f"Active duration: {active_duration:.2f} seconds")
46
+ #print(f"Active percentage: {active_percentage:.2%}")
47
+
48
+ # Check if the active duration meets or exceeds the threshold
49
+ return active_percentage >= threshold
50
+
51
+ #################################### For gathering full tracks
52
+
53
+ def gather_instr(pm):
54
+ # Gather all the program indexes of the instrument tracks
55
+ program_indexes = [instrument.program for instrument in pm.instruments]
56
+
57
+ # Sort the program indexes
58
+ program_indexes.sort()
59
+
60
+ # Convert the sorted list of program indexes to a tuple
61
+ program_indexes_tuple = tuple(program_indexes)
62
+ return program_indexes_tuple
63
+
64
+ def gather_full_instr(pm, threshold = 0.6):
65
+ # Gather all the program indexes of the instrument tracks that exceed the duration threshold
66
+ program_indexes = []
67
+ for instrument in pm.instruments:
68
+ if is_full_track(pm, instrument, threshold):
69
+ program_indexes.append(instrument.program)
70
+ program_indexes.sort()
71
+ # Convert the list of program indexes to a tuple
72
+ program_indexes_tuple = tuple(program_indexes)
73
+
74
+ return program_indexes_tuple
75
+
76
+ ####################################### For finding instruments
77
+
78
+ def has_intersection(wanted_instr, exist_instr):
79
+ # Convert both the tuple and the group of integers to sets
80
+ tuple_set = set(wanted_instr)
81
+ group_set = set(exist_instr)
82
+
83
+ # Check if there is any intersection
84
+ return not tuple_set.isdisjoint(group_set)
85
+
86
+ # The functions checking instruments in the midi file tracks
87
+ def has_piano(exist_instr):
88
+ wanted_instr = PIANO
89
+ return has_intersection(wanted_instr, exist_instr)
90
+
91
+ def has_string(exist_instr):
92
+ wanted_instr = STRING
93
+ return has_intersection(wanted_instr, exist_instr)
94
+
95
+ def has_guitar(exist_instr):
96
+ wanted_instr = GUITAR
97
+ return has_intersection(wanted_instr, exist_instr)
98
+
99
+ def has_brass(exist_instr):
100
+ wanted_instr = BRASS
101
+ return has_intersection(wanted_instr, exist_instr)
102
+
103
+ def has_drums(pm):
104
+ for instrument in pm.instruments:
105
+ if instrument.is_drum:
106
+ return True
107
+ return False
108
+
109
+
110
+ def print_track_details(instrument):
111
+ """
112
+ For visualizing the information in a midi track
113
+ """
114
+ print(f"Instrument: {pretty_midi.program_to_instrument_name(instrument.program)}")
115
+ print(f"Is drum: {instrument.is_drum}")
116
+
117
+ print("\nNotes:")
118
+ for note in instrument.notes:
119
+ print(f"Start: {note.start:.2f}, End: {note.end:.2f}, Pitch: {note.pitch}, Velocity: {note.velocity}")
120
+
121
+ print("\nControl Changes:")
122
+ for cc in instrument.control_changes:
123
+ print(f"Time: {cc.time:.2f}, Number: {cc.number}, Value: {cc.value}")
124
+
125
+ print("\nPitch Bends:")
126
+ for pb in instrument.pitch_bends:
127
+ print(f"Time: {pb.time:.2f}, Pitch: {pb.pitch}")
128
+
129
+ def is_timesig_44(pm):
130
+ for time_signature in pm.time_signature_changes:
131
+ if time_signature.numerator != 4 or time_signature.denominator != 4:
132
+ return False
133
+ return True
134
+
135
+ def is_timesig_34(pm):
136
+ for time_signature in pm.time_signature_changes:
137
+ if time_signature.numerator != 4 or time_signature.denominator != 4:
138
+ return False
139
+ return True
generation/__pycache__/gen_utils.cpython-39.pyc ADDED
Binary file (10.1 kB). View file
 
generation/gen_utils.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pretty_midi as pm
4
+
5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+
7
+ CHORD_DICTIONARY = {
8
+ "C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]),
9
+ "C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]),
10
+ "D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]),
11
+ "Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]),
12
+ "E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]),
13
+ "F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]),
14
+ "F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]),
15
+ "G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]),
16
+ "Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]),
17
+ "A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]),
18
+ "Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]),
19
+ "B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]),
20
+
21
+ "c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]),
22
+ "c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]),
23
+ "d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]),
24
+ "eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]),
25
+ "e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]),
26
+ "f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]),
27
+ "f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]),
28
+ "g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]),
29
+ "g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]),
30
+ "a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]),
31
+ "bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]),
32
+ "b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]),
33
+ }
34
+
35
+
36
+ def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True):
37
+ '''
38
+ piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch
39
+ num_notes_onset: a tensor with shape (batch_size, length)
40
+ mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma
41
+ reduce_extra_notes: True if want to reduce extra notes
42
+ '''
43
+ ########## for those greater than the threshold, if num of notes exceed num_notes[i],
44
+ ########## will keep the first ones and set others to threshold
45
+ print("Coming")
46
+ # we only edit onset
47
+ onset_roll = piano_roll_full[:,0,:,:]
48
+ mask = mask_full[:,0,:,:]
49
+ shape = onset_roll.shape
50
+
51
+ onset_roll = onset_roll.reshape(-1,shape[-1])
52
+ mask = mask.reshape(-1,shape[-1])
53
+ num_notes = num_notes_onset.reshape(-1)
54
+
55
+ reduce_note_threshold = 0.499
56
+ increase_note_threshold = 0.501
57
+
58
+ # Initialize a tensor to store the modified values
59
+ final_onset_roll = onset_roll.clone()
60
+
61
+ ########### if number of notes > required, remove the extra notes ###############
62
+ if reduce_extra_notes:
63
+ threshold_mask = onset_roll > reduce_note_threshold
64
+ # Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection
65
+ values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
66
+
67
+ # Get the top num_notes.max() values for each row
68
+ num_notes_max = int(num_notes.max().item()) # Maximum number of notes needed in any row
69
+ topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1)
70
+
71
+ # Create a mask for the top num_notes[i] values for each row
72
+ col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max)
73
+ topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf"))
74
+
75
+ # Set all values greater than reduce_note_threshold to reduce_note_threshold initially
76
+ final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold
77
+
78
+ # Create a flattened index to scatter the top values back into final_onset_roll
79
+ flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices)
80
+ flat_row_indices = flat_row_indices[topk_mask]
81
+
82
+ # Gather the valid topk_indices and corresponding values
83
+ valid_topk_indices = topk_indices[topk_mask]
84
+ valid_topk_values = topk_values[topk_mask]
85
+
86
+ # Use scatter to place the top num_notes[i] values back to their original positions
87
+ final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values)
88
+
89
+ ########### if number of notes < required, add some notes ###############
90
+ pitch_less_84_mask = torch.ones_like(mask)
91
+ pitch_less_84_mask[:,51:] = 0
92
+
93
+ # Count how many values >= increase_note_threshold for each row
94
+ threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1)
95
+ greater_than_threshold2_count = threshold_mask_2.sum(dim=1)
96
+
97
+ # For those rows, find the remaining number of values needed to be set to increase_note_threshold
98
+ remaining_needed = num_notes - greater_than_threshold2_count
99
+ remaining_needed_max = int(remaining_needed.max().item())
100
+ print("\n\n\n",remaining_needed_max,"\n\n\n")
101
+ if remaining_needed_max>=0: # need to add notes
102
+ # Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold)
103
+ values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
104
+ topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1)
105
+
106
+ # Mask to only adjust the needed number of values in each row
107
+ col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max)
108
+ adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf"))
109
+
110
+ # Flatten row indices for the new top-k below increase_note_threshold
111
+ flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices)
112
+ flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask]
113
+
114
+ # Gather the valid indices and set them to increase_note_threshold
115
+ valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask]
116
+
117
+ # Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold
118
+ final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device))
119
+
120
+ final_onset_roll = final_onset_roll.reshape(shape)
121
+ piano_roll_full[:,0,:,:] = final_onset_roll
122
+ return piano_roll_full
123
+
124
+ def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"):
125
+ # 预先计算 major 和 minor 和弦的所有旋转
126
+ maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
127
+ maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1))
128
+ min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
129
+ min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1))
130
+
131
+ # all chords, with rotation
132
+ maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
133
+ min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
134
+
135
+ # combine all chords
136
+ # chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types,
137
+ # 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes
138
+ chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0)
139
+
140
+ # if using null rhythm condition, have to convert -2 to 1 and -1 to 0
141
+ if background_condition[:,:2,:,:].min()<0:
142
+ correct_chord_condition = -background_condition[:,:2,:,:]-1
143
+ else:
144
+ correct_chord_condition = background_condition[:,:2,:,:]
145
+ merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond
146
+ chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond
147
+ shape = chd_chroma_ours.shape
148
+ chd_chroma_ours = chd_chroma_ours.reshape(-1,64)
149
+ matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1)
150
+ seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape)
151
+ seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1))
152
+
153
+ no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1)
154
+ seven_notes_chroma_ours[no_chd_match] = 1.
155
+
156
+ # edit notes based on chroma
157
+ x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0)
158
+ print("See Coming?")
159
+ # edit rhythm
160
+ if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): # only edit if rhythm is provided
161
+ num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1)
162
+ x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes)
163
+
164
+ return x0
165
+
166
+ def expand_roll(roll, unit=4, contain_onset=False):
167
+ # roll: (Channel, T, H) -> (Channel, T * unit, H)
168
+ n_channel, length, height = roll.shape
169
+
170
+ expanded_roll = roll.repeat(unit, axis=1)
171
+ if contain_onset:
172
+ expanded_roll = expanded_roll.reshape((n_channel, length, unit, height))
173
+ expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:])
174
+
175
+ expanded_roll[::2, :, 1:] = 0
176
+ expanded_roll = expanded_roll.reshape((n_channel, length * unit, height))
177
+ return expanded_roll
178
+
179
+ def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96):
180
+ piano_roll_cut = piano_roll[:,:,lowest:highest+1]
181
+ return piano_roll_cut
182
+
183
+ def circular_extend(chd_roll, lowest=33, highest=96):
184
+ #chd_roll: 6*L*12->6*L*64
185
+ C4 = 60-lowest
186
+ C3 = C4-12
187
+ shape = chd_roll.shape
188
+ ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest))
189
+ ext_chd[:,:,C4:C4+12] = chd_roll
190
+ ext_chd[:,:,C3:C3+12] = chd_roll
191
+ return ext_chd
192
+
193
+
194
+ def default_quantization(v):
195
+ return 1 if v > 0.5 else 0
196
+
197
+ def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96):
198
+ ## this function is for extending the cutted piano rolls into the full 128 piano rolls
199
+ ## recall that the piano rolls are of dimensions (2,L,64), we add zeros and fill it into (2,L,128)
200
+ padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0)
201
+ return padded_roll
202
+
203
+
204
+
205
+ def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None):
206
+ """
207
+ piano_roll: (2, L, 128), onset and sustain channel.
208
+ raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave
209
+ """
210
+ def convert_p(p_, note_list):
211
+ edit_note_flag = False
212
+ for t in range(n_step):
213
+ onset_state = quantization_func(piano_roll[0, t, p_])
214
+ sustain_state = quantization_func(piano_roll[1, t, p_])
215
+
216
+ is_onset = bool(onset_state)
217
+ is_sustain = bool(sustain_state) and not is_onset
218
+
219
+ pitch = p_
220
+
221
+ if is_onset:
222
+ edit_note_flag = True
223
+ note_list.append([t, pitch, 1])
224
+ elif is_sustain:
225
+ if edit_note_flag:
226
+ note_list[-1][-1] += 1
227
+ else:
228
+ edit_note_flag = False
229
+ return note_list
230
+
231
+ quantization_func = default_quantization if quantization_func is None else quantization_func
232
+ assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}"
233
+
234
+ n_step = piano_roll.shape[1]
235
+
236
+ notes = []
237
+ for p in range(128):
238
+ convert_p(p, notes)
239
+
240
+ return notes
241
+
242
+
243
+ def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100):
244
+ """Default use shift beat"""
245
+
246
+ beat_alpha = 60 / bpm
247
+ step_alpha = unit * beat_alpha
248
+
249
+ notes = []
250
+
251
+ shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha
252
+
253
+ for note in note_mat:
254
+ onset, pitch, dur = note
255
+ start = onset * step_alpha + shift_sec
256
+ end = (onset + dur) * step_alpha + shift_sec
257
+
258
+ notes.append(pm.Note(vel, int(pitch), start, end))
259
+
260
+ return notes
261
+
262
+
263
+ def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None):
264
+ midi = pm.PrettyMIDI(initial_tempo=bpm)
265
+
266
+ piano_program = pm.instrument_name_to_program('Acoustic Grand Piano')
267
+ piano = pm.Instrument(program=piano_program)
268
+ piano.notes+=piano_notes_list
269
+ midi.instruments.append(piano)
270
+
271
+ # chd_program = pm.instrument_name_to_program('Violin')
272
+ # chd = pm.Instrument(program=chd_program)
273
+ # chd.notes+=chd_notes_list
274
+ # midi.instruments.append(chd)
275
+
276
+ if lsh_notes_list is not None:
277
+ lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano')
278
+ lsh = pm.Instrument(program=lsh_program)
279
+ lsh.notes+=lsh_notes_list
280
+ midi.instruments.append(lsh)
281
+
282
+ return midi
283
+
284
+ def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80):
285
+ piano_mat = piano_roll_to_note_mat(piano_roll)
286
+ piano_notes = note_mat_to_notes(piano_mat, bpm)
287
+
288
+ chd_mat = piano_roll_to_note_mat(chd_roll)
289
+ chd_notes = note_mat_to_notes(chd_mat, bpm)
290
+
291
+ if lsh_roll is not None:
292
+ lsh_mat = piano_roll_to_note_mat(lsh_roll)
293
+ lsh_notes = note_mat_to_notes(lsh_mat, bpm)
294
+ else:
295
+ lsh_notes=None
296
+
297
+ piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes,
298
+ chd_notes_list=chd_notes, lsh_notes_list=lsh_notes)
299
+ return piano_pm
300
+
301
+ def save_midi(pm, filename):
302
+ pm.write(filename)
model/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .latent_diffusion import LatentDiffusion
2
+ from .model_sdf import Diffpro_SDF
3
+ from .architecture.unet import UNetModel
4
+ import os
5
+
6
+
7
+ def init_ldm_model(params, debug_mode=False):
8
+ unet_model = UNetModel(
9
+ in_channels=params.in_channels,
10
+ out_channels=params.out_channels,
11
+ channels=params.channels,
12
+ attention_levels=params.attention_levels,
13
+ n_res_blocks=params.n_res_blocks,
14
+ channel_multipliers=params.channel_multipliers,
15
+ n_heads=params.n_heads,
16
+ tf_layers=params.tf_layers,
17
+ #d_cond=params.d_cond,
18
+ )
19
+
20
+
21
+ ldm_model = LatentDiffusion(
22
+ unet_model=unet_model,
23
+ #autoencoder=None,
24
+ #autoreg_cond_enc=autoreg_cond_enc,
25
+ #external_cond_enc=external_cond_enc,
26
+ latent_scaling_factor=params.latent_scaling_factor,
27
+ n_steps=params.n_steps,
28
+ linear_start=params.linear_start,
29
+ linear_end=params.linear_end,
30
+ debug_mode=debug_mode
31
+ )
32
+
33
+ return ldm_model
34
+
35
+
36
+
37
+ def init_diff_pro_sdf(ldm_model, params, device):
38
+ return Diffpro_SDF(ldm_model).to(device)
39
+
40
+
41
+ def get_model_path(model_dir, model_id=None):
42
+ model_desc = os.path.basename(model_dir)
43
+ if model_id is None:
44
+ model_path = os.path.join(model_dir, 'chkpts', 'weights.pt')
45
+
46
+ # retrieve real model_id from the actual file weights.pt is pointing to
47
+ model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0]
48
+
49
+ elif model_id == 'best':
50
+ model_path = os.path.join(model_dir, 'chkpts', 'weights_best.pt')
51
+ # retrieve real model_id from the actual file weights.pt is pointing to
52
+ model_id = os.path.basename(os.path.realpath(model_path)).split('-')[1].split('.')[0]
53
+ elif model_id == 'default':
54
+ model_path = os.path.join(model_dir, 'chkpts', 'weights_default.pt')
55
+ if not os.path.exists(model_path):
56
+ return get_model_path(model_dir, 'best')
57
+ else:
58
+ model_path = f"{model_dir}/chkpts/weights-{model_id}.pt"
59
+ return model_path, model_id, model_desc
model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
model/__pycache__/latent_diffusion.cpython-39.pyc ADDED
Binary file (5.99 kB). View file
 
model/__pycache__/model_sdf.cpython-39.pyc ADDED
Binary file (2.39 kB). View file
 
model/__pycache__/sampler_sdf.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
model/architecture/__pycache__/unet.cpython-39.pyc ADDED
Binary file (9.04 kB). View file
 
model/architecture/__pycache__/unet_attention.cpython-39.pyc ADDED
Binary file (8.98 kB). View file
 
model/architecture/unet.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ---
3
+ title: U-Net for Stable Diffusion
4
+ summary: >
5
+ Annotated PyTorch implementation/tutorial of the U-Net in stable diffusion.
6
+ ---
7
+
8
+ # U-Net for [Stable Diffusion](../index.html)
9
+
10
+ This implements the U-Net that
11
+ gives $\epsilon_\text{cond}(x_t, c)$
12
+
13
+ We have kept to the model definition and naming unchanged from
14
+ [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
15
+ so that we can load the checkpoints directly.
16
+ """
17
+
18
+ import math
19
+ from typing import List
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+ from .unet_attention import SpatialTransformer
27
+
28
+
29
+ class UNetModel(nn.Module):
30
+ """
31
+ ## U-Net model
32
+ """
33
+ def __init__(
34
+ self,
35
+ *,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ channels: int,
39
+ n_res_blocks: int,
40
+ attention_levels: List[int],
41
+ channel_multipliers: List[int],
42
+ n_heads: int,
43
+ tf_layers: int = 1,
44
+ #d_cond: int = 768
45
+ ):
46
+ """
47
+ :param in_channels: is the number of channels in the input feature map
48
+ :param out_channels: is the number of channels in the output feature map
49
+ :param channels: is the base channel count for the model
50
+ :param n_res_blocks: number of residual blocks at each level
51
+ :param attention_levels: are the levels at which attention should be performed
52
+ :param channel_multipliers: are the multiplicative factors for number of channels for each level
53
+ :param n_heads: the number of attention heads in the transformers
54
+ """
55
+ super().__init__()
56
+ self.channels = channels
57
+ self.out_channels = out_channels
58
+ #self.d_cond = d_cond
59
+
60
+ # Number of levels
61
+ levels = len(channel_multipliers)
62
+ # Size time embeddings
63
+ d_time_emb = channels * 4
64
+ self.time_embed = nn.Sequential(
65
+ nn.Linear(channels, d_time_emb),
66
+ nn.SiLU(),
67
+ nn.Linear(d_time_emb, d_time_emb),
68
+ )
69
+
70
+ # Input half of the U-Net
71
+ self.input_blocks = nn.ModuleList()
72
+ # Initial $3 \times 3$ convolution that maps the input to `channels`.
73
+ # The blocks are wrapped in `TimestepEmbedSequential` module because
74
+ # different modules have different forward function signatures;
75
+ # for example, convolution only accepts the feature map and
76
+ # residual blocks accept the feature map and time embedding.
77
+ # `TimestepEmbedSequential` calls them accordingly.
78
+ self.input_blocks.append(
79
+ TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1))
80
+ )
81
+ # Number of channels at each block in the input half of U-Net
82
+ input_block_channels = [channels]
83
+ # Number of channels at each level
84
+ channels_list = [channels * m for m in channel_multipliers]
85
+ # Prepare levels
86
+ for i in range(levels):
87
+ # Add the residual blocks and attentions
88
+ for _ in range(n_res_blocks):
89
+ # Residual block maps from previous number of channels to the number of
90
+ # channels in the current level
91
+ layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
92
+ channels = channels_list[i]
93
+ # Add transformer
94
+ if i in attention_levels:
95
+ layers.append(
96
+ SpatialTransformer(channels, n_heads, tf_layers)
97
+ )
98
+ # Add them to the input half of the U-Net and keep track of the number of channels of
99
+ # its output
100
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
101
+ input_block_channels.append(channels)
102
+ # Down sample at all levels except last
103
+ if i != levels - 1:
104
+ self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
105
+ input_block_channels.append(channels)
106
+
107
+ # The middle of the U-Net
108
+ self.middle_block = TimestepEmbedSequential(
109
+ ResBlock(channels, d_time_emb),
110
+ SpatialTransformer(channels, n_heads, tf_layers),
111
+ ResBlock(channels, d_time_emb),
112
+ )
113
+
114
+ # Second half of the U-Net
115
+ self.output_blocks = nn.ModuleList([])
116
+ # Prepare levels in reverse order
117
+ for i in reversed(range(levels)):
118
+ # Add the residual blocks and attentions
119
+ for j in range(n_res_blocks + 1):
120
+ # Residual block maps from previous number of channels plus the
121
+ # skip connections from the input half of U-Net to the number of
122
+ # channels in the current level.
123
+ layers = [
124
+ ResBlock(
125
+ channels + input_block_channels.pop(),
126
+ d_time_emb,
127
+ out_channels=channels_list[i]
128
+ )
129
+ ]
130
+ channels = channels_list[i]
131
+ # Add transformer
132
+ if i in attention_levels:
133
+ layers.append(
134
+ SpatialTransformer(channels, n_heads, tf_layers)
135
+ )
136
+ # Up-sample at every level after last residual block
137
+ # except the last one.
138
+ # Note that we are iterating in reverse; i.e. `i == 0` is the last.
139
+ if i != 0 and j == n_res_blocks:
140
+ layers.append(UpSample(channels))
141
+ # Add to the output half of the U-Net
142
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
143
+
144
+ # Final normalization and $3 \times 3$ convolution
145
+ self.out = nn.Sequential(
146
+ normalization(channels),
147
+ nn.SiLU(),
148
+ nn.Conv2d(channels, out_channels, 3, padding=1),
149
+ )
150
+
151
+ def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
152
+ """
153
+ ## Create sinusoidal time step embeddings
154
+
155
+ :param time_steps: are the time steps of shape `[batch_size]`
156
+ :param max_period: controls the minimum frequency of the embeddings.
157
+ """
158
+ # $\frac{c}{2}$; half the channels are sin and the other half is cos,
159
+ half = self.channels // 2
160
+ # $\frac{1}{10000^{\frac{2i}{c}}}$
161
+ frequencies = torch.exp(
162
+ -math.log(max_period) *
163
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
164
+ ).to(device=time_steps.device)
165
+ # $\frac{t}{10000^{\frac{2i}{c}}}$
166
+ args = time_steps[:, None].float() * frequencies[None]
167
+ # $\cos\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$ and $\sin\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$
168
+ return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
169
+
170
+ def forward(self, x: torch.Tensor, time_steps: torch.Tensor):
171
+ """
172
+ :param x: is the input feature map of shape `[batch_size, channels, width, height]`
173
+ :param time_steps: are the time steps of shape `[batch_size]`
174
+ :param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
175
+ """
176
+ # To store the input half outputs for skip connections
177
+ x_input_block = []
178
+
179
+ # Get time step embeddings
180
+ t_emb = self.time_step_embedding(time_steps)
181
+ t_emb = self.time_embed(t_emb)
182
+
183
+ # Input half of the U-Net
184
+ for module in self.input_blocks:
185
+ ##########################
186
+ #print("x:", x.dtype,"t_emb:",t_emb.dtype)
187
+ ##########################
188
+ #x = x.to(torch.float16)
189
+ x = module(x, t_emb)
190
+ x_input_block.append(x)
191
+ # Middle of the U-Net
192
+ x = self.middle_block(x, t_emb)
193
+ # Output half of the U-Net
194
+ for module in self.output_blocks:
195
+ # print(x.size(), 'a')
196
+ x = torch.cat([x, x_input_block.pop()], dim=1)
197
+ # print(x.size(), 'b')
198
+ x = module(x, t_emb)
199
+
200
+ # Final normalization and $3 \times 3$ convolution
201
+ return self.out(x)
202
+
203
+
204
+ class TimestepEmbedSequential(nn.Sequential):
205
+ """
206
+ ### Sequential block for modules with different inputs
207
+
208
+ This sequential module can compose of different modules suck as `ResBlock`,
209
+ `nn.Conv` and `SpatialTransformer` and calls them with the matching signatures
210
+ """
211
+ def forward(self, x, t_emb, cond=None):
212
+ for layer in self:
213
+ if isinstance(layer, ResBlock):
214
+ x = layer(x, t_emb)
215
+ elif isinstance(layer, SpatialTransformer):
216
+ x = layer(x)
217
+ else:
218
+ x = layer(x)
219
+ return x
220
+
221
+
222
+ class UpSample(nn.Module):
223
+ """
224
+ ### Up-sampling layer
225
+ """
226
+ def __init__(self, channels: int):
227
+ """
228
+ :param channels: is the number of channels
229
+ """
230
+ super().__init__()
231
+ # $3 \times 3$ convolution mapping
232
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
233
+
234
+ def forward(self, x: torch.Tensor):
235
+ """
236
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
237
+ """
238
+ # Up-sample by a factor of $2$
239
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
240
+ # Apply convolution
241
+ return self.conv(x)
242
+
243
+
244
+ class DownSample(nn.Module):
245
+ """
246
+ ## Down-sampling layer
247
+ """
248
+ def __init__(self, channels: int):
249
+ """
250
+ :param channels: is the number of channels
251
+ """
252
+ super().__init__()
253
+ # $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
254
+ self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
255
+
256
+ def forward(self, x: torch.Tensor):
257
+ """
258
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
259
+ """
260
+ # Apply convolution
261
+ return self.op(x)
262
+
263
+
264
+ class ResBlock(nn.Module):
265
+ """
266
+ ## ResNet Block
267
+ """
268
+ def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
269
+ """
270
+ :param channels: the number of input channels
271
+ :param d_t_emb: the size of timestep embeddings
272
+ :param out_channels: is the number of out channels. defaults to `channels.
273
+ """
274
+ super().__init__()
275
+ # `out_channels` not specified
276
+ if out_channels is None:
277
+ out_channels = channels
278
+
279
+ # First normalization and convolution
280
+ self.in_layers = nn.Sequential(
281
+ normalization(channels),
282
+ nn.SiLU(),
283
+ nn.Conv2d(channels, out_channels, 3, padding=1),
284
+ )
285
+
286
+ # Time step embeddings
287
+ self.emb_layers = nn.Sequential(
288
+ nn.SiLU(),
289
+ nn.Linear(d_t_emb, out_channels),
290
+ )
291
+ # Final convolution layer
292
+ self.out_layers = nn.Sequential(
293
+ normalization(out_channels), nn.SiLU(), nn.Dropout(0.),
294
+ nn.Conv2d(out_channels, out_channels, 3, padding=1)
295
+ )
296
+
297
+ # `channels` to `out_channels` mapping layer for residual connection
298
+ if out_channels == channels:
299
+ self.skip_connection = nn.Identity()
300
+ else:
301
+ self.skip_connection = nn.Conv2d(channels, out_channels, 1)
302
+
303
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
304
+ """
305
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
306
+ :param t_emb: is the time step embeddings of shape `[batch_size, d_t_emb]`
307
+ """
308
+ # Initial convolution
309
+ h = self.in_layers(x)
310
+ # Time step embeddings
311
+ t_emb = self.emb_layers(t_emb).type(h.dtype)
312
+ # Add time step embeddings
313
+ h = h + t_emb[:, :, None, None]
314
+ # Final convolution
315
+ h = self.out_layers(h)
316
+ # Add skip connection
317
+ return self.skip_connection(x) + h
318
+
319
+
320
+ class GroupNorm32(nn.GroupNorm):
321
+ """
322
+ ### Group normalization with float32 casting
323
+ """
324
+ def forward(self, x):
325
+ return super().forward(x.float()).type(x.dtype)
326
+
327
+
328
+ def normalization(channels):
329
+ """
330
+ ### Group normalization
331
+
332
+ This is a helper function, with fixed number of groups..
333
+ """
334
+ return GroupNorm32(32, channels)
335
+
336
+
337
+ def _test_time_embeddings():
338
+ """
339
+ Test sinusoidal time step embeddings
340
+ """
341
+ import matplotlib.pyplot as plt
342
+
343
+ plt.figure(figsize=(15, 5))
344
+ m = UNetModel(
345
+ in_channels=1,
346
+ out_channels=1,
347
+ channels=320,
348
+ n_res_blocks=1,
349
+ attention_levels=[],
350
+ channel_multipliers=[],
351
+ n_heads=1,
352
+ tf_layers=1,
353
+ d_cond=1
354
+ )
355
+ te = m.time_step_embedding(torch.arange(0, 1000))
356
+ plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
357
+ plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
358
+ plt.title("Time embeddings")
359
+ plt.show()
360
+
361
+
362
+ #
363
+ if __name__ == '__main__':
364
+ _test_time_embeddings()
model/architecture/unet_attention.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ---
3
+ title: Transformer for Stable Diffusion U-Net
4
+ summary: >
5
+ Annotated PyTorch implementation/tutorial of the transformer
6
+ for U-Net in stable diffusion.
7
+ ---
8
+
9
+ # Transformer for Stable Diffusion [U-Net](unet.html)
10
+
11
+ This implements the transformer module used in [U-Net](unet.html) that
12
+ gives $\epsilon_\text{cond}(x_t, c)$
13
+
14
+ We have kept to the model definition and naming unchanged from
15
+ [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
16
+ so that we can load the checkpoints directly.
17
+ """
18
+
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch import nn
24
+
25
+
26
+ class SpatialTransformer(nn.Module):
27
+ """
28
+ ## Spatial Transformer
29
+ """
30
+ def __init__(self, channels: int, n_heads: int, n_layers: int):
31
+ """
32
+ :param channels: is the number of channels in the feature map
33
+ :param n_heads: is the number of attention heads
34
+ :param n_layers: is the number of transformer layers
35
+ :param d_cond: is the size of the conditional embedding
36
+ """
37
+ super().__init__()
38
+ # Initial group normalization
39
+ self.norm = torch.nn.GroupNorm(
40
+ num_groups=32, num_channels=channels, eps=1e-6, affine=True
41
+ )
42
+ # Initial $1 \times 1$ convolution
43
+ self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
44
+
45
+ # Transformer layers
46
+ self.transformer_blocks = nn.ModuleList(
47
+ [
48
+ BasicTransformerBlock(
49
+ channels, n_heads, channels // n_heads
50
+ ) for _ in range(n_layers)
51
+ ]
52
+ )
53
+
54
+ # Final $1 \times 1$ convolution
55
+ self.proj_out = nn.Conv2d(
56
+ channels, channels, kernel_size=1, stride=1, padding=0
57
+ )
58
+
59
+ def forward(self, x: torch.Tensor):
60
+ """
61
+ :param x: is the feature map of shape `[batch_size, channels, height, width]`
62
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
63
+ """
64
+ # Get shape `[batch_size, channels, height, width]`
65
+ b, c, h, w = x.shape
66
+ # For residual connection
67
+ x_in = x
68
+ # Normalize
69
+ x = self.norm(x)
70
+ # Initial $1 \times 1$ convolution
71
+ x = self.proj_in(x)
72
+ # Transpose and reshape from `[batch_size, channels, height, width]`
73
+ # to `[batch_size, height * width, channels]`
74
+ x = x.permute(0, 2, 3, 1).view(b, h * w, c)
75
+ # Apply the transformer layers
76
+ for block in self.transformer_blocks:
77
+ x = block(x)
78
+ # Reshape and transpose from `[batch_size, height * width, channels]`
79
+ # to `[batch_size, channels, height, width]`
80
+ x = x.view(b, h, w, c).permute(0, 3, 1, 2)
81
+ # Final $1 \times 1$ convolution
82
+ x = self.proj_out(x)
83
+ # Add residual
84
+ return x + x_in
85
+
86
+
87
+ class BasicTransformerBlock(nn.Module):
88
+ """
89
+ ### Transformer Layer
90
+ """
91
+ def __init__(self, d_model: int, n_heads: int, d_head: int):
92
+ """
93
+ :param d_model: is the input embedding size
94
+ :param n_heads: is the number of attention heads
95
+ :param d_head: is the size of a attention head
96
+ :param d_cond: is the size of the conditional embeddings
97
+ """
98
+ super().__init__()
99
+ # Self-attention layer and pre-norm layer
100
+ self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
101
+ self.norm1 = nn.LayerNorm(d_model)
102
+ # Cross attention layer and pre-norm layer
103
+ #self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
104
+ self.norm2 = nn.LayerNorm(d_model)
105
+ # Feed-forward network and pre-norm layer
106
+ self.ff = FeedForward(d_model)
107
+ self.norm3 = nn.LayerNorm(d_model)
108
+
109
+ def forward(self, x: torch.Tensor):
110
+ """
111
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
112
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
113
+ """
114
+ # Self attention
115
+ x = self.attn1(self.norm1(x)) + x
116
+ # Cross-attention with conditioning
117
+ # x = self.attn2(self.norm2(x), cond=cond) + x
118
+ # Feed-forward network
119
+ x = self.ff(self.norm3(x)) + x
120
+ #
121
+ return x
122
+
123
+
124
+ class CrossAttention(nn.Module):
125
+ """
126
+ ### Cross Attention Layer
127
+
128
+ This falls-back to self-attention when conditional embeddings are not specified.
129
+ """
130
+
131
+ use_flash_attention: bool = False
132
+
133
+ def __init__(
134
+ self,
135
+ d_model: int,
136
+ d_cond: int,
137
+ n_heads: int,
138
+ d_head: int,
139
+ is_inplace: bool = True
140
+ ):
141
+ """
142
+ :param d_model: is the input embedding size
143
+ :param n_heads: is the number of attention heads
144
+ :param d_head: is the size of a attention head
145
+ :param d_cond: is the size of the conditional embeddings
146
+ :param is_inplace: specifies whether to perform the attention softmax computation inplace to
147
+ save memory
148
+ """
149
+ super().__init__()
150
+
151
+ self.is_inplace = is_inplace
152
+ self.n_heads = n_heads
153
+ self.d_head = d_head
154
+
155
+ # Attention scaling factor
156
+ self.scale = d_head**-0.5
157
+
158
+ # Query, key and value mappings
159
+ d_attn = d_head * n_heads
160
+ self.to_q = nn.Linear(d_model, d_attn, bias=False)
161
+ self.to_k = nn.Linear(d_cond, d_attn, bias=False)
162
+ self.to_v = nn.Linear(d_cond, d_attn, bias=False)
163
+
164
+ # Final linear layer
165
+ self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
166
+
167
+ # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
168
+ # Flash attention is only used if it's installed
169
+ # and `CrossAttention.use_flash_attention` is set to `True`.
170
+ try:
171
+ # You can install flash attention by cloning their Github repo,
172
+ # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
173
+ # and then running `python setup.py install`
174
+ from flash_attn.flash_attention import FlashAttention
175
+ self.flash = FlashAttention()
176
+ # Set the scale for scaled dot-product attention.
177
+ self.flash.softmax_scale = self.scale
178
+ # Set to `None` if it's not installed
179
+ except ImportError:
180
+ self.flash = None
181
+
182
+ def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
183
+ """
184
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
185
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
186
+ """
187
+
188
+ # If `cond` is `None` we perform self attention
189
+ has_cond = cond is not None
190
+ if not has_cond:
191
+ cond = x
192
+
193
+ # Get query, key and value vectors
194
+ q = self.to_q(x)
195
+ k = self.to_k(cond)
196
+ v = self.to_v(cond)
197
+
198
+ # Use flash attention if it's available and the head size is less than or equal to `128`
199
+ if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
200
+ return self.flash_attention(q, k, v)
201
+ # Otherwise, fallback to normal attention
202
+ else:
203
+ return self.normal_attention(q, k, v)
204
+
205
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
206
+ """
207
+ #### Flash Attention
208
+
209
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
210
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
211
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
212
+ """
213
+
214
+ # Get batch size and number of elements along sequence axis (`width * height`)
215
+ batch_size, seq_len, _ = q.shape
216
+
217
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
218
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
219
+ qkv = torch.stack((q, k, v), dim=2)
220
+ # Split the heads
221
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
222
+
223
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
224
+ # fit this size.
225
+ if self.d_head <= 32:
226
+ pad = 32 - self.d_head
227
+ elif self.d_head <= 64:
228
+ pad = 64 - self.d_head
229
+ elif self.d_head <= 128:
230
+ pad = 128 - self.d_head
231
+ else:
232
+ raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
233
+
234
+ # Pad the heads
235
+ if pad:
236
+ qkv = torch.cat(
237
+ (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
238
+ )
239
+
240
+ # Compute attention
241
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
242
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
243
+ out, _ = self.flash(qkv)
244
+ # Truncate the extra head size
245
+ out = out[:, :, :, : self.d_head]
246
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
247
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
248
+
249
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
250
+ return self.to_out(out)
251
+
252
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
253
+ """
254
+ #### Normal Attention
255
+
256
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
257
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
258
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
259
+ """
260
+
261
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
262
+ q = q.view(*q.shape[: 2], self.n_heads, -1)
263
+ k = k.view(*k.shape[: 2], self.n_heads, -1)
264
+ v = v.view(*v.shape[: 2], self.n_heads, -1)
265
+
266
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
267
+ attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
268
+
269
+ # Compute softmax
270
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
271
+ if self.is_inplace:
272
+ half = attn.shape[0] // 2
273
+ attn[half :] = attn[half :].softmax(dim=-1)
274
+ attn[: half] = attn[: half].softmax(dim=-1)
275
+ else:
276
+ attn = attn.softmax(dim=-1)
277
+
278
+ # Compute attention output
279
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
280
+ out = torch.einsum('bhij,bjhd->bihd', attn, v)
281
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
282
+ out = out.reshape(*out.shape[: 2], -1)
283
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
284
+ return self.to_out(out)
285
+
286
+
287
+ class FeedForward(nn.Module):
288
+ """
289
+ ### Feed-Forward Network
290
+ """
291
+ def __init__(self, d_model: int, d_mult: int = 4):
292
+ """
293
+ :param d_model: is the input embedding size
294
+ :param d_mult: is multiplicative factor for the hidden layer size
295
+ """
296
+ super().__init__()
297
+ self.net = nn.Sequential(
298
+ GeGLU(d_model, d_model * d_mult), nn.Dropout(0.),
299
+ nn.Linear(d_model * d_mult, d_model)
300
+ )
301
+
302
+ def forward(self, x: torch.Tensor):
303
+ return self.net(x)
304
+
305
+
306
+ class GeGLU(nn.Module):
307
+ """
308
+ ### GeGLU Activation
309
+
310
+ $$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
311
+ """
312
+ def __init__(self, d_in: int, d_out: int):
313
+ super().__init__()
314
+ # Combined linear projections $xW + b$ and $xV + c$
315
+ self.proj = nn.Linear(d_in, d_out * 2)
316
+
317
+ def forward(self, x: torch.Tensor):
318
+ # Get $xW + b$ and $xV + c$
319
+ x, gate = self.proj(x).chunk(2, dim=-1)
320
+ # $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
321
+ return x * F.gelu(gate)
model/latent_diffusion.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ---
3
+ title: Latent Diffusion Models
4
+ summary: >
5
+ Annotated PyTorch implementation/tutorial of latent diffusion models from paper
6
+ High-Resolution Image Synthesis with Latent Diffusion Models
7
+ ---
8
+
9
+ # Latent Diffusion Models
10
+
11
+ Latent diffusion models use an auto-encoder to map between image space and
12
+ latent space. The diffusion model works on the diffusion space, which makes it
13
+ a lot easier to train.
14
+ It is based on paper
15
+ [High-Resolution Image Synthesis with Latent Diffusion Models](https://papers.labml.ai/paper/2112.10752).
16
+
17
+ They use a pre-trained auto-encoder and train the diffusion U-Net on the latent
18
+ space of the pre-trained auto-encoder.
19
+
20
+ For a simpler diffusion implementation refer to our [DDPM implementation](../ddpm/index.html).
21
+ We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.
22
+ """
23
+
24
+ from typing import List, Tuple, Optional, Union
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from .architecture.unet import UNetModel
29
+ import random
30
+
31
+
32
+ def gather(consts: torch.Tensor, t: torch.Tensor):
33
+ """Gather consts for $t$ and reshape to feature map shape"""
34
+ c = consts.gather(-1, t)
35
+ return c.reshape(-1, 1, 1, 1)
36
+
37
+
38
+ class LatentDiffusion(nn.Module):
39
+ """
40
+ ## Latent diffusion model
41
+
42
+ This contains following components:
43
+
44
+ * [AutoEncoder](model/autoencoder.html)
45
+ * [U-Net](model/unet.html) with [attention](model/unet_attention.html)
46
+ """
47
+ eps_model: UNetModel
48
+ #first_stage_model: Optional[Autoencoder] = None
49
+
50
+ def __init__(
51
+ self,
52
+ unet_model: UNetModel,
53
+ latent_scaling_factor: float,
54
+ n_steps: int,
55
+ linear_start: float,
56
+ linear_end: float,
57
+ debug_mode: Optional[bool] = False
58
+ ):
59
+ """
60
+ :param unet_model: is the [U-Net](model/unet.html) that predicts noise
61
+ $\epsilon_\text{cond}(x_t, c)$, in latent space
62
+ :param autoencoder: is the [AutoEncoder](model/autoencoder.html)
63
+ :param latent_scaling_factor: is the scaling factor for the latent space. The encodings of
64
+ the autoencoder are scaled by this before feeding into the U-Net.
65
+ :param n_steps: is the number of diffusion steps $T$.
66
+ :param linear_start: is the start of the $\beta$ schedule.
67
+ :param linear_end: is the end of the $\beta$ schedule.
68
+ """
69
+ super().__init__()
70
+ # Wrap the [U-Net](model/unet.html) to keep the same model structure as
71
+ # [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion).
72
+ self.eps_model = unet_model
73
+ self.latent_scaling_factor = latent_scaling_factor
74
+
75
+ # Number of steps $T$
76
+ self.n_steps = n_steps
77
+
78
+ # $\beta$ schedule
79
+ beta = torch.linspace(
80
+ linear_start**0.5, linear_end**0.5, n_steps, dtype=torch.float64
81
+ ) ** 2
82
+ # $\alpha_t = 1 - \beta_t$
83
+ alpha = 1. - beta
84
+ # $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
85
+ alpha_bar = torch.cumprod(alpha, dim=0)
86
+ self.alpha = nn.Parameter(alpha.to(torch.float32), requires_grad=False)
87
+ self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
88
+ self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
89
+ self.alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])
90
+ self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev))
91
+ self.sigma2 = self.beta
92
+
93
+ self.debug_mode = debug_mode
94
+
95
+ @property
96
+ def device(self):
97
+ """
98
+ ### Get model device
99
+ """
100
+ return next(iter(self.eps_model.parameters())).device
101
+
102
+
103
+
104
+ def forward(self, x: torch.Tensor, t: torch.Tensor):
105
+ """
106
+ ### Predict noise
107
+
108
+ Predict noise given the latent representation $x_t$, time step $t$, and the
109
+ conditioning context $c$.
110
+
111
+ $$\epsilon_\text{cond}(x_t, c)$$
112
+ """
113
+ return self.eps_model(x, t)
114
+
115
+ def q_xt_x0(self, x0: torch.Tensor,
116
+ t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
117
+ """
118
+ #### Get $q(x_t|x_0)$ distribution
119
+ """
120
+
121
+ # [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$
122
+ mean = gather(self.alpha_bar, t)**0.5 * x0
123
+ # $(1-\bar\alpha_t) \mathbf{I}$
124
+ var = 1 - gather(self.alpha_bar, t)
125
+ #
126
+ return mean, var
127
+
128
+ def q_sample(
129
+ self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None
130
+ ):
131
+ """
132
+ #### Sample from $q(x_t|x_0)$
133
+ """
134
+
135
+ # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
136
+ if eps is None:
137
+ eps = torch.randn_like(x0)
138
+
139
+ # get $q(x_t|x_0)$
140
+ mean, var = self.q_xt_x0(x0, t)
141
+ # Sample from $q(x_t|x_0)$
142
+ return mean + (var**0.5) * eps
143
+
144
+ def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
145
+ """
146
+ #### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
147
+ """
148
+
149
+ # $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
150
+ eps_theta = self.eps_model(xt, t)
151
+ # [gather](utils.html) $\bar\alpha_t$
152
+ alpha_bar = gather(self.alpha_bar, t)
153
+ # [gather](utils.html) $\bar\alpha_t-1$
154
+ alpha_bar_prev = gather(self.alpha_bar_prev, t)
155
+ # [gather](utils.html) $\sigma_t$
156
+ sigma_ddim = gather(self.sigma_ddim, t)
157
+
158
+ # DDIM sampling
159
+ # $\frac{x_t-\sqrt{1-\bar\alpha_t}\epsilon}{\sqrt{\bar\alpha_t}}$
160
+ predicted_x0 = (xt - (1-alpha_bar)**0.5 * eps_theta) / (alpha_bar)**.5
161
+ # $\sqrt{1-\alpha_{t-1}-\sigma_t^2}$
162
+ direction_to_xt = (1 - alpha_bar_prev - sigma_ddim**2)**0.5 * eps_theta
163
+
164
+ # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
165
+ eps = torch.randn(xt.shape, device=xt.device)
166
+
167
+ # Sample
168
+ x_tm_1 = alpha_bar_prev**0.5 * predicted_x0 + direction_to_xt + sigma_ddim * eps
169
+ return x_tm_1
170
+
171
+ def loss(
172
+ self,
173
+ x0: torch.Tensor,
174
+ #autoreg_cond: Union[torch.Tensor, None], #This means it can be either a tensor or none
175
+ #external_cond: Union[torch.Tensor, None],
176
+ noise: Optional[torch.Tensor] = None,
177
+ ):
178
+ """
179
+ #### Simplified Loss
180
+ """
181
+ # Get batch size
182
+ batch_size = x0.shape[0]
183
+ # Get random $t$ for each sample in the batch
184
+ t = torch.randint(
185
+ 0, self.n_steps, (batch_size, ), device=x0.device, dtype=torch.long
186
+ )
187
+
188
+
189
+ #autoreg_cond = -torch.ones(x0.size(0), 1, self.eps_model.d_cond, device=x0.device, dtype=x0.dtype)
190
+ #cond = autoreg_cond
191
+
192
+ if x0.size(1) == self.eps_model.out_channels: # generating form
193
+ if self.debug_mode:
194
+ print('In the mode of root level:', x0.size())
195
+ if noise is None:
196
+ x0 = x0.to(torch.float32)
197
+ noise = torch.randn_like(x0)
198
+
199
+ xt = self.q_sample(x0, t, eps=noise)
200
+
201
+ eps_theta = self.eps_model(xt, t)
202
+
203
+ loss = F.mse_loss(noise, eps_theta)
204
+ else:
205
+ if self.debug_mode:
206
+ print('In the mode of non-root level:', x0.size())
207
+
208
+ if noise is None:
209
+ noise = torch.randn_like(x0[:, 0: 2])
210
+
211
+ front_t = self.q_sample(x0[:, 0: 2], t, eps=noise)
212
+
213
+ background_cond = x0[:, 2:]
214
+
215
+ xt = torch.cat([front_t, background_cond], 1)
216
+
217
+ eps_theta = self.eps_model(xt, t)
218
+
219
+ loss = F.mse_loss(noise, eps_theta)
220
+ if self.debug_mode:
221
+ print('loss:', loss)
222
+ return loss
model/model_sdf.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torch.nn as nn
4
+ from .latent_diffusion import LatentDiffusion
5
+
6
+
7
+ class Diffpro_SDF(nn.Module):
8
+
9
+ def __init__(
10
+ self,
11
+ ldm: LatentDiffusion,
12
+ ):
13
+ """
14
+ cond_type: {chord, texture}
15
+ cond_mode: {cond, mix, uncond}
16
+ mix: use a special condition for unconditional learning with probability of 0.2
17
+ use_enc: whether to use pretrained chord encoder to generate encoded condition
18
+ """
19
+ super(Diffpro_SDF, self).__init__()
20
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ self.ldm = ldm
22
+
23
+ @classmethod
24
+ def load_trained(
25
+ cls,
26
+ ldm,
27
+ chkpt_fpath,
28
+ ):
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model = cls(ldm)
31
+ trained_leaner = torch.load(chkpt_fpath, map_location=device)
32
+ try:
33
+ model.load_state_dict(trained_leaner["model"])
34
+ except RuntimeError:
35
+ model_dict = trained_leaner["model"]
36
+ model_dict = {k.replace('cond_enc', 'autoreg_cond_enc'): v for k, v in model_dict.items()}
37
+ model_dict = {k.replace('style_enc', 'external_cond_enc'): v for k, v in model_dict.items()}
38
+ model.load_state_dict(model_dict)
39
+ return model
40
+
41
+ def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
42
+ return self.ldm.p_sample(xt, t)
43
+
44
+ def q_sample(self, x0: torch.Tensor, t: torch.Tensor):
45
+ return self.ldm.q_sample(x0, t)
46
+
47
+ def get_loss_dict(self, batch, step):
48
+ """
49
+ z_y is the stuff the diffusion model needs to learn
50
+ """
51
+ # x = batch.float().to(self.device)
52
+
53
+ x= batch
54
+ loss = self.ldm.loss(x)
55
+ return {"loss": loss}
model/sampler_sdf.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Union
2
+ import numpy as np
3
+ import torch
4
+ from labml import monit
5
+ from .latent_diffusion import LatentDiffusion
6
+
7
+ def set_seed(seed):
8
+ np.random.seed(seed)
9
+ torch.manual_seed(seed)
10
+ if torch.cuda.is_available():
11
+ torch.cuda.manual_seed(seed)
12
+ torch.cuda.manual_seed_all(seed)
13
+
14
+ # Call the function to set the seed
15
+ # set_seed(42)
16
+
17
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
18
+ """
19
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
20
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
21
+ """
22
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
23
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
24
+ # rescale the results from guidance (fixes overexposure)
25
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
26
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
27
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
28
+ return noise_cfg
29
+
30
+
31
+ class DiffusionSampler:
32
+ """
33
+ ## Base class for sampling algorithms
34
+ """
35
+ model: LatentDiffusion
36
+
37
+ def __init__(self, model: LatentDiffusion):
38
+ """
39
+ :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
40
+ """
41
+ super().__init__()
42
+ # Set the model $\epsilon_\text{cond}(x_t, c)$
43
+ self.model = model
44
+ # Get number of steps the model was trained with $T$
45
+ self.n_steps = model.n_steps
46
+
47
+
48
+ class SDFSampler(DiffusionSampler):
49
+ """
50
+ ## DDPM Sampler
51
+
52
+ This extends the [`DiffusionSampler` base class](index.html).
53
+
54
+ DDPM samples images by repeatedly removing noise by sampling step by step from
55
+ $p_\theta(x_{t-1} | x_t)$,
56
+
57
+ \begin{align}
58
+
59
+ p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\
60
+
61
+ \mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
62
+ + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
63
+
64
+ \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\
65
+
66
+ x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\
67
+
68
+ \end{align}
69
+ """
70
+
71
+ model: LatentDiffusion
72
+
73
+ def __init__(
74
+ self,
75
+ model: LatentDiffusion,
76
+ max_l,
77
+ h,
78
+ is_autocast=False,
79
+ is_show_image=False,
80
+ device=None,
81
+ debug_mode=False
82
+ ):
83
+ """
84
+ :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
85
+ """
86
+ super().__init__(model)
87
+ if device is None:
88
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
89
+ else:
90
+ self.device = device
91
+ # selected time steps ($\tau$) $1, 2, \dots, T$
92
+ # self.time_steps = np.asarray(list(range(self.n_steps)), dtype=np.int32)
93
+ self.tau = torch.tensor([13, 53, 116, 193, 310, 443, 587, 730, 845, 999], device=self.device) # torch.tensor([999, 845, 730, 587, 443, 310, 193, 116, 53, 13])
94
+ # self.tau = torch.tensor(np.asarray(list(range(self.n_steps)), dtype=np.int32), device=self.device)
95
+ self.used_n_steps = len(self.tau)
96
+
97
+ self.is_show_image = is_show_image
98
+
99
+ self.autocast = torch.cuda.amp.autocast(enabled=is_autocast)
100
+
101
+ self.out_channel = self.model.eps_model.out_channels
102
+ self.max_l = max_l
103
+ self.h = h
104
+ self.debug_mode = debug_mode
105
+ self.guidance_scale = 7.5
106
+ self.guidance_rescale = 0.7
107
+
108
+ # now, we set the coefficients
109
+ with torch.no_grad():
110
+ # $\bar\alpha_t$
111
+ self.alpha_bar = self.model.alpha_bar
112
+ # $\beta_t$ schedule
113
+ beta = self.model.beta
114
+ # $\bar\alpha_{t-1}$
115
+ self.alpha_bar_prev = torch.cat([self.alpha_bar.new_tensor([1.]), self.alpha_bar[:-1]])
116
+ # $\sigma_t$ in DDIM
117
+ self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev)) # DDPM noise schedule
118
+
119
+ # $\frac{1}{\sqrt{\bar\alpha}}$
120
+ self.one_over_sqrt_alpha_bar = 1 / (self.alpha_bar ** 0.5)
121
+ # $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$
122
+ self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar = (1 - self.alpha_bar)**0.5 / self.alpha_bar**0.5
123
+
124
+ # $\sqrt{\bar\alpha}$
125
+ self.sqrt_alpha_bar = self.alpha_bar ** 0.5
126
+ # $\sqrt{1 - \bar\alpha}$
127
+ self.sqrt_1m_alpha_bar = (1 - self.alpha_bar) ** 0.5
128
+ # # $\sqrt{\bar\alpha_{t-1}}$
129
+ # self.sqrt_alpha_bar_prev = self.alpha_bar_prev ** 0.5
130
+ # # $\sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}$
131
+ # self.sqrt_1m_alpha_bar_prev_m_sigma2 = (1 - self.alpha_bar_prev - self.sigma_ddim ** 2) ** 0.5
132
+
133
+ #@property
134
+ # def d_cond(self):
135
+ #return self.model.eps_model.d_cond
136
+
137
+ def get_eps(
138
+ self,
139
+ x: torch.Tensor,
140
+ t: torch.Tensor,
141
+ background_cond: Optional[torch.Tensor],
142
+
143
+ uncond_scale: Optional[float],
144
+ ):
145
+ """
146
+ ## Get $\epsilon(x_t, c)$
147
+
148
+ :param x: is $x_t$ of shape `[batch_size, channels, height, width]`
149
+ :param t: is $t$ of shape `[batch_size]`
150
+ :param background_cond: background condition
151
+ :param autoreg_cond: autoregressive condition
152
+ :param external_cond: external condition
153
+ :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
154
+ :param uncond_scale: is the unconditional guidance scale $s$. This is used for
155
+ $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
156
+ :param uncond_cond: is the conditional embedding for empty prompt $c_u$
157
+ """
158
+ # When the scale $s = 1$
159
+ # $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$
160
+
161
+ batch_size = x.size(0)
162
+
163
+ # if hasattr(self.model, 'style_enc'):
164
+ # if external_cond is not None:
165
+ # external_cond = self.model.external_cond_enc(external_cond)
166
+ # if uncond_scale is None or uncond_scale == 1:
167
+ # external_uncond = (-torch.ones_like(external_cond)).to(self.device)
168
+ # else:
169
+ # external_uncond = None
170
+ # # if random.random() < 0.2:
171
+ # # external_cond = (-torch.ones_like(external_cond)).to(self.device)
172
+ # else:
173
+ # external_cond = -torch.ones(batch_size, 4, self.d_cond, device=x.device, dtype=x.dtype)
174
+ # external_uncond = None
175
+ # cond = torch.cat([autoreg_cond, external_cond], 1)
176
+ # if external_uncond is None:
177
+ # uncond = None
178
+ # else:
179
+ # uncond = torch.cat([autoreg_cond, external_uncond], 1)
180
+ # else:
181
+ # cond = autoreg_cond
182
+ # uncond = None
183
+
184
+ if background_cond is not None:
185
+ x = torch.cat([x, background_cond], 1) if background_cond is not None else x
186
+
187
+ # if uncond is None:
188
+ # e_t = self.model(x, t, cond)
189
+ # else:
190
+ # e_t_cond = self.model(x, t, cond)
191
+ # e_t_uncond = self.model(x, t, uncond)
192
+ # e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
193
+
194
+ e_t = self.model(x,t)
195
+ return e_t
196
+
197
+ @torch.no_grad()
198
+ def p_sample(
199
+ self,
200
+ x: torch.Tensor,
201
+ background_cond: Optional[torch.Tensor],
202
+ #autoreg_cond: Optional[torch.Tensor],
203
+ #external_cond: Optional[torch.Tensor],
204
+ t: torch.Tensor,
205
+ step: int,
206
+ repeat_noise: bool = False,
207
+ temperature: float = 1.,
208
+ uncond_scale: float = 1.,
209
+ same_noise_all_measure: bool = False,
210
+ X0EditFunc = None,
211
+ use_classifier_free_guidance = False,
212
+ use_lsh = False,
213
+ reduce_extra_notes=True,
214
+ rhythm_control="Yes",
215
+ ):
216
+ print("p_sample")
217
+ """
218
+ ### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$
219
+
220
+ :param x: is $x_t$ of shape `[batch_size, channels, height, width]`
221
+ :param background_cond: background condition
222
+ :param autoreg_cond: autoregressive condition
223
+ :param external_cond: external condition
224
+ :param t: is $t$ of shape `[batch_size]`
225
+ :param step: is the step $t$ as an integer
226
+ :param repeat_noise: specified whether the noise should be same for all samples in the batch
227
+ :param temperature: is the noise temperature (random noise gets multiplied by this)
228
+ :param uncond_scale: is the unconditional guidance scale $s$. This is used for
229
+ $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
230
+ """
231
+ # Get current tau_i and tau_{i-1}
232
+ tau_i = self.tau[t]
233
+ step_tau_i = self.tau[step]
234
+
235
+ # Get $\epsilon_\theta$
236
+ with self.autocast:
237
+ if use_classifier_free_guidance:
238
+ if use_lsh:
239
+ assert background_cond.shape[1] == 6 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain, lsh_onset, lsh_sustain
240
+ null_lsh = -torch.ones_like(background_cond[:,4:,:,:])
241
+ null_background_cond = torch.cat([background_cond[:,2:4,:,:], null_lsh], axis=1)
242
+ real_background_cond = torch.cat([background_cond[:,:2,:,:], background_cond[:,4:,:,:]], axis=1)
243
+
244
+ e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale)
245
+ e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale)
246
+ e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null)
247
+ if self.guidance_rescale > 0:
248
+ e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale)
249
+ else:
250
+ assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain
251
+ null_background_cond = background_cond[:,2:,:,:]
252
+ real_background_cond = background_cond[:,:2,:,:]
253
+ e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale)
254
+ e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale)
255
+ e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null)
256
+ if self.guidance_rescale > 0:
257
+ e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale)
258
+ else:
259
+ if use_lsh:
260
+ assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, lsh_onset, lsh_sustain
261
+ e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale)
262
+ else:
263
+ assert background_cond.shape[1] == 2 # chd_onset, chd_sustain
264
+ e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale)
265
+
266
+ # Get batch size
267
+ bs = x.shape[0]
268
+
269
+ # $\frac{1}{\sqrt{\bar\alpha}}$
270
+ one_over_sqrt_alpha_bar = x.new_full(
271
+ (bs, 1, 1, 1), self.one_over_sqrt_alpha_bar[step_tau_i]
272
+ )
273
+ # $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$
274
+ sqrt_1m_alpha_bar_over_sqrt_alpha_bar = x.new_full(
275
+ (bs, 1, 1, 1), self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar[step_tau_i]
276
+ )
277
+
278
+ # $\sigma_t$ in DDIM
279
+ sigma_ddim = x.new_full(
280
+ (bs, 1, 1, 1), self.sigma_ddim[step_tau_i]
281
+ )
282
+
283
+
284
+ # Calculate $x_0$ with current $\epsilon_\theta$
285
+ #
286
+ # predicted x_0 in DDIM
287
+ predicted_x0 = one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - sqrt_1m_alpha_bar_over_sqrt_alpha_bar * e_tau_i
288
+
289
+ # edit predicted x_0
290
+ if X0EditFunc is not None:
291
+ predicted_x0 = X0EditFunc(predicted_x0, background_cond, reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control)
292
+ e_tau_i = (one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - predicted_x0) / sqrt_1m_alpha_bar_over_sqrt_alpha_bar
293
+
294
+ # Do not add noise when $t = 1$ (final step sampling process).
295
+ # Note that `step` is `0` when $t = 1$)
296
+ if step == 0:
297
+ noise = 0
298
+ # If same noise is used for all samples in the batch
299
+ elif repeat_noise:
300
+ if same_noise_all_measure:
301
+ noise = torch.randn((1, predicted_x0.shape[1], 16, predicted_x0.shape[3]), device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1)
302
+ else:
303
+ noise = torch.randn((1, *predicted_x0.shape[1:]), device=self.device)
304
+ # Different noise for each sample
305
+ else:
306
+ if same_noise_all_measure:
307
+ noise = torch.randn(predicted_x0.shape[0], predicted_x0.shape[1], 16, predicted_x0.shape[3], device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1)
308
+ else:
309
+ noise = torch.randn(predicted_x0.shape, device=self.device)
310
+
311
+ # Multiply noise by the temperature
312
+ noise = noise * temperature
313
+
314
+ if step > 0:
315
+ step_tau_i_m_1 = self.tau[step-1]
316
+ # $\sqrt{\bar\alpha_{\tau_i-1}}$
317
+ sqrt_alpha_bar_prev = x.new_full(
318
+ (bs, 1, 1, 1), self.sqrt_alpha_bar[step_tau_i_m_1]
319
+ )
320
+ # $\sqrt{1-\bar\alpha_{\tau_i-1}-\sigma_\tau^2}$
321
+ sqrt_1m_alpha_bar_prev_m_sigma2 = x.new_full(
322
+ (bs, 1, 1, 1), (1 - self.alpha_bar[step_tau_i_m_1] - self.sigma_ddim[step_tau_i] ** 2) ** 0.5
323
+ )
324
+ direction_to_xt = sqrt_1m_alpha_bar_prev_m_sigma2 * e_tau_i
325
+ x_prev = sqrt_alpha_bar_prev * predicted_x0 + direction_to_xt + sigma_ddim * noise
326
+ else:
327
+ x_prev = predicted_x0 + sigma_ddim * noise
328
+
329
+ # Sample from,
330
+ #
331
+ # $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$
332
+
333
+ #
334
+ return x_prev, predicted_x0, e_tau_i
335
+
336
+ @torch.no_grad()
337
+ def q_sample(
338
+ self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None
339
+ ):
340
+ """
341
+ ### Sample from $q(x_t|x_0)$
342
+
343
+ $$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$
344
+
345
+ :param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
346
+ :param index: is the time step $t$ index
347
+ :param noise: is the noise, $\epsilon$
348
+ """
349
+
350
+ # Random noise, if noise is not specified
351
+ if noise is None:
352
+ noise = torch.randn_like(x0, device=self.device)
353
+
354
+ # Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$
355
+ return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise
356
+
357
+ @torch.no_grad()
358
+ def sample(
359
+ self,
360
+ shape: List[int],
361
+ background_cond: Optional[torch.Tensor] = None,
362
+ #autoreg_cond: Optional[torch.Tensor] = None,
363
+ #external_cond: Optional[torch.Tensor] = None,
364
+ repeat_noise: bool = False,
365
+ temperature: float = 1.,
366
+ uncond_scale: float = 1.,
367
+ x_last: Optional[torch.Tensor] = None,
368
+ t_start: int = 0,
369
+ same_noise_all_measure: bool = False,
370
+ X0EditFunc = None,
371
+ use_classifier_free_guidance = False,
372
+ use_lsh = False,
373
+ reduce_extra_notes=True,
374
+ rhythm_control="Yes",
375
+ ):
376
+ """
377
+ ### Sampling Loop
378
+
379
+ :param shape: is the shape of the generated images in the
380
+ form `[batch_size, channels, height, width]`
381
+ :param background_cond: background condition
382
+ :param autoreg_cond: autoregressive condition
383
+ :param external_cond: external condition
384
+ :param repeat_noise: specified whether the noise should be same for all samples in the batch
385
+ :param temperature: is the noise temperature (random noise gets multiplied by this)
386
+ :param x_last: is $x_T$. If not provided random noise will be used.
387
+ :param uncond_scale: is the unconditional guidance scale $s$. This is used for
388
+ $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
389
+ :param t_start: t_start
390
+ """
391
+
392
+ # Get device and batch size
393
+ bs = shape[0]
394
+
395
+ ######
396
+ print(shape)
397
+ ######
398
+
399
+
400
+ # Get $x_T$
401
+ if same_noise_all_measure:
402
+ x = x_last if x_last is not None else torch.randn(shape[0],shape[1],16,shape[3], device=self.device).repeat(1,1,int(shape[2]/16),1)
403
+ else:
404
+ x = x_last if x_last is not None else torch.randn(shape, device=self.device)
405
+
406
+ # Time steps to sample at $T - t', T - t' - 1, \dots, 1$
407
+ time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:]
408
+
409
+ # Sampling loop
410
+ for step in monit.iterate('Sample', time_steps):
411
+ # Time step $t$
412
+ ts = x.new_full((bs, ), step, dtype=torch.long)
413
+
414
+ x, pred_x0, e_t = self.p_sample(
415
+ x,
416
+ background_cond,
417
+ #autoreg_cond,
418
+ #external_cond,
419
+ ts,
420
+ step,
421
+ repeat_noise=repeat_noise,
422
+ temperature=temperature,
423
+ uncond_scale=uncond_scale,
424
+ same_noise_all_measure=same_noise_all_measure,
425
+ X0EditFunc = X0EditFunc,
426
+ use_classifier_free_guidance = use_classifier_free_guidance,
427
+ use_lsh=use_lsh,
428
+ reduce_extra_notes=reduce_extra_notes,
429
+ rhythm_control=rhythm_control
430
+ )
431
+
432
+ s1 = step + 1
433
+
434
+ # if self.is_show_image:
435
+ # if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0):
436
+ # show_image(x, f"exp/img/x{s1}.png")
437
+
438
+ # Return $x_0$
439
+ # if self.is_show_image:
440
+ # show_image(x, f"exp/img/x0.png")
441
+
442
+ return x
443
+
444
+ @torch.no_grad()
445
+ def paint(
446
+ self,
447
+ x: Optional[torch.Tensor] = None,
448
+ background_cond: Optional[torch.Tensor] = None,
449
+ #autoreg_cond: Optional[torch.Tensor] = None,
450
+ #external_cond: Optional[torch.Tensor] = None,
451
+ t_start: int = 0,
452
+ orig: Optional[torch.Tensor] = None,
453
+ mask: Optional[torch.Tensor] = None,
454
+ orig_noise: Optional[torch.Tensor] = None,
455
+ uncond_scale: float = 1.,
456
+ same_noise_all_measure: bool = False,
457
+ X0EditFunc = None,
458
+ use_classifier_free_guidance = False,
459
+ use_lsh = False,
460
+ ):
461
+ """
462
+ ### Painting Loop
463
+
464
+ :param x: is $x_{S'}$ of shape `[batch_size, channels, height, width]`
465
+ :param background_cond: background condition
466
+ :param autoreg_cond: autoregressive condition
467
+ :param external_cond: external condition
468
+ :param t_start: is the sampling step to start from, $S'$
469
+ :param orig: is the original image in latent page which we are in paining.
470
+ If this is not provided, it'll be an image to image transformation.
471
+ :param mask: is the mask to keep the original image.
472
+ :param orig_noise: is fixed noise to be added to the original image.
473
+ :param uncond_scale: is the unconditional guidance scale $s$. This is used for
474
+ $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
475
+ """
476
+ # Get batch size
477
+ bs = orig.size(0)
478
+
479
+ if x is None:
480
+ x = torch.randn(orig.shape, device=self.device)
481
+
482
+ # Time steps to sample at $\tau_{S`}, \tau_{S' - 1}, \dots, \tau_1$
483
+ # time_steps = np.flip(self.time_steps[: t_start])
484
+ time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:]
485
+
486
+ for i, step in monit.enum('Paint', time_steps):
487
+ # Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
488
+ # index = len(time_steps) - i - 1
489
+ # Time step $\tau_i$
490
+ ts = x.new_full((bs, ), step, dtype=torch.long)
491
+
492
+ # Sample $x_{\tau_{i-1}}$
493
+ x, _, _ = self.p_sample(
494
+ x,
495
+ background_cond,
496
+ #autoreg_cond,
497
+ #external_cond,
498
+ t=ts,
499
+ step=step,
500
+ uncond_scale=uncond_scale,
501
+ same_noise_all_measure=same_noise_all_measure,
502
+ X0EditFunc = X0EditFunc,
503
+ use_classifier_free_guidance = use_classifier_free_guidance,
504
+ use_lsh=use_lsh,
505
+ )
506
+
507
+ # Replace the masked area with original image
508
+ if orig is not None:
509
+ assert mask is not None
510
+ # Get the $q_{\sigma,\tau}(x_{\tau_i}|x_0)$ for original image in latent space
511
+ orig_t = self.q_sample(orig, self.tau[step], noise=orig_noise)
512
+ # Replace the masked area
513
+ x = orig_t * mask + x * (1 - mask)
514
+
515
+ s1 = step + 1
516
+
517
+ # if self.is_show_image:
518
+ # if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0):
519
+ # show_image(x, f"exp/img/x{s1}.png")
520
+
521
+ # if self.is_show_image:
522
+ # show_image(x, f"exp/img/x0.png")
523
+ return x
524
+
525
+ def generate(self, background_cond=None, batch_size=1, uncond_scale=None,
526
+ same_noise_all_measure=False, X0EditFunc=None,
527
+ use_classifier_free_guidance=False, use_lsh=False, reduce_extra_notes=True, rhythm_control="Yes"):
528
+
529
+ shape = [batch_size, self.out_channel, self.max_l, self.h]
530
+
531
+ if self.debug_mode:
532
+ return torch.randn(shape, dtype=torch.float)
533
+
534
+ return self.sample(shape, background_cond, uncond_scale=uncond_scale, same_noise_all_measure=same_noise_all_measure,
535
+ X0EditFunc=X0EditFunc, use_classifier_free_guidance=use_classifier_free_guidance, use_lsh=use_lsh,
536
+ reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control
537
+ )
538
+
output_0.mid ADDED
Binary file (376 Bytes). View file
 
output_0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9abffb8b039f86161f025cab6419eecf93ec741ea67e66964ca8e79d333c9d4
3
+ size 2469720
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ TiMidity++
piano_roll.png ADDED

Git LFS Details

  • SHA256: cf4d433689089c7895ed3ebf569e2dda9284a8d443481f6d8aa9f9575089cc37
  • Pointer size: 130 Bytes
  • Size of remote file: 16.4 kB
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio#==4.44.0
2
+ imageio#==2.35.1
3
+ imageio[ffmpeg]
4
+ labml#==0.5.3
5
+ librosa
6
+ mir_eval
7
+ matplotlib
8
+ music21
9
+ numba==0.53.1
10
+ numpy==1.19.5
11
+ opencv-python
12
+ # pandas==1.2.5
13
+ pretty_midi
14
+ pydub
15
+ requests
16
+ soundfile
17
+ fluidsynth
18
+ scikit-learn
19
+ torch==2.4.1
20
+ torchvision
21
+ tqdm
22
+ tensorboard
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:852cfd7b011bb1ba0ce1d0d05a7acd672c7cd4934756b2e0d357d8002b5ecb6b
3
+ size 441623592
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/events.out.tfevents.1726894943.berkeleyaisim3.16517.0 ADDED
Binary file (157 Bytes). View file
 
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_grad_norm/events.out.tfevents.1726894943.berkeleyaisim3.16517.2 ADDED
Binary file (10.2 kB). View file
 
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/train_loss/events.out.tfevents.1726894943.berkeleyaisim3.16517.1 ADDED
Binary file (10.2 kB). View file
 
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_grad_norm/events.out.tfevents.1726895010.berkeleyaisim3.16517.4 ADDED
Binary file (592 Bytes). View file
 
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/logs/val_loss/events.out.tfevents.1726895010.berkeleyaisim3.16517.3 ADDED
Binary file (592 Bytes). View file
 
results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"batch_size": 16, "max_epoch": 10, "learning_rate": 5e-05, "max_grad_norm": 10, "fp16": true, "in_channels": 6, "out_channels": 2, "channels": 64, "attention_levels": [2, 3], "n_res_blocks": 2, "channel_multipliers": [1, 2, 4, 4], "n_heads": 4, "tf_layers": 1, "d_cond": 2, "linear_start": 0.00085, "linear_end": 0.012, "n_steps": 1000, "latent_scaling_factor": 0.18215}
rhythm_plot_0.png ADDED

Git LFS Details

  • SHA256: 967207ba25f584c5f1559beea33807766b5c7472a025e4e9e182b82e3876e143
  • Pointer size: 130 Bytes
  • Size of remote file: 11.5 kB
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.9
samples/control_vs_uncontrol/example_1_acc_control.jpg ADDED

Git LFS Details

  • SHA256: 2b1addf1c8a495c2cd6dc3f4c2fe2d86139e6808ad839ab8fbc7070bf8d02314
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
samples/control_vs_uncontrol/example_1_acc_control.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98254caf2f2943c0f4bf4a07cfddf0c0dfa8bc33f41e5aafcd2626f33763b680
3
+ size 2469720
samples/control_vs_uncontrol/example_1_acc_uncontrol.jpg ADDED

Git LFS Details

  • SHA256: 0c857985742375d62e30b6a7ad86c2ac27978a5d2f9417a01d85cadf7aa87042
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
samples/control_vs_uncontrol/example_1_acc_uncontrol.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:582e9d559424e40216f3d7e2ded5b07844311da1783fd77792b58a27fa27ba8c
3
+ size 2469720
samples/control_vs_uncontrol/example_1_mel_chd.jpg ADDED

Git LFS Details

  • SHA256: 7308a7c0255420a712837437cbb37a567b2c01606c0d020caa4d44f94a1f8464
  • Pointer size: 130 Bytes
  • Size of remote file: 71.5 kB
samples/control_vs_uncontrol/example_1_mel_chd.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cedfa10bc44605e9cd1c0a04a82d90f7b6c31dfe68d790a7d9c8e6e27117c90e
3
+ size 5292046
samples/control_vs_uncontrol/example_2_acc_control.jpg ADDED

Git LFS Details

  • SHA256: 5e7a9691db8538d58c3d4fff42184c3aac32a73849cc4f2c14d4ec540e9f5d30
  • Pointer size: 130 Bytes
  • Size of remote file: 96.6 kB
samples/control_vs_uncontrol/example_2_acc_control.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6e70d13d7d97c0cb3cbf38d9ad2473f5df5eb04c69a54ddb228aca1134f819a
3
+ size 2364108
samples/control_vs_uncontrol/example_2_acc_uncontrol.jpg ADDED

Git LFS Details

  • SHA256: 01da8567d60a0ab49c4d0e9a78b21ab1a102e196c2eccb39c30a5a721f1998d9
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
samples/control_vs_uncontrol/example_2_acc_uncontrol.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a60053aca94dac302379e536774832a3f90bf2fd7a0d7a23c2dc97b742f78910
3
+ size 2364108
samples/control_vs_uncontrol/example_2_mel_chd.jpg ADDED

Git LFS Details

  • SHA256: ad9825c1cff525d4f348d60e77e8a28769e58a4ef93730a4c8fcfc8394f40e92
  • Pointer size: 130 Bytes
  • Size of remote file: 60.2 kB