yourmt3 / app_colab.py
asdd12e2ad's picture
asd
c207bc4
"""
YourMT3+ with Instrument Conditioning - Google Colab Version
Instructions for use in Google Colab:
1. First, run this cell to install dependencies:
!pip install torch torchaudio transformers gradio pytorch-lightning
2. Clone the YourMT3 repository:
!git clone https://github.com/mimbres/YourMT3.git
%cd YourMT3
3. Copy this code to a cell and run it to launch the interface
4. The Gradio interface will provide a public URL you can access
"""
import sys
import os
# Add the amt/src directory to Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
import subprocess
from typing import Tuple, Dict, Literal
from ctypes import ArgumentError
from html_helper import *
from model_helper import *
import torchaudio
import glob
import gradio as gr
from gradio_log import Log
from pathlib import Path
# Create log file
log_file = 'amt/log.txt'
Path(log_file).touch()
# Model Configuration
model_name = 'YPTF.MoE+Multi (noPS)' # You can change this
precision = '16'
project = '2024'
print(f"Loading model: {model_name}")
# Get model arguments based on selection
if model_name == "YMT3+":
checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt"
args = [checkpoint, '-p', project, '-pr', precision]
elif model_name == "YPTF+Single (noPS)":
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
'-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF+Multi (PS)":
checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (noPS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (PS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
else:
raise ValueError(f"Unknown model: {model_name}")
# Load model
print("Loading model checkpoint...")
try:
model = load_model_checkpoint(args=args, device="cpu")
model.to("cuda")
print("โœ“ Model loaded successfully!")
except Exception as e:
print(f"โœ— Error loading model: {e}")
print("Make sure the model checkpoints are available in amt/logs/")
# Helper functions
def prepare_media(source_path_or_url: os.PathLike,
source_type: Literal['audio_filepath', 'youtube_url'],
delete_video: bool = True,
simulate = False) -> Dict:
"""prepare media from source path or youtube, and return audio info"""
if source_type == 'audio_filepath':
audio_file = source_path_or_url
elif source_type == 'youtube_url':
if os.path.exists('/content/yt_audio.mp3'): # Colab path
os.remove('/content/yt_audio.mp3')
# Download from youtube
with open(log_file, 'w') as lf:
audio_file = '/content/yt_audio' # Colab path
command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
'-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
'--extractor-retries', '10', '--force-overwrites']
if simulate:
command = command + ['-s']
process = subprocess.Popen(command,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
for line in iter(process.stdout.readline, ''):
print(line)
lf.write(line); lf.flush()
process.stdout.close()
process.wait()
audio_file += '.mp3'
else:
raise ValueError(source_type)
# Create info
info = torchaudio.info(audio_file)
return {
"filepath": audio_file,
"track_name": os.path.basename(audio_file).split('.')[0],
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample),
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(info.encoding),
}
def process_audio(audio_filepath, instrument_hint=None):
"""Process uploaded audio with optional instrument conditioning"""
if audio_filepath is None:
return None
try:
audio_info = prepare_media(audio_filepath, source_type='audio_filepath')
midifile = transcribe(model, audio_info, instrument_hint)
midifile = to_data_url(midifile)
return create_html_from_midi(midifile)
except Exception as e:
return f"<p style='color: red;'>Error processing audio: {str(e)}</p>"
def process_video(youtube_url, instrument_hint=None):
"""Process YouTube video with optional instrument conditioning"""
if 'youtu' not in youtube_url:
return None
try:
audio_info = prepare_media(youtube_url, source_type='youtube_url')
midifile = transcribe(model, audio_info, instrument_hint)
midifile = to_data_url(midifile)
return create_html_from_midi(midifile)
except Exception as e:
return f"<p style='color: red;'>Error processing YouTube video: {str(e)}</p>"
def play_video(youtube_url):
if 'youtu' not in youtube_url:
return None
return create_html_youtube_player(youtube_url)
# Get example files
AUDIO_EXAMPLES = glob.glob('examples/*.*', recursive=True)
YOUTUBE_EXAMPLES = ["https://youtu.be/5vJBhdjvVcE?si=s3NFG_SlVju0Iklg",
"https://youtu.be/mw5VIEIvuMI?si=Dp9UFVw00Tl8CXe2",
"https://youtu.be/OXXRoa1U6xU?si=dpYMun4LjZHNydSb"]
# Gradio theme
theme = gr.Theme.from_hub("gradio/dracula_revamped")
css = """
.gradio-container {
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
background-size: 400% 400%;
animation: gradient 15s ease infinite;
}
@keyframes gradient {
0% {background-position: 0% 50%;}
50% {background-position: 100% 50%;}
100% {background-position: 0% 50%;}
}
"""
# Create Gradio interface
with gr.Blocks(theme=theme, css=css) as demo:
gr.Markdown(f"""
# ๐ŸŽถ YourMT3+ with Instrument Conditioning
**Enhanced music transcription with instrument-specific control!**
**New Feature**: Select which instrument you want to transcribe from the dropdown menu.
This solves the problem of the model switching between instruments mid-track.
**Model**: `{model_name}` | **Running in**: Google Colab
---
""")
with gr.Tabs():
with gr.Tab("๐ŸŽต Upload Audio"):
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Upload Audio File",
type="filepath",
format="wav"
)
instrument_selector = gr.Dropdown(
choices=[
"Auto (detect all instruments)",
"Vocals/Singing",
"Guitar",
"Piano",
"Violin",
"Drums",
"Bass",
"Saxophone",
"Flute"
],
value="Auto (detect all instruments)",
label="๐ŸŽฏ Target Instrument",
info="NEW! Choose the specific instrument you want to transcribe"
)
transcribe_button = gr.Button("๐ŸŽผ Transcribe", variant="primary", size="lg")
if AUDIO_EXAMPLES:
gr.Examples(examples=AUDIO_EXAMPLES[:5], inputs=audio_input)
with gr.Row():
output_audio = gr.HTML(label="Transcription Result")
with gr.Tab("๐Ÿ“บ YouTube"):
with gr.Row():
with gr.Column():
youtube_input = gr.Textbox(
label="YouTube URL",
placeholder="https://youtu.be/..."
)
youtube_instrument_selector = gr.Dropdown(
choices=[
"Auto (detect all instruments)",
"Vocals/Singing",
"Guitar",
"Piano",
"Violin",
"Drums",
"Bass",
"Saxophone",
"Flute"
],
value="Auto (detect all instruments)",
label="๐ŸŽฏ Target Instrument",
info="Choose the specific instrument you want to transcribe"
)
with gr.Row():
play_button = gr.Button("โ–ถ๏ธ Preview Video", variant="secondary")
transcribe_yt_button = gr.Button("๐ŸŽผ Transcribe", variant="primary")
gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_input)
with gr.Row():
with gr.Column():
youtube_player = gr.HTML(label="Video Preview")
with gr.Column():
output_youtube = gr.HTML(label="Transcription Result")
# Event handlers
def process_with_instrument_audio(audio_file, instrument_choice):
instrument_map = {
"Auto (detect all instruments)": None,
"Vocals/Singing": "vocals",
"Guitar": "guitar",
"Piano": "piano",
"Violin": "violin",
"Drums": "drums",
"Bass": "bass",
"Saxophone": "saxophone",
"Flute": "flute"
}
instrument_hint = instrument_map.get(instrument_choice, None)
return process_audio(audio_file, instrument_hint)
def process_with_instrument_youtube(url, instrument_choice):
instrument_map = {
"Auto (detect all instruments)": None,
"Vocals/Singing": "vocals",
"Guitar": "guitar",
"Piano": "piano",
"Violin": "violin",
"Drums": "drums",
"Bass": "bass",
"Saxophone": "saxophone",
"Flute": "flute"
}
instrument_hint = instrument_map.get(instrument_choice, None)
return process_video(url, instrument_hint)
# Connect events
transcribe_button.click(
process_with_instrument_audio,
inputs=[audio_input, instrument_selector],
outputs=output_audio
)
transcribe_yt_button.click(
process_with_instrument_youtube,
inputs=[youtube_input, youtube_instrument_selector],
outputs=output_youtube
)
play_button.click(play_video, inputs=youtube_input, outputs=youtube_player)
print("๐Ÿš€ Launching YourMT3+ with Instrument Conditioning...")
print("๐Ÿ“ Tips:")
print(" โ€ข Try 'Vocals/Singing' for vocal tracks to avoid instrument switching")
print(" โ€ข Use 'Guitar' for guitar solos to get complete transcriptions")
print(" โ€ข 'Auto' works like the original YourMT3+")
# Launch with share=True for Colab public URL
demo.launch(share=True, debug=True)