oKen38461 commited on
Commit
d9a2a3d
·
1 Parent(s): 0f839d2

テストスクリプトの削除に伴い、`tests/`を`.gitignore`に追加しました。また、`README.md`のAPIドキュメントセクションを更新しました。

Browse files
Files changed (4) hide show
  1. .huggingface.yaml +9 -0
  2. CLAUDE.md +120 -0
  3. app_streaming.py +195 -0
  4. test_streaming.py +140 -0
.huggingface.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # .huggingface.yaml
2
+ sdk: gradio
3
+ python_version: "3.10"
4
+ hardware: "A100"
5
+ timeout_seconds: 600 # 初回ロード時間を確保
6
+ accelerator: gpu
7
+ python:
8
+ pip_install:
9
+ - -r requirements.txt
CLAUDE.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Commands
6
+
7
+ ### Setup and Installation
8
+ ```bash
9
+ # Initial setup - creates necessary directories
10
+ ./setup.sh
11
+
12
+ # Install Python dependencies
13
+ pip install -r requirements.txt
14
+
15
+ # Pre-installation requirements (if needed)
16
+ pip install -r pre-requirements.txt
17
+ ```
18
+
19
+ ### Running the Application
20
+ ```bash
21
+ # Run the optimized Gradio interface (recommended)
22
+ python app_optimized.py
23
+
24
+ # Run the original Gradio interface
25
+ python app.py
26
+
27
+ # Run the FastAPI server for API access
28
+ python api_server.py
29
+ ```
30
+
31
+ ### Testing
32
+ ```bash
33
+ # Run basic API tests
34
+ python test_api.py
35
+
36
+ # Run API client tests
37
+ python test_api_client.py
38
+
39
+ # Run performance tests
40
+ python test_performance.py
41
+
42
+ # Run optimized performance tests
43
+ python test_performance_optimized.py
44
+
45
+ # Run real-world performance tests
46
+ python test_performance_real.py
47
+ ```
48
+
49
+ ## Architecture Overview
50
+
51
+ This is a **Talking Head Generation System** that creates lip-synced videos from audio and source images. The project is structured in three phases with Phase 3 focusing on performance optimization.
52
+
53
+ ### Core Processing Pipeline
54
+ 1. **Input**: Audio file (WAV) + Source image (PNG/JPG)
55
+ 2. **Audio Processing**: Extract features using HuBERT model
56
+ 3. **Motion Generation**: Generate facial motion from audio features
57
+ 4. **Image Warping**: Apply motion to source image
58
+ 5. **Video Generation**: Create final video with audio sync
59
+
60
+ ### Key Components
61
+
62
+ #### Model Management (`model_manager.py`)
63
+ - Downloads models from Hugging Face on first run (~2.5GB)
64
+ - Manages PyTorch and TensorRT model variants
65
+ - Caches models in `/tmp/ditto_models`
66
+
67
+ #### Core Processing (`/core/`)
68
+ - **atomic_components/**: Basic processing units
69
+ - `audio2motion.py`: Audio to motion conversion
70
+ - `warping.py`: Image warping logic
71
+ - **aux_models/**: Supporting models (face detection, landmarks, HuBERT)
72
+ - **models/**: Main neural network architectures
73
+ - **optimization/**: Phase 3 performance optimizations
74
+
75
+ #### Phase 3 Optimizations (`/core/optimization/`)
76
+ - **resolution_optimization.py**: Fixed 320×320 processing
77
+ - **gpu_optimization.py**: Mixed precision, torch.compile
78
+ - **avatar_cache.py**: Pre-cached avatar system with tokens
79
+ - **cold_start_optimization.py**: Optimized model loading
80
+ - **inference_cache.py**: Result caching
81
+ - **parallel_processing.py**: CPU-GPU parallel execution
82
+
83
+ ### Performance Targets
84
+ - Process 16 seconds of audio in ~15 seconds (50-65% faster with Phase 3)
85
+ - First Frame Delay (FFD): <400ms on A100
86
+ - Real-time factor (RTF): <1.0
87
+ - Latest target (2025-07-18): 2-second streaming chunks
88
+
89
+ ### API Endpoints
90
+
91
+ #### Gradio API
92
+ - `/process_talking_head`: Main processing endpoint
93
+ - `/process_talking_head_optimized`: Optimized with caching
94
+ - `/preload_avatar`: Upload and cache avatars
95
+ - `/clear_cache`: Clear inference cache
96
+
97
+ #### FastAPI (api_server.py)
98
+ - `POST /generate`: Generate video from audio/image
99
+ - `GET /health`: Health check
100
+ - Additional endpoints for streaming support
101
+
102
+ ### Important Notes
103
+
104
+ 1. **GPU Requirements**: Requires NVIDIA GPU with CUDA support. Optimized for A100.
105
+
106
+ 2. **First Run**: Models are downloaded automatically on first run. Ensure sufficient disk space.
107
+
108
+ 3. **Caching**: The system uses multiple cache levels:
109
+ - Avatar cache: Pre-processed source images
110
+ - Inference cache: Recent generation results
111
+ - Model cache: Downloaded models
112
+
113
+ 4. **Testing**: Always run performance tests after optimization changes to verify improvements.
114
+
115
+ 5. **Streaming**: Latest SOW targets 2-second chunk processing for real-time streaming applications.
116
+
117
+ 6. **File Formats**:
118
+ - Audio: WAV format required
119
+ - Images: PNG or JPG (will be resized to 320×320)
120
+ - Output: MP4 video
app_streaming.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, tempfile, queue, threading, time, numpy as np, soundfile as sf
2
+ import gradio as gr
3
+ from stream_pipeline_offline import StreamSDK
4
+ import torch
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ import cv2
8
+
9
+ # モデル設定
10
+ CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
11
+ DATA_ROOT = "checkpoints/ditto_pytorch"
12
+
13
+ # サンプルファイルのディレクトリ
14
+ EXAMPLES_DIR = (Path(__file__).parent / "example").resolve()
15
+
16
+ # グローバルで一度だけロード(concurrency_count=1 前提)
17
+ sdk: StreamSDK | None = None
18
+ def init_sdk():
19
+ global sdk
20
+ if sdk is None:
21
+ sdk = StreamSDK(CFG_PKL, DATA_ROOT)
22
+ return sdk
23
+
24
+ # 音声チャンクサイズ(秒)
25
+ CHUNK_SEC = 0.20 # 16000*0.20 = 3200 sample ≒ 5 フレーム
26
+
27
+ def generator(mic, src_img):
28
+ """
29
+ Gradio 生成関数
30
+ mic : (sr, np.ndarray) 形式 (Gradio Audio streaming=True)
31
+ src_img : 画像ファイルパス
32
+ Yields : PIL.Image (現在フレーム) または (最後に mp4)
33
+ """
34
+ if mic is None:
35
+ yield None, None, "マイク入力を開始してください"
36
+ return
37
+
38
+ if src_img is None:
39
+ yield None, None, "ソース画像をアップロードしてください"
40
+ return
41
+
42
+ try:
43
+ sr, wav_full = mic
44
+ sdk = init_sdk()
45
+
46
+ # setup: online_mode=True でストリーミング
47
+ tmp_out = tempfile.mktemp(suffix=".mp4")
48
+ sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
49
+ N_total = int(np.ceil(len(wav_full) / sr * 25)) # 概算フレーム数
50
+ sdk.setup_Nd(N_total)
51
+
52
+ # 処理開始時刻
53
+ start_time = time.time()
54
+ frame_count = 0
55
+
56
+ # 音声を CHUNK_SEC ごとに送り込む
57
+ hop = int(sr * CHUNK_SEC)
58
+ for start_idx in range(0, len(wav_full), hop):
59
+ chunk = wav_full[start_idx : start_idx + hop]
60
+ if len(chunk) < hop:
61
+ chunk = np.pad(chunk, (0, hop - len(chunk)))
62
+ sdk.run_chunk(chunk)
63
+
64
+ # 直近で書き込まれたフレームをキューから取得
65
+ frames_processed = 0
66
+ while sdk.writer_queue.qsize() > 0 and frames_processed < 5:
67
+ try:
68
+ frame = sdk.writer_queue.get_nowait()
69
+ if frame is not None:
70
+ # numpy array (H, W, 3) を PIL Image に変換
71
+ pil_frame = Image.fromarray(frame)
72
+ frame_count += 1
73
+ elapsed = time.time() - start_time
74
+ fps = frame_count / elapsed if elapsed > 0 else 0
75
+ yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}"
76
+ frames_processed += 1
77
+ except queue.Empty:
78
+ break
79
+
80
+ # 少し待機(CPU負荷調整)
81
+ time.sleep(0.01)
82
+
83
+ # 残りのフレームを処理
84
+ print("音声チャンクの送信完了、残りフレームを処理中...")
85
+ timeout_count = 0
86
+ while timeout_count < 50: # 最大5秒待機
87
+ if sdk.writer_queue.qsize() > 0:
88
+ try:
89
+ frame = sdk.writer_queue.get_nowait()
90
+ if frame is not None:
91
+ pil_frame = Image.fromarray(frame)
92
+ frame_count += 1
93
+ elapsed = time.time() - start_time
94
+ fps = frame_count / elapsed if elapsed > 0 else 0
95
+ yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}"
96
+ timeout_count = 0
97
+ except queue.Empty:
98
+ time.sleep(0.1)
99
+ timeout_count += 1
100
+ else:
101
+ time.sleep(0.1)
102
+ timeout_count += 1
103
+
104
+ # SDKを閉じて最終的なMP4を生成
105
+ print("SDKを閉じて最終的なMP4を生成中...")
106
+ sdk.close() # ワーカー join & mp4 結合
107
+
108
+ # 処理完了
109
+ elapsed_total = time.time() - start_time
110
+ yield None, gr.Video(tmp_out), f"✅ 完了! 総フレーム数: {frame_count}, 処理時間: {elapsed_total:.1f}秒"
111
+
112
+ except Exception as e:
113
+ import traceback
114
+ error_msg = f"❌ エラー: {str(e)}\n{traceback.format_exc()}"
115
+ print(error_msg)
116
+ yield None, None, error_msg
117
+
118
+ # Gradio UI
119
+ with gr.Blocks(title="DittoTalkingHead Streaming") as demo:
120
+ gr.Markdown("""
121
+ # DittoTalkingHead - ストリーミング版
122
+
123
+ 音声をリアルタイムで処理し、生成されたフレームを逐次表示します。
124
+
125
+ ## 使い方
126
+ 1. **ソース画像**(PNG/JPG形式)をアップロード
127
+ 2. **Start**ボタンをクリックしてマイク録音開始
128
+ 3. 録音中、ライブフレームが更新されます
129
+ 4. 録音停止後、最終的なMP4が生成されます
130
+ """)
131
+
132
+ with gr.Row():
133
+ with gr.Column():
134
+ img_in = gr.Image(
135
+ type="filepath",
136
+ label="ソース画像 / Source Image",
137
+ value=str(EXAMPLES_DIR / "reference.png") if (EXAMPLES_DIR / "reference.png").exists() else None
138
+ )
139
+ mic_in = gr.Audio(
140
+ sources=["microphone"],
141
+ streaming=True,
142
+ label="マイク入力 (16 kHz)",
143
+ format="wav"
144
+ )
145
+
146
+ with gr.Column():
147
+ live_img = gr.Image(label="ライブフレーム", type="pil")
148
+ final_mp4 = gr.Video(label="最終結果 (MP4)")
149
+ status_text = gr.Textbox(label="ステータス", value="待機中...")
150
+
151
+ btn = gr.Button("Start Streaming", variant="primary")
152
+
153
+ # ストリーミング処理を開始
154
+ btn.click(
155
+ fn=generator,
156
+ inputs=[mic_in, img_in],
157
+ outputs=[live_img, final_mp4, status_text],
158
+ stream_every=0.1 # 100msごとに更新
159
+ )
160
+
161
+ # サンプル
162
+ if EXAMPLES_DIR.exists():
163
+ gr.Examples(
164
+ examples=[
165
+ [str(EXAMPLES_DIR / "reference.png")]
166
+ ],
167
+ inputs=[img_in],
168
+ label="サンプル画像"
169
+ )
170
+
171
+ # 起動設定
172
+ if __name__ == "__main__":
173
+ # GPU最適化設定
174
+ if torch.cuda.is_available():
175
+ torch.cuda.empty_cache()
176
+ torch.backends.cudnn.benchmark = True
177
+
178
+ # 環境変数設定
179
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
180
+
181
+ print("=== DittoTalkingHead ストリーミング版 起動 ===")
182
+ print(f"- チャンクサイズ: {CHUNK_SEC}秒")
183
+ print(f"- 最大解像度: 1024px")
184
+ print(f"- GPU: {'利用可能' if torch.cuda.is_available() else '利用不可'}")
185
+
186
+ # モデルの事前ロード
187
+ print("モデルを事前ロード中...")
188
+ init_sdk()
189
+ print("✅ モデルロード完了")
190
+
191
+ demo.queue(concurrency_count=1, max_size=8).launch(
192
+ server_name="0.0.0.0",
193
+ server_port=7860,
194
+ share=False
195
+ )
test_streaming.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ストリーミング実装のテストスクリプト
3
+ """
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import tempfile
7
+ import time
8
+ from pathlib import Path
9
+ from stream_pipeline_offline import StreamSDK
10
+
11
+ # テスト設定
12
+ CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
13
+ DATA_ROOT = "checkpoints/ditto_pytorch"
14
+ EXAMPLES_DIR = Path("example")
15
+
16
+ def test_streaming():
17
+ """ストリーミング機能の基本テスト"""
18
+ print("=== ストリーミング機能テスト開始 ===")
19
+
20
+ # テスト用の音声を生成(3秒のサイン波)
21
+ duration = 3.0 # seconds
22
+ sample_rate = 16000
23
+ t = np.linspace(0, duration, int(sample_rate * duration))
24
+ audio_data = np.sin(2 * np.pi * 440 * t) * 0.5 # 440Hz
25
+
26
+ # SDKの初期化
27
+ print("1. SDK初期化...")
28
+ sdk = StreamSDK(CFG_PKL, DATA_ROOT)
29
+ print("✅ SDK初期化完了")
30
+
31
+ # セットアップ
32
+ print("\n2. ストリーミングモードでセットアップ...")
33
+ src_img = str(EXAMPLES_DIR / "reference.png")
34
+ tmp_out = tempfile.mktemp(suffix=".mp4")
35
+
36
+ sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
37
+ N_total = int(np.ceil(duration * 25)) # 25fps
38
+ sdk.setup_Nd(N_total)
39
+ print("✅ セットアップ完了")
40
+
41
+ # チャンク単位で音声を送信
42
+ print("\n3. チャンク単位で音声送信...")
43
+ chunk_sec = 0.2 # 200ms
44
+ chunk_samples = int(sample_rate * chunk_sec)
45
+ chunks_sent = 0
46
+ frames_received = 0
47
+
48
+ start_time = time.time()
49
+
50
+ for i in range(0, len(audio_data), chunk_samples):
51
+ chunk = audio_data[i:i + chunk_samples]
52
+ if len(chunk) < chunk_samples:
53
+ chunk = np.pad(chunk, (0, chunk_samples - len(chunk)))
54
+
55
+ sdk.run_chunk(chunk)
56
+ chunks_sent += 1
57
+
58
+ # キューからフレームを確認
59
+ while sdk.writer_queue.qsize() > 0:
60
+ try:
61
+ frame = sdk.writer_queue.get_nowait()
62
+ if frame is not None:
63
+ frames_received += 1
64
+ print(f" フレーム {frames_received} 受信 (チャンク {chunks_sent})")
65
+ except:
66
+ break
67
+
68
+ time.sleep(0.05) # 少し待機
69
+
70
+ # 残りのフレームを待つ
71
+ print("\n4. 残りのフレームを処理...")
72
+ timeout = 5.0 # 5秒タイムアウト
73
+ timeout_start = time.time()
74
+
75
+ while time.time() - timeout_start < timeout:
76
+ if sdk.writer_queue.qsize() > 0:
77
+ try:
78
+ frame = sdk.writer_queue.get_nowait()
79
+ if frame is not None:
80
+ frames_received += 1
81
+ print(f" フレーム {frames_received} 受信")
82
+ except:
83
+ pass
84
+ else:
85
+ time.sleep(0.1)
86
+
87
+ # クローズ
88
+ print("\n5. SDKクローズ...")
89
+ sdk.close()
90
+
91
+ elapsed = time.time() - start_time
92
+
93
+ # 結果
94
+ print("\n=== テスト結果 ===")
95
+ print(f"✅ 送信チャンク数: {chunks_sent}")
96
+ print(f"✅ 受信フレーム数: {frames_received}")
97
+ print(f"✅ 処理時間: {elapsed:.2f}秒")
98
+ print(f"✅ 出力ファイル: {tmp_out}")
99
+
100
+ # 期待される結果の確認
101
+ expected_frames = int(duration * 25) # 25fps
102
+ if frames_received >= expected_frames * 0.8: # 80%以上
103
+ print("✅ テスト成功!")
104
+ else:
105
+ print(f"⚠️ 期待フレーム数 ({expected_frames}) に対して受信数が少ない")
106
+
107
+ return True
108
+
109
+
110
+ def test_writer_queue():
111
+ """writer_queueの動作確認"""
112
+ print("\n=== writer_queue 動作確認 ===")
113
+
114
+ sdk = StreamSDK(CFG_PKL, DATA_ROOT)
115
+
116
+ # キューの存在確認
117
+ if hasattr(sdk, 'writer_queue'):
118
+ print("✅ writer_queue が存在します")
119
+ print(f" キューサイズ: {sdk.writer_queue.qsize()}")
120
+ print(f" 最大サイズ: {sdk.writer_queue.maxsize}")
121
+ else:
122
+ print("❌ writer_queue が見つかりません")
123
+ return False
124
+
125
+ return True
126
+
127
+
128
+ if __name__ == "__main__":
129
+ # writer_queueの確認
130
+ if not test_writer_queue():
131
+ print("基本的な要件が満たされていません")
132
+ exit(1)
133
+
134
+ # ストリーミングテスト
135
+ try:
136
+ test_streaming()
137
+ except Exception as e:
138
+ print(f"❌ エラー: {e}")
139
+ import traceback
140
+ traceback.print_exc()