|
import gradio as gr |
|
import pretty_midi |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import soundfile as sf |
|
import cv2 |
|
import imageio |
|
|
|
import sys |
|
import subprocess |
|
import os |
|
import torch |
|
from model import init_ldm_model |
|
from model.model_sdf import Diffpro_SDF |
|
from model.sampler_sdf import SDFSampler |
|
|
|
import pickle |
|
from train.train_params import params_chord_lsh_cond |
|
from generation.gen_utils import * |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt' |
|
chord_list = list(CHORD_DICTIONARY.keys()) |
|
|
|
def get_shape(file_path): |
|
if file_path.endswith('.jpg'): |
|
img = cv2.imread(file_path) |
|
return img.shape |
|
|
|
elif file_path.endswith('.mp4'): |
|
vid = imageio.get_reader(file_path) |
|
return vid.get_meta_data()['size'] |
|
|
|
else: |
|
raise ValueError("Unsupported file type") |
|
|
|
|
|
def midi_to_wav(midi, output_file): |
|
|
|
audio_data = midi.fluidsynth() |
|
|
|
|
|
sf.write(output_file, audio_data, samplerate=44100) |
|
|
|
def update_musescore_image(selected_prompt): |
|
|
|
if selected_prompt == "example 1": |
|
return "samples/diy_examples/example1/example1.jpg" |
|
elif selected_prompt == "example 2": |
|
return "samples/diy_examples/example2/example2.jpg" |
|
elif selected_prompt == "example 3": |
|
return "samples/diy_examples/example3/example3.jpg" |
|
elif selected_prompt == "example 4": |
|
return "samples/diy_examples/example4/example4.jpg" |
|
elif selected_prompt == "example 5": |
|
return "samples/diy_examples/example5/example5.jpg" |
|
elif selected_prompt == "example 6": |
|
return "samples/diy_examples/example6/example6.jpg" |
|
|
|
|
|
def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"): |
|
ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False) |
|
model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device) |
|
sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False) |
|
|
|
if mode=="example": |
|
if prompt == "example 1": |
|
background_condition = np.load("samples/diy_examples/example1/example1.npy") |
|
tempo=70 |
|
elif prompt == "example 2": |
|
background_condition = np.load("samples/diy_examples/example2/example2.npy") |
|
elif prompt == "example 3": |
|
background_condition = np.load("samples/diy_examples/example3/example3.npy") |
|
elif prompt == "example 4": |
|
background_condition = np.load("samples/diy_examples/example4/example4.npy") |
|
|
|
background_condition = np.tile(background_condition, (num_samples,1,1,1)) |
|
background_condition = torch.Tensor(background_condition).to(device) |
|
else: |
|
background_condition = np.tile(prompt, (num_samples,1,1,1)) |
|
background_condition = torch.Tensor(background_condition).to(device) |
|
|
|
if rhythm_control!="Yes": |
|
background_condition[:,0:2] = background_condition[:,2:4] |
|
|
|
output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples, |
|
same_noise_all_measure=False, X0EditFunc=X0EditFunc, |
|
use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False, |
|
rhythm_control=rhythm_control) |
|
output_x = torch.clamp(output_x, min=0, max=1) |
|
output_x = output_x.cpu().numpy() |
|
|
|
|
|
for i in range(num_samples): |
|
full_roll = extend_piano_roll(output_x[i]) |
|
full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) |
|
full_lsh_roll = None |
|
if background_condition.shape[1]>=6: |
|
if background_condition[:,4:6,:,:].min()>=0: |
|
full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy()) |
|
midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo) |
|
filename = f"output_{i}.mid" |
|
save_midi(midi_file, filename) |
|
subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate() |
|
|
|
return 'output_0.mid', 'output_0.wav', midi_file |
|
|
|
|
|
def visualize_midi(midi): |
|
|
|
roll = midi.get_piano_roll(fs=100) |
|
|
|
|
|
plt.figure(figsize=(10, 4)) |
|
plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest') |
|
plt.title("Piano Roll") |
|
plt.xlabel("Time") |
|
plt.ylabel("Pitch") |
|
plt.colorbar() |
|
|
|
|
|
output_image_path = "piano_roll.png" |
|
plt.savefig(output_image_path) |
|
return output_image_path |
|
|
|
|
|
def generate_from_example(prompt): |
|
midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control="No") |
|
piano_roll_image = visualize_midi(midi) |
|
return audio_output, piano_roll_image |
|
|
|
|
|
prompt_list = ["example 1", "example 2", "example 3", "example 4"] |
|
|
|
custom_css = """ |
|
.custom-purple { |
|
background-color: #d7bde2; |
|
padding: 10px; |
|
border-radius: 5px; |
|
} |
|
.audio_waveform-container { |
|
display: none !important; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown("# <div style='text-align: center;font-size:40px'> Efficient Fine-Grained Guidance for Diffusion Model Based Symbolic Music Generation <div style='text-align: center;'>") |
|
|
|
gr.Markdown("<div style='text-align: center;font-size:20px'>Tingyu Zhu<sup>*</sup>, Haoyu Liu<sup>*</sup>, Ziyu Wang, Zhimin Jiang, Zeyu Zheng</div>") |
|
gr.Markdown("<div style='text-align: center;font-size:20px'><a href='https://arxiv.org/abs/2410.08435'>[Paper]</a> <a href='https://github.com/huajianduzhuo-code/FGG-music-code'>[Code Repo]</a></div>") |
|
|
|
gr.Markdown("<span style='font-size:25px;'> For detailed information and demonstrations of our method, please visit our [GitHub Pages site](https://huajianduzhuo-code.github.io/FGG-diffusion-music/) to explore:\ |
|
\n   1. Accompaniment Generation given Melody and Chord\ |
|
\n   2. Style-Controlled Music Generation\ |
|
\n   3. Demonstrating the Effectiveness of Sampling Control by Comparison</span>") |
|
|
|
gr.HTML("<div style='height: 50px;'></div>") |
|
gr.Markdown("\n\n\n") |
|
gr.Markdown("# <span style='color: red;'> Interactive Demo </span>") |
|
gr.Markdown( |
|
"<span style='font-size:20px;'>" |
|
"π΅ Try out our interactive tool to generate music with our model!<br>" |
|
"You can create new accompaniments conditioned on a given melody and chord progression." |
|
"</span>" |
|
) |
|
|
|
gr.Markdown( |
|
"<span style='color:blue; font-size:20px;'>" |
|
"β οΈ This Space currently runs on a Hugging Face-provided CPU. On average, it takes ~15 seconds to generate a 4-measure music segment.<br>" |
|
"If multiple users are generating at the same time, you may enter a queue, which can cause delays.<br><br>" |
|
"π On our local server (NVIDIA RTX 6000 Ada GPU), the same generation takes only 0.4 seconds.<br><br>" |
|
"To speed things up, you can: <br>" |
|
"β’ π Fork this Space and select a different hardware configuration<br>" |
|
"β’ π§βπ» Clone our <a href='https://github.com/huajianduzhuo-code/FGG-music-code'>[Code Repo]</a> and run the generation notebooks locally after installing dependencies and downloading the model weights." |
|
"</span>" |
|
) |
|
|
|
|
|
with gr.Column(elem_classes="custom-purple"): |
|
gr.Markdown("### Select an example to generate music given melody and chord condition") |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1") |
|
gr.Markdown("### This is the melody to be conditioned on:") |
|
condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition") |
|
prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore) |
|
|
|
with gr.Column(): |
|
generate_button = gr.Button("Generate") |
|
gr.Markdown("### Generation results:") |
|
audio_output = gr.Audio(label="Generated Music") |
|
piano_roll_output = gr.Image(label="Generated Piano Roll") |
|
|
|
generate_button.click( |
|
fn=generate_from_example, |
|
inputs=[prompt_selector], |
|
outputs=[audio_output, piano_roll_output] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|