oKen38461 commited on
Commit
43f5a2b
·
1 Parent(s): ac7cda5

`.gitignore`に`docs/`フォルダを追加して、無視するファイルを更新

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. README_hf_space.md +50 -0
  3. app.py +130 -0
  4. model_manager.py +267 -0
  5. requirements.txt +46 -0
  6. setup.sh +18 -0
.gitignore CHANGED
@@ -37,6 +37,8 @@ log/*
37
  # Folders to ignore
38
  example/
39
  ToDo/
 
 
40
 
41
  !example/audio.wav
42
  !example/image.png
 
37
  # Folders to ignore
38
  example/
39
  ToDo/
40
+ docs/
41
+
42
 
43
  !example/audio.wav
44
  !example/image.png
README_hf_space.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DittoTalkingHead
3
+ emoji: 🗣️
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.19.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ hardware: a100-large
12
+ ---
13
+
14
+ # DittoTalkingHead - Talking Head Generation
15
+
16
+ 音声とソース画像から、リアルなTalking Headビデオを生成します。
17
+
18
+ ## 特徴
19
+
20
+ - 高品質なリップシンク
21
+ - 自然な表情と頭部の動き
22
+ - TensorRTによる高速推論
23
+ - 自動モデルダウンロード機能
24
+
25
+ ## 使い方
26
+
27
+ 1. **音声ファイル**(WAV形式)をアップロード
28
+ 2. **ソース画像**(PNG/JPG形式)をアップロード
29
+ 3. **生成**ボタンをクリック
30
+
31
+ ## 技術仕様
32
+
33
+ - **GPU**: NVIDIA A100(推奨)
34
+ - **フレームワーク**: PyTorch
35
+ - **モデル**: DittoTalkingHead (PyTorch版)
36
+ - **モデルサイズ**: 約2.5GB
37
+
38
+ ## 注意事項
39
+
40
+ - 初回実行時は、モデルの自動ダウンロードのため時間がかかります(約10-15分)
41
+ - GPU(A100)環境での実行を推奨します
42
+ - 音声ファイルは16kHz WAV形式が推奨です
43
+
44
+ ## モデルソース
45
+
46
+ モデルは[digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)から自動的にダウンロードされます。
47
+
48
+ ## ライセンス
49
+
50
+ Apache License 2.0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import tempfile
4
+ import shutil
5
+ from pathlib import Path
6
+ from model_manager import ModelManager
7
+ from stream_pipeline_offline import StreamSDK
8
+ from inference import run, seed_everything
9
+
10
+ # モデルの初期化
11
+ print("=== モデルの初期化開始 ===")
12
+
13
+ # PyTorchモデルを使用(TensorRTモデルは非常に大きいため)
14
+ USE_PYTORCH = True
15
+
16
+ model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
17
+ if not model_manager.setup_models():
18
+ raise RuntimeError("モデルのセットアップに失敗しました。")
19
+
20
+ # SDKの初期化
21
+ if USE_PYTORCH:
22
+ data_root = "./checkpoints/ditto_pytorch"
23
+ cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
24
+ else:
25
+ data_root = "./checkpoints/ditto_trt_Ampere_Plus"
26
+ cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"
27
+
28
+ try:
29
+ SDK = StreamSDK(cfg_pkl, data_root)
30
+ print("✅ SDK初期化成功")
31
+ except Exception as e:
32
+ print(f"❌ SDK初期化エラー: {e}")
33
+ raise
34
+
35
+ def process_talking_head(audio_file, source_image):
36
+ """音声とソース画像からTalking Headビデオを生成"""
37
+
38
+ if audio_file is None:
39
+ return None, "音声ファイルをアップロードしてください。"
40
+
41
+ if source_image is None:
42
+ return None, "ソース画像をアップロードしてください。"
43
+
44
+ try:
45
+ # 一時ファイルの作成
46
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output:
47
+ output_path = tmp_output.name
48
+
49
+ # 処理実行
50
+ print(f"処理開始: audio={audio_file}, image={source_image}")
51
+ seed_everything(1024)
52
+ run(SDK, audio_file, source_image, output_path)
53
+
54
+ # 結果の確認
55
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
56
+ return output_path, "✅ 処理が完了しました!"
57
+ else:
58
+ return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。"
59
+
60
+ except Exception as e:
61
+ import traceback
62
+ error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}"
63
+ print(error_msg)
64
+ return None, error_msg
65
+
66
+ # Gradio UI
67
+ with gr.Blocks(title="DittoTalkingHead") as demo:
68
+ gr.Markdown("""
69
+ # DittoTalkingHead - Talking Head Generation
70
+
71
+ 音声とソース画像から、リアルなTalking Headビデオを生成します。
72
+
73
+ ## 使い方
74
+ 1. **音声ファイル**(WAV形式)をアップロード
75
+ 2. **ソース画像**(PNG/JPG形式)をアップロード
76
+ 3. **生成**ボタンをクリック
77
+
78
+ ⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。
79
+
80
+ ### 技術仕様
81
+ - **モデル**: DittoTalkingHead (PyTorch版)
82
+ - **GPU**: NVIDIA A100推奨
83
+ - **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)
84
+ """)
85
+
86
+ with gr.Row():
87
+ with gr.Column():
88
+ audio_input = gr.Audio(
89
+ label="音声ファイル (WAV)",
90
+ type="filepath"
91
+ )
92
+ image_input = gr.Image(
93
+ label="ソース画像",
94
+ type="filepath"
95
+ )
96
+ generate_btn = gr.Button("生成", variant="primary")
97
+
98
+ with gr.Column():
99
+ video_output = gr.Video(
100
+ label="生成されたビデオ"
101
+ )
102
+ status_output = gr.Textbox(
103
+ label="ステータス",
104
+ lines=3
105
+ )
106
+
107
+ # サンプル
108
+ gr.Examples(
109
+ examples=[
110
+ ["example/audio.wav", "example/image.png"]
111
+ ],
112
+ inputs=[audio_input, image_input],
113
+ outputs=[video_output, status_output],
114
+ fn=process_talking_head,
115
+ cache_examples=True
116
+ )
117
+
118
+ # イベントハンドラ
119
+ generate_btn.click(
120
+ fn=process_talking_head,
121
+ inputs=[audio_input, image_input],
122
+ outputs=[video_output, status_output]
123
+ )
124
+
125
+ if __name__ == "__main__":
126
+ demo.launch(
127
+ server_name="0.0.0.0",
128
+ server_port=7860,
129
+ share=False
130
+ )
model_manager.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import requests
4
+ from tqdm import tqdm
5
+ from pathlib import Path
6
+ import hashlib
7
+ import json
8
+ import time
9
+
10
+ class ModelManager:
11
+ def __init__(self, cache_dir="/tmp/models", use_pytorch=False):
12
+ self.cache_dir = Path(cache_dir)
13
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
14
+ self.use_pytorch = use_pytorch
15
+
16
+ # Hugging Face公式リポジトリからモデルを取得
17
+ base_url = "https://huggingface.co/digital-avatar/ditto-talkinghead/resolve/main"
18
+
19
+ if use_pytorch:
20
+ # PyTorchモデルの設定
21
+ self.model_configs = [
22
+ {
23
+ "name": "appearance_extractor.pth",
24
+ "url": f"{base_url}/checkpoints/ditto_pytorch/models/appearance_extractor.pth",
25
+ "dest_dir": "checkpoints/ditto_pytorch/models",
26
+ "dest_file": "appearance_extractor.pth",
27
+ "type": "file"
28
+ },
29
+ {
30
+ "name": "decoder.pth",
31
+ "url": f"{base_url}/checkpoints/ditto_pytorch/models/decoder.pth",
32
+ "dest_dir": "checkpoints/ditto_pytorch/models",
33
+ "dest_file": "decoder.pth",
34
+ "type": "file"
35
+ },
36
+ {
37
+ "name": "lmdm_v0.4_hubert.pth",
38
+ "url": f"{base_url}/checkpoints/ditto_pytorch/models/lmdm_v0.4_hubert.pth",
39
+ "dest_dir": "checkpoints/ditto_pytorch/models",
40
+ "dest_file": "lmdm_v0.4_hubert.pth",
41
+ "type": "file"
42
+ },
43
+ {
44
+ "name": "motion_extractor.pth",
45
+ "url": f"{base_url}/checkpoints/ditto_pytorch/models/motion_extractor.pth",
46
+ "dest_dir": "checkpoints/ditto_pytorch/models",
47
+ "dest_file": "motion_extractor.pth",
48
+ "type": "file"
49
+ },
50
+ {
51
+ "name": "stitch_network.pth",
52
+ "url": f"{base_url}/checkpoints/ditto_pytorch/models/stitch_network.pth",
53
+ "dest_dir": "checkpoints/ditto_pytorch/models",
54
+ "dest_file": "stitch_network.pth",
55
+ "type": "file"
56
+ },
57
+ {
58
+ "name": "warp_network.pth",
59
+ "url": f"{base_url}/checkpoints/ditto_pytorch/models/warp_network.pth",
60
+ "dest_dir": "checkpoints/ditto_pytorch/models",
61
+ "dest_file": "warp_network.pth",
62
+ "type": "file"
63
+ },
64
+ {
65
+ "name": "v0.4_hubert_cfg.pkl",
66
+ "url": f"{base_url}/checkpoints/ditto_cfg/v0.4_hubert_cfg.pkl",
67
+ "dest_dir": "checkpoints/ditto_cfg",
68
+ "dest_file": "v0.4_hubert_cfg.pkl",
69
+ "type": "file"
70
+ }
71
+ ]
72
+ else:
73
+ # TensorRTモデルの設定
74
+ self.model_configs = [
75
+ {
76
+ "name": "ditto_trt_models",
77
+ "url": os.environ.get("DITTO_TRT_URL", f"{base_url}/checkpoints/ditto_trt_Ampere_Plus.tar.gz"),
78
+ "dest_dir": "checkpoints",
79
+ "type": "archive",
80
+ "extract_subdir": "ditto_trt_Ampere_Plus"
81
+ },
82
+ {
83
+ "name": "v0.4_hubert_cfg_trt.pkl",
84
+ "url": f"{base_url}/checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl",
85
+ "dest_dir": "checkpoints/ditto_cfg",
86
+ "dest_file": "v0.4_hubert_cfg_trt.pkl",
87
+ "type": "file"
88
+ }
89
+ ]
90
+
91
+ self.progress_file = self.cache_dir / "download_progress.json"
92
+ self.download_progress = self.load_progress()
93
+
94
+ def load_progress(self):
95
+ """ダウンロード進捗の読み込み"""
96
+ if self.progress_file.exists():
97
+ with open(self.progress_file, 'r') as f:
98
+ return json.load(f)
99
+ return {}
100
+
101
+ def save_progress(self):
102
+ """ダウンロード進捗の保存"""
103
+ with open(self.progress_file, 'w') as f:
104
+ json.dump(self.download_progress, f)
105
+
106
+ def get_file_hash(self, filepath):
107
+ """ファイルのハッシュ値を計算"""
108
+ sha256_hash = hashlib.sha256()
109
+ with open(filepath, "rb") as f:
110
+ for byte_block in iter(lambda: f.read(4096), b""):
111
+ sha256_hash.update(byte_block)
112
+ return sha256_hash.hexdigest()
113
+
114
+ def download_file(self, url, dest_path, retries=3):
115
+ """ファイルのダウンロード(レジューム対応)"""
116
+ dest_path = Path(dest_path)
117
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
118
+
119
+ headers = {}
120
+ mode = 'wb'
121
+ resume_pos = 0
122
+
123
+ # レジューム処理
124
+ if dest_path.exists():
125
+ resume_pos = dest_path.stat().st_size
126
+ headers['Range'] = f'bytes={resume_pos}-'
127
+ mode = 'ab'
128
+
129
+ for attempt in range(retries):
130
+ try:
131
+ response = requests.get(url, headers=headers, stream=True, timeout=30)
132
+ response.raise_for_status()
133
+
134
+ total_size = int(response.headers.get('content-length', 0))
135
+ if resume_pos > 0:
136
+ total_size += resume_pos
137
+
138
+ with open(dest_path, mode) as f:
139
+ with tqdm(total=total_size, initial=resume_pos, unit='B', unit_scale=True, desc=dest_path.name) as pbar:
140
+ for chunk in response.iter_content(chunk_size=8192):
141
+ if chunk:
142
+ f.write(chunk)
143
+ pbar.update(len(chunk))
144
+
145
+ return True
146
+
147
+ except Exception as e:
148
+ print(f"ダウンロードエラー (試行 {attempt + 1}/{retries}): {e}")
149
+ if attempt < retries - 1:
150
+ time.sleep(5) # 再試行前に待機
151
+ else:
152
+ raise
153
+
154
+ return False
155
+
156
+ def extract_archive(self, archive_path, dest_dir, extract_subdir=None):
157
+ """アーカイブの展開"""
158
+ import tarfile
159
+ import zipfile
160
+
161
+ archive_path = Path(archive_path)
162
+ dest_dir = Path(dest_dir)
163
+ temp_dir = dest_dir / "temp_extract"
164
+
165
+ try:
166
+ if archive_path.suffix == '.gz' or archive_path.suffix == '.tar' or str(archive_path).endswith('.tar.gz'):
167
+ with tarfile.open(archive_path, 'r:*') as tar:
168
+ if extract_subdir:
169
+ # 一時ディレクトリに展開してから移動
170
+ temp_dir.mkdir(exist_ok=True)
171
+ tar.extractall(temp_dir)
172
+ # 特定のサブディレクトリを移動
173
+ src_dir = temp_dir / extract_subdir
174
+ if src_dir.exists():
175
+ shutil.move(str(src_dir), str(dest_dir / extract_subdir))
176
+ shutil.rmtree(temp_dir)
177
+ else:
178
+ tar.extractall(dest_dir)
179
+ elif archive_path.suffix == '.zip':
180
+ with zipfile.ZipFile(archive_path, 'r') as zip_ref:
181
+ zip_ref.extractall(dest_dir)
182
+ else:
183
+ raise ValueError(f"Unsupported archive format: {archive_path.suffix}")
184
+ except Exception as e:
185
+ if temp_dir.exists():
186
+ shutil.rmtree(temp_dir)
187
+ raise e
188
+
189
+ def check_models_exist(self):
190
+ """必要なモデルが存在するかチェック"""
191
+ missing_models = []
192
+ for config in self.model_configs:
193
+ if config['type'] == 'file':
194
+ dest_path = Path(config['dest_dir']) / config['dest_file']
195
+ if not dest_path.exists():
196
+ missing_models.append(config)
197
+ else: # archive
198
+ dest_dir = Path(config['dest_dir'])
199
+ if not dest_dir.exists() or not any(dest_dir.iterdir()):
200
+ missing_models.append(config)
201
+ return missing_models
202
+
203
+ def download_models(self):
204
+ """必要なモデルをダウンロード"""
205
+ missing_models = self.check_models_exist()
206
+
207
+ if not missing_models:
208
+ print("すべてのモデルが既に存在します。")
209
+ return True
210
+
211
+ print(f"{len(missing_models)}個のモデルをダウンロードします...")
212
+
213
+ for config in missing_models:
214
+ size_info = config.get('size', '不明')
215
+ print(f"\n{config['name']} をダウンロード中... (サイズ: {size_info})")
216
+
217
+ # キャッシュパスの設定
218
+ cache_filename = f"{config['name']}.download"
219
+ cache_path = self.cache_dir / cache_filename
220
+
221
+ try:
222
+ # ダウンロード
223
+ if not cache_path.exists() or self.download_progress.get(config['name'], {}).get('status') != 'completed':
224
+ self.download_file(config['url'], cache_path)
225
+ self.download_progress[config['name']] = {'status': 'completed'}
226
+ self.save_progress()
227
+
228
+ # 展開またはコピー
229
+ if config['type'] == 'file':
230
+ dest_dir = Path(config['dest_dir'])
231
+ dest_dir.mkdir(parents=True, exist_ok=True)
232
+ dest_path = dest_dir / config['dest_file']
233
+ shutil.copy2(cache_path, dest_path)
234
+ else: # archive
235
+ dest_dir = Path(config['dest_dir'])
236
+ dest_dir.mkdir(parents=True, exist_ok=True)
237
+ print(f"{config['name']} を展開中...")
238
+ extract_subdir = config.get('extract_subdir')
239
+ self.extract_archive(cache_path, dest_dir, extract_subdir)
240
+
241
+ print(f"{config['name']} のセットアップ完了")
242
+
243
+ except Exception as e:
244
+ print(f"エラー: {config['name']} のダウンロード中にエラーが発生しました: {e}")
245
+ return False
246
+
247
+ return True
248
+
249
+ def setup_models(self):
250
+ """モデルのセットアップ(メイン処理)"""
251
+ print("=== DittoTalkingHead モデルセットアップ ===")
252
+ print(f"キャッシュディレクトリ: {self.cache_dir}")
253
+
254
+ success = self.download_models()
255
+
256
+ if success:
257
+ print("\n✅ すべてのモデルのセットアップが完了しました!")
258
+ else:
259
+ print("\n❌ モデルのセットアップ中にエラーが発生しました。")
260
+
261
+ return success
262
+
263
+
264
+ if __name__ == "__main__":
265
+ # テスト実行
266
+ manager = ModelManager()
267
+ manager.setup_models()
requirements.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch==2.5.1
3
+ torchvision==0.20.1
4
+ torchaudio==2.5.1
5
+ numpy==2.0.1
6
+ pillow==11.0.0
7
+
8
+ # Audio processing
9
+ librosa==0.10.2.post1
10
+ soundfile==0.13.0
11
+ audioread==3.0.1
12
+ soxr==0.5.0.post1
13
+
14
+ # Video/Image processing
15
+ opencv-python-headless==4.10.0.84
16
+ imageio==2.36.1
17
+ imageio-ffmpeg==0.5.1
18
+ scikit-image==0.25.0
19
+
20
+ # Machine learning
21
+ scikit-learn==1.6.0
22
+ scipy==1.15.0
23
+ numba==0.60.0
24
+
25
+ # TensorRT (GPU acceleration)
26
+ tensorrt==8.6.1
27
+ tensorrt-bindings==8.6.1
28
+ tensorrt-libs==8.6.1
29
+ polygraphy
30
+ colored
31
+
32
+ # Web interface
33
+ gradio==4.19.0
34
+
35
+ # Utilities
36
+ tqdm==4.67.1
37
+ requests==2.32.3
38
+ pyyaml==6.0.2
39
+ joblib==1.4.2
40
+ cython==3.0.11
41
+
42
+ # CUDA dependencies
43
+ cuda-python==12.6.2.post1
44
+ nvidia-cublas-cu12==12.6.4.1
45
+ nvidia-cuda-runtime-cu12==12.6.77
46
+ nvidia-cudnn-cu12==9.6.0.74
setup.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Setup script for Hugging Face Space
4
+ echo "=== DittoTalkingHead Setup Script ==="
5
+
6
+ # Create necessary directories
7
+ mkdir -p checkpoints/ditto_cfg
8
+ mkdir -p tmp
9
+ mkdir -p output
10
+
11
+ # Install system dependencies if needed
12
+ # apt-get update && apt-get install -y ffmpeg
13
+
14
+ # Run model download (PyTorch models)
15
+ echo "Starting model download (PyTorch models)..."
16
+ python -c "from model_manager import ModelManager; manager = ModelManager(use_pytorch=True); manager.setup_models()"
17
+
18
+ echo "Setup complete!"