interactive-symbolic-music's picture
moved to github.io
049cfd4
import gradio as gr
import pretty_midi
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import cv2
import imageio
import sys
import subprocess
import os
import torch
from model import init_ldm_model
from model.model_sdf import Diffpro_SDF
from model.sampler_sdf import SDFSampler
import pickle
from train.train_params import params_chord_lsh_cond
from generation.gen_utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'results/test/model_with_chord_lsh_cond_and_rhythm_onset_and_null_sep/chkpts/weights_best.pt'
chord_list = list(CHORD_DICTIONARY.keys())
def get_shape(file_path):
if file_path.endswith('.jpg'):
img = cv2.imread(file_path)
return img.shape # (height, width, channels)
elif file_path.endswith('.mp4'):
vid = imageio.get_reader(file_path)
return vid.get_meta_data()['size'] # (width, height)
else:
raise ValueError("Unsupported file type")
# Function to convert MIDI to WAV
def midi_to_wav(midi, output_file):
# Synthesize the waveform from the MIDI using pretty_midi
audio_data = midi.fluidsynth()
# Write the waveform to a WAV file
sf.write(output_file, audio_data, samplerate=44100)
def update_musescore_image(selected_prompt):
# Logic to return the correct image file based on the selected prompt
if selected_prompt == "example 1":
return "samples/diy_examples/example1/example1.jpg"
elif selected_prompt == "example 2":
return "samples/diy_examples/example2/example2.jpg"
elif selected_prompt == "example 3":
return "samples/diy_examples/example3/example3.jpg"
elif selected_prompt == "example 4":
return "samples/diy_examples/example4/example4.jpg"
elif selected_prompt == "example 5":
return "samples/diy_examples/example5/example5.jpg"
elif selected_prompt == "example 6":
return "samples/diy_examples/example6/example6.jpg"
# Model for generating music
def generate_music(prompt, tempo, num_samples=1, mode="example", rhythm_control="Yes"):
ldm_model = init_ldm_model(params_chord_lsh_cond, debug_mode=False)
model = Diffpro_SDF.load_trained(ldm_model, model_path).to(device)
sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False)
if mode=="example":
if prompt == "example 1":
background_condition = np.load("samples/diy_examples/example1/example1.npy")
tempo=70
elif prompt == "example 2":
background_condition = np.load("samples/diy_examples/example2/example2.npy")
elif prompt == "example 3":
background_condition = np.load("samples/diy_examples/example3/example3.npy")
elif prompt == "example 4":
background_condition = np.load("samples/diy_examples/example4/example4.npy")
background_condition = np.tile(background_condition, (num_samples,1,1,1))
background_condition = torch.Tensor(background_condition).to(device)
else:
background_condition = np.tile(prompt, (num_samples,1,1,1))
background_condition = torch.Tensor(background_condition).to(device)
if rhythm_control!="Yes":
background_condition[:,0:2] = background_condition[:,2:4]
# generate samples
output_x = sampler.generate(background_cond=background_condition, batch_size=num_samples,
same_noise_all_measure=False, X0EditFunc=X0EditFunc,
use_classifier_free_guidance=True, use_lsh=True, reduce_extra_notes=False,
rhythm_control=rhythm_control)
output_x = torch.clamp(output_x, min=0, max=1)
output_x = output_x.cpu().numpy()
# save samples
for i in range(num_samples):
full_roll = extend_piano_roll(output_x[i]) # accompaniment roll
full_chd_roll = extend_piano_roll(-background_condition[i,2:4,:,:].cpu().numpy()-1) # chord roll
full_lsh_roll = None
if background_condition.shape[1]>=6:
if background_condition[:,4:6,:,:].min()>=0:
full_lsh_roll = extend_piano_roll(background_condition[i,4:6,:,:].cpu().numpy())
midi_file = piano_roll_to_midi(full_roll, full_chd_roll, full_lsh_roll, bpm=tempo)
filename = f"output_{i}.mid"
save_midi(midi_file, filename)
subprocess.Popen(['timidity',f'output_{i}.mid','-Ow','-o',f'output_{i}.wav']).communicate()
return 'output_0.mid', 'output_0.wav', midi_file
# Function to visualize MIDI notes
def visualize_midi(midi):
# Get piano roll from MIDI
roll = midi.get_piano_roll(fs=100)
# Plot the piano roll
plt.figure(figsize=(10, 4))
plt.imshow(roll, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest')
plt.title("Piano Roll")
plt.xlabel("Time")
plt.ylabel("Pitch")
plt.colorbar()
# Save the plot as an image
output_image_path = "piano_roll.png"
plt.savefig(output_image_path)
return output_image_path
# Gradio main function
def generate_from_example(prompt):
midi_output, audio_output, midi = generate_music(prompt, tempo=80, mode="example", rhythm_control="No")
piano_roll_image = visualize_midi(midi)
return audio_output, piano_roll_image
# Prompt list
prompt_list = ["example 1", "example 2", "example 3", "example 4"]
custom_css = """
.custom-purple {
background-color: #d7bde2;
padding: 10px;
border-radius: 5px;
}
.audio_waveform-container {
display: none !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# <div style='text-align: center;font-size:40px'> Efficient Fine-Grained Guidance for Diffusion Model Based Symbolic Music Generation <div style='text-align: center;'>")
gr.Markdown("<div style='text-align: center;font-size:20px'>Tingyu Zhu<sup>*</sup>, Haoyu Liu<sup>*</sup>, Ziyu Wang, Zhimin Jiang, Zeyu Zheng</div>")
gr.Markdown("<div style='text-align: center;font-size:20px'><a href='https://arxiv.org/abs/2410.08435'>[Paper]</a> <a href='https://github.com/huajianduzhuo-code/FGG-music-code'>[Code Repo]</a></div>")
gr.Markdown("<span style='font-size:25px;'> For detailed information and demonstrations of our method, please visit our [GitHub Pages site](https://huajianduzhuo-code.github.io/FGG-diffusion-music/) to explore:\
\n &emsp; 1. Accompaniment Generation given Melody and Chord\
\n &emsp; 2. Style-Controlled Music Generation\
\n &emsp; 3. Demonstrating the Effectiveness of Sampling Control by Comparison</span>")
gr.HTML("<div style='height: 50px;'></div>")
gr.Markdown("\n\n\n")
gr.Markdown("# <span style='color: red;'> Interactive Demo </span>")
gr.Markdown(
"<span style='font-size:20px;'>"
"🎡 Try out our interactive tool to generate music with our model!<br>"
"You can create new accompaniments conditioned on a given melody and chord progression."
"</span>"
)
gr.Markdown(
"<span style='color:blue; font-size:20px;'>"
"⚠️ This Space currently runs on a Hugging Face-provided CPU. On average, it takes ~15 seconds to generate a 4-measure music segment.<br>"
"If multiple users are generating at the same time, you may enter a queue, which can cause delays.<br><br>"
"πŸš€ On our local server (NVIDIA RTX 6000 Ada GPU), the same generation takes only 0.4 seconds.<br><br>"
"To speed things up, you can: <br>"
"β€’ πŸ” Fork this Space and select a different hardware configuration<br>"
"β€’ πŸ§‘β€πŸ’» Clone our <a href='https://github.com/huajianduzhuo-code/FGG-music-code'>[Code Repo]</a> and run the generation notebooks locally after installing dependencies and downloading the model weights."
"</span>"
)
with gr.Column(elem_classes="custom-purple"):
gr.Markdown("### Select an example to generate music given melody and chord condition")
with gr.Row():
with gr.Column():
prompt_selector = gr.Dropdown(choices=prompt_list, label="Select an example", value="example 1")
gr.Markdown("### This is the melody to be conditioned on:")
condition_musescore = gr.Image("samples/diy_examples/example1/example1.jpg", label="melody, chord, and rhythm condition")
prompt_selector.change(fn=update_musescore_image, inputs=prompt_selector, outputs=condition_musescore)
with gr.Column():
generate_button = gr.Button("Generate")
gr.Markdown("### Generation results:")
audio_output = gr.Audio(label="Generated Music")
piano_roll_output = gr.Image(label="Generated Piano Roll")
generate_button.click(
fn=generate_from_example,
inputs=[prompt_selector],
outputs=[audio_output, piano_roll_output]
)
# Launch Gradio interface
if __name__ == "__main__":
demo.launch()