""" YourMT3+ with Instrument Conditioning - Google Colab Version Instructions for use in Google Colab: 1. First, run this cell to install dependencies: !pip install torch torchaudio transformers gradio pytorch-lightning 2. Clone the YourMT3 repository: !git clone https://github.com/mimbres/YourMT3.git %cd YourMT3 3. Copy this code to a cell and run it to launch the interface 4. The Gradio interface will provide a public URL you can access """ import sys import os # Add the amt/src directory to Python path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src'))) import subprocess from typing import Tuple, Dict, Literal from ctypes import ArgumentError from html_helper import * from model_helper import * import torchaudio import glob import gradio as gr from gradio_log import Log from pathlib import Path # Create log file log_file = 'amt/log.txt' Path(log_file).touch() # Model Configuration model_name = 'YPTF.MoE+Multi (noPS)' # You can change this precision = '16' project = '2024' print(f"Loading model: {model_name}") # Get model arguments based on selection if model_name == "YMT3+": checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt" args = [checkpoint, '-p', project, '-pr', precision] elif model_name == "YPTF+Single (noPS)": checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt" args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] elif model_name == "YPTF+Multi (PS)": checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt" args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] elif model_name == "YPTF.MoE+Multi (noPS)": checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] elif model_name == "YPTF.MoE+Multi (PS)": checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt" args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] else: raise ValueError(f"Unknown model: {model_name}") # Load model print("Loading model checkpoint...") try: model = load_model_checkpoint(args=args, device="cpu") model.to("cuda") print("✓ Model loaded successfully!") except Exception as e: print(f"✗ Error loading model: {e}") print("Make sure the model checkpoints are available in amt/logs/") # Helper functions def prepare_media(source_path_or_url: os.PathLike, source_type: Literal['audio_filepath', 'youtube_url'], delete_video: bool = True, simulate = False) -> Dict: """prepare media from source path or youtube, and return audio info""" if source_type == 'audio_filepath': audio_file = source_path_or_url elif source_type == 'youtube_url': if os.path.exists('/content/yt_audio.mp3'): # Colab path os.remove('/content/yt_audio.mp3') # Download from youtube with open(log_file, 'w') as lf: audio_file = '/content/yt_audio' # Colab path command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio', '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames', '--extractor-retries', '10', '--force-overwrites'] if simulate: command = command + ['-s'] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) for line in iter(process.stdout.readline, ''): print(line) lf.write(line); lf.flush() process.stdout.close() process.wait() audio_file += '.mp3' else: raise ValueError(source_type) # Create info info = torchaudio.info(audio_file) return { "filepath": audio_file, "track_name": os.path.basename(audio_file).split('.')[0], "sample_rate": int(info.sample_rate), "bits_per_sample": int(info.bits_per_sample), "num_channels": int(info.num_channels), "num_frames": int(info.num_frames), "duration": int(info.num_frames / info.sample_rate), "encoding": str.lower(info.encoding), } def process_audio(audio_filepath, instrument_hint=None): """Process uploaded audio with optional instrument conditioning""" if audio_filepath is None: return None try: audio_info = prepare_media(audio_filepath, source_type='audio_filepath') midifile = transcribe(model, audio_info, instrument_hint) midifile = to_data_url(midifile) return create_html_from_midi(midifile) except Exception as e: return f"

Error processing audio: {str(e)}

" def process_video(youtube_url, instrument_hint=None): """Process YouTube video with optional instrument conditioning""" if 'youtu' not in youtube_url: return None try: audio_info = prepare_media(youtube_url, source_type='youtube_url') midifile = transcribe(model, audio_info, instrument_hint) midifile = to_data_url(midifile) return create_html_from_midi(midifile) except Exception as e: return f"

Error processing YouTube video: {str(e)}

" def play_video(youtube_url): if 'youtu' not in youtube_url: return None return create_html_youtube_player(youtube_url) # Get example files AUDIO_EXAMPLES = glob.glob('examples/*.*', recursive=True) YOUTUBE_EXAMPLES = ["https://youtu.be/5vJBhdjvVcE?si=s3NFG_SlVju0Iklg", "https://youtu.be/mw5VIEIvuMI?si=Dp9UFVw00Tl8CXe2", "https://youtu.be/OXXRoa1U6xU?si=dpYMun4LjZHNydSb"] # Gradio theme theme = gr.Theme.from_hub("gradio/dracula_revamped") css = """ .gradio-container { background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); background-size: 400% 400%; animation: gradient 15s ease infinite; } @keyframes gradient { 0% {background-position: 0% 50%;} 50% {background-position: 100% 50%;} 100% {background-position: 0% 50%;} } """ # Create Gradio interface with gr.Blocks(theme=theme, css=css) as demo: gr.Markdown(f""" # 🎶 YourMT3+ with Instrument Conditioning **Enhanced music transcription with instrument-specific control!** **New Feature**: Select which instrument you want to transcribe from the dropdown menu. This solves the problem of the model switching between instruments mid-track. **Model**: `{model_name}` | **Running in**: Google Colab --- """) with gr.Tabs(): with gr.Tab("🎵 Upload Audio"): with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="Upload Audio File", type="filepath", format="wav" ) instrument_selector = gr.Dropdown( choices=[ "Auto (detect all instruments)", "Vocals/Singing", "Guitar", "Piano", "Violin", "Drums", "Bass", "Saxophone", "Flute" ], value="Auto (detect all instruments)", label="🎯 Target Instrument", info="NEW! Choose the specific instrument you want to transcribe" ) transcribe_button = gr.Button("🎼 Transcribe", variant="primary", size="lg") if AUDIO_EXAMPLES: gr.Examples(examples=AUDIO_EXAMPLES[:5], inputs=audio_input) with gr.Row(): output_audio = gr.HTML(label="Transcription Result") with gr.Tab("📺 YouTube"): with gr.Row(): with gr.Column(): youtube_input = gr.Textbox( label="YouTube URL", placeholder="https://youtu.be/..." ) youtube_instrument_selector = gr.Dropdown( choices=[ "Auto (detect all instruments)", "Vocals/Singing", "Guitar", "Piano", "Violin", "Drums", "Bass", "Saxophone", "Flute" ], value="Auto (detect all instruments)", label="🎯 Target Instrument", info="Choose the specific instrument you want to transcribe" ) with gr.Row(): play_button = gr.Button("▶️ Preview Video", variant="secondary") transcribe_yt_button = gr.Button("🎼 Transcribe", variant="primary") gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_input) with gr.Row(): with gr.Column(): youtube_player = gr.HTML(label="Video Preview") with gr.Column(): output_youtube = gr.HTML(label="Transcription Result") # Event handlers def process_with_instrument_audio(audio_file, instrument_choice): instrument_map = { "Auto (detect all instruments)": None, "Vocals/Singing": "vocals", "Guitar": "guitar", "Piano": "piano", "Violin": "violin", "Drums": "drums", "Bass": "bass", "Saxophone": "saxophone", "Flute": "flute" } instrument_hint = instrument_map.get(instrument_choice, None) return process_audio(audio_file, instrument_hint) def process_with_instrument_youtube(url, instrument_choice): instrument_map = { "Auto (detect all instruments)": None, "Vocals/Singing": "vocals", "Guitar": "guitar", "Piano": "piano", "Violin": "violin", "Drums": "drums", "Bass": "bass", "Saxophone": "saxophone", "Flute": "flute" } instrument_hint = instrument_map.get(instrument_choice, None) return process_video(url, instrument_hint) # Connect events transcribe_button.click( process_with_instrument_audio, inputs=[audio_input, instrument_selector], outputs=output_audio ) transcribe_yt_button.click( process_with_instrument_youtube, inputs=[youtube_input, youtube_instrument_selector], outputs=output_youtube ) play_button.click(play_video, inputs=youtube_input, outputs=youtube_player) print("🚀 Launching YourMT3+ with Instrument Conditioning...") print("📝 Tips:") print(" • Try 'Vocals/Singing' for vocal tracks to avoid instrument switching") print(" • Use 'Guitar' for guitar solos to get complete transcriptions") print(" • 'Auto' works like the original YourMT3+") # Launch with share=True for Colab public URL demo.launch(share=True, debug=True)