Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
Quick test script for YourMT3+ instrument conditioning | |
Run this to test if everything is working before launching the full interface | |
""" | |
import sys | |
import os | |
from pathlib import Path | |
# Add amt/src to path | |
sys.path.append(os.path.abspath('amt/src')) | |
def test_basic_import(): | |
"""Test if we can import the basic modules""" | |
print("π Testing basic imports...") | |
try: | |
import torch | |
print("β torch") | |
import torchaudio | |
print("β torchaudio") | |
import gradio as gr | |
print("β gradio") | |
# Test YourMT3 imports | |
from model_helper import load_model_checkpoint, transcribe | |
print("β model_helper") | |
from html_helper import create_html_from_midi, to_data_url | |
print("β html_helper") | |
return True | |
except Exception as e: | |
print(f"β Import error: {e}") | |
return False | |
def test_model_loading(): | |
"""Test model loading with debug output""" | |
print("\nπ Testing model loading...") | |
try: | |
from model_helper import load_model_checkpoint | |
# Use the same args as app.py | |
model_name = 'YPTF.MoE+Multi (noPS)' | |
precision = '16' | |
project = '2024' | |
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" | |
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] | |
print(f"Loading {model_name}...") | |
model = load_model_checkpoint(args=args, device="cpu") | |
print("β Model loaded successfully!") | |
# Test our debug function | |
from model_helper import debug_model_task_config | |
debug_model_task_config(model) | |
return model | |
except Exception as e: | |
print(f"β Model loading failed: {e}") | |
import traceback | |
traceback.print_exc() | |
return None | |
def test_instrument_conditioning(model): | |
"""Test the instrument conditioning with a sample file""" | |
print("\nπ Testing instrument conditioning...") | |
# Find a test audio file | |
example_files = list(Path("examples").glob("*.wav")) | |
if not example_files: | |
print("β No example files found") | |
return False | |
test_file = example_files[0] | |
print(f"Using test file: {test_file}") | |
try: | |
import torchaudio | |
from model_helper import transcribe | |
# Create audio info | |
info = torchaudio.info(str(test_file)) | |
audio_info = { | |
"filepath": str(test_file), | |
"track_name": test_file.stem + "_test", | |
"sample_rate": int(info.sample_rate), | |
"bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16, | |
"num_channels": int(info.num_channels), | |
"num_frames": int(info.num_frames), | |
"duration": int(info.num_frames / info.sample_rate), | |
"encoding": str.lower(str(info.encoding)), | |
} | |
print("\n--- Testing normal transcription ---") | |
midifile1 = transcribe(model, audio_info, instrument_hint=None) | |
print(f"Normal transcription result: {midifile1}") | |
print("\n--- Testing vocals conditioning ---") | |
midifile2 = transcribe(model, audio_info, instrument_hint="vocals") | |
print(f"Vocals transcription result: {midifile2}") | |
print("β Instrument conditioning test completed!") | |
return True | |
except Exception as e: | |
print(f"β Instrument conditioning test failed: {e}") | |
import traceback | |
traceback.print_exc() | |
return False | |
def main(): | |
print("π΅ YourMT3+ Quick Test") | |
print("=" * 40) | |
# Check if we're in the right directory | |
if not Path("app.py").exists(): | |
print("β Please run this from the YourMT3 directory") | |
sys.exit(1) | |
print(f"π Working directory: {os.getcwd()}") | |
# Test imports | |
if not test_basic_import(): | |
print("\nβ Basic imports failed - install dependencies first") | |
sys.exit(1) | |
# Test model loading | |
model = test_model_loading() | |
if model is None: | |
print("\nβ Model loading failed - check model weights") | |
sys.exit(1) | |
# Test instrument conditioning | |
if test_instrument_conditioning(model): | |
print("\nπ All tests passed!") | |
print("\nYou can now run:") | |
print(" python app.py") | |
print("\nThen visit: http://127.0.0.1:7860") | |
else: | |
print("\nβ οΈ Some tests failed but basic functionality should work") | |
print("You can still try running: python app.py") | |
if __name__ == "__main__": | |
main() | |