yourmt3 / test_local.py
asdd12e2ad's picture
asd
c207bc4
#!/usr/bin/env python3
"""
Quick test script for YourMT3+ instrument conditioning
Run this to test if everything is working before launching the full interface
"""
import sys
import os
from pathlib import Path
# Add amt/src to path
sys.path.append(os.path.abspath('amt/src'))
def test_basic_import():
"""Test if we can import the basic modules"""
print("πŸ” Testing basic imports...")
try:
import torch
print("βœ… torch")
import torchaudio
print("βœ… torchaudio")
import gradio as gr
print("βœ… gradio")
# Test YourMT3 imports
from model_helper import load_model_checkpoint, transcribe
print("βœ… model_helper")
from html_helper import create_html_from_midi, to_data_url
print("βœ… html_helper")
return True
except Exception as e:
print(f"❌ Import error: {e}")
return False
def test_model_loading():
"""Test model loading with debug output"""
print("\nπŸ” Testing model loading...")
try:
from model_helper import load_model_checkpoint
# Use the same args as app.py
model_name = 'YPTF.MoE+Multi (noPS)'
precision = '16'
project = '2024'
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
print(f"Loading {model_name}...")
model = load_model_checkpoint(args=args, device="cpu")
print("βœ… Model loaded successfully!")
# Test our debug function
from model_helper import debug_model_task_config
debug_model_task_config(model)
return model
except Exception as e:
print(f"❌ Model loading failed: {e}")
import traceback
traceback.print_exc()
return None
def test_instrument_conditioning(model):
"""Test the instrument conditioning with a sample file"""
print("\nπŸ” Testing instrument conditioning...")
# Find a test audio file
example_files = list(Path("examples").glob("*.wav"))
if not example_files:
print("❌ No example files found")
return False
test_file = example_files[0]
print(f"Using test file: {test_file}")
try:
import torchaudio
from model_helper import transcribe
# Create audio info
info = torchaudio.info(str(test_file))
audio_info = {
"filepath": str(test_file),
"track_name": test_file.stem + "_test",
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample) if hasattr(info, 'bits_per_sample') else 16,
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(str(info.encoding)),
}
print("\n--- Testing normal transcription ---")
midifile1 = transcribe(model, audio_info, instrument_hint=None)
print(f"Normal transcription result: {midifile1}")
print("\n--- Testing vocals conditioning ---")
midifile2 = transcribe(model, audio_info, instrument_hint="vocals")
print(f"Vocals transcription result: {midifile2}")
print("βœ… Instrument conditioning test completed!")
return True
except Exception as e:
print(f"❌ Instrument conditioning test failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
print("🎡 YourMT3+ Quick Test")
print("=" * 40)
# Check if we're in the right directory
if not Path("app.py").exists():
print("❌ Please run this from the YourMT3 directory")
sys.exit(1)
print(f"πŸ“ Working directory: {os.getcwd()}")
# Test imports
if not test_basic_import():
print("\n❌ Basic imports failed - install dependencies first")
sys.exit(1)
# Test model loading
model = test_model_loading()
if model is None:
print("\n❌ Model loading failed - check model weights")
sys.exit(1)
# Test instrument conditioning
if test_instrument_conditioning(model):
print("\nπŸŽ‰ All tests passed!")
print("\nYou can now run:")
print(" python app.py")
print("\nThen visit: http://127.0.0.1:7860")
else:
print("\n⚠️ Some tests failed but basic functionality should work")
print("You can still try running: python app.py")
if __name__ == "__main__":
main()