Spaces:
Runtime error
Runtime error
""" | |
ストリーミング実装のテストスクリプト | |
""" | |
import numpy as np | |
import soundfile as sf | |
import tempfile | |
import time | |
from pathlib import Path | |
from stream_pipeline_offline import StreamSDK | |
# テスト設定 | |
CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
DATA_ROOT = "checkpoints/ditto_pytorch" | |
EXAMPLES_DIR = Path("example") | |
def test_streaming(): | |
"""ストリーミング機能の基本テスト""" | |
print("=== ストリーミング機能テスト開始 ===") | |
# テスト用の音声を生成(3秒のサイン波) | |
duration = 3.0 # seconds | |
sample_rate = 16000 | |
t = np.linspace(0, duration, int(sample_rate * duration)) | |
audio_data = np.sin(2 * np.pi * 440 * t) * 0.5 # 440Hz | |
# SDKの初期化 | |
print("1. SDK初期化...") | |
sdk = StreamSDK(CFG_PKL, DATA_ROOT) | |
print("✅ SDK初期化完了") | |
# セットアップ | |
print("\n2. ストリーミングモードでセットアップ...") | |
src_img = str(EXAMPLES_DIR / "reference.png") | |
tmp_out = tempfile.mktemp(suffix=".mp4") | |
sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024) | |
N_total = int(np.ceil(duration * 20)) # 20fps | |
sdk.setup_Nd(N_total) | |
print("✅ セットアップ完了") | |
# チャンク単位で音声を送信 | |
print("\n3. チャンク単位で音声送信...") | |
chunk_sec = 0.2 # 200ms | |
chunk_samples = int(sample_rate * chunk_sec) | |
chunks_sent = 0 | |
frames_received = 0 | |
start_time = time.time() | |
for i in range(0, len(audio_data), chunk_samples): | |
chunk = audio_data[i:i + chunk_samples] | |
if len(chunk) < chunk_samples: | |
chunk = np.pad(chunk, (0, chunk_samples - len(chunk))) | |
sdk.run_chunk(chunk) | |
chunks_sent += 1 | |
# キューからフレームを確認 | |
while sdk.writer_queue.qsize() > 0: | |
try: | |
frame = sdk.writer_queue.get_nowait() | |
if frame is not None: | |
frames_received += 1 | |
print(f" フレーム {frames_received} 受信 (チャンク {chunks_sent})") | |
except: | |
break | |
time.sleep(0.05) # 少し待機 | |
# 残りのフレームを待つ | |
print("\n4. 残りのフレームを処理...") | |
timeout = 5.0 # 5秒タイムアウト | |
timeout_start = time.time() | |
while time.time() - timeout_start < timeout: | |
if sdk.writer_queue.qsize() > 0: | |
try: | |
frame = sdk.writer_queue.get_nowait() | |
if frame is not None: | |
frames_received += 1 | |
print(f" フレーム {frames_received} 受信") | |
except: | |
pass | |
else: | |
time.sleep(0.1) | |
# クローズ | |
print("\n5. SDKクローズ...") | |
sdk.close() | |
elapsed = time.time() - start_time | |
# 結果 | |
print("\n=== テスト結果 ===") | |
print(f"✅ 送信チャンク数: {chunks_sent}") | |
print(f"✅ 受信フレーム数: {frames_received}") | |
print(f"✅ 処理時間: {elapsed:.2f}秒") | |
print(f"✅ 出力ファイル: {tmp_out}") | |
# 期待される結果の確認 | |
expected_frames = int(duration * 20) # 20fps | |
if frames_received >= expected_frames * 0.8: # 80%以上 | |
print("✅ テスト成功!") | |
else: | |
print(f"⚠️ 期待フレーム数 ({expected_frames}) に対して受信数が少ない") | |
return True | |
def test_writer_queue(): | |
"""writer_queueの動作確認""" | |
print("\n=== writer_queue 動作確認 ===") | |
sdk = StreamSDK(CFG_PKL, DATA_ROOT) | |
# キューの存在確認 | |
if hasattr(sdk, 'writer_queue'): | |
print("✅ writer_queue が存在します") | |
print(f" キューサイズ: {sdk.writer_queue.qsize()}") | |
print(f" 最大サイズ: {sdk.writer_queue.maxsize}") | |
else: | |
print("❌ writer_queue が見つかりません") | |
return False | |
return True | |
if __name__ == "__main__": | |
# writer_queueの確認 | |
if not test_writer_queue(): | |
print("基本的な要件が満たされていません") | |
exit(1) | |
# ストリーミングテスト | |
try: | |
test_streaming() | |
except Exception as e: | |
print(f"❌ エラー: {e}") | |
import traceback | |
traceback.print_exc() |