Spaces:
Runtime error
Runtime error
`.gitignore`に`docs/`フォルダを追加して、無視するファイルを更新
Browse files- .gitignore +2 -0
- README_hf_space.md +50 -0
- app.py +130 -0
- model_manager.py +267 -0
- requirements.txt +46 -0
- setup.sh +18 -0
.gitignore
CHANGED
@@ -37,6 +37,8 @@ log/*
|
|
37 |
# Folders to ignore
|
38 |
example/
|
39 |
ToDo/
|
|
|
|
|
40 |
|
41 |
!example/audio.wav
|
42 |
!example/image.png
|
|
|
37 |
# Folders to ignore
|
38 |
example/
|
39 |
ToDo/
|
40 |
+
docs/
|
41 |
+
|
42 |
|
43 |
!example/audio.wav
|
44 |
!example/image.png
|
README_hf_space.md
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: DittoTalkingHead
|
3 |
+
emoji: 🗣️
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.19.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
hardware: a100-large
|
12 |
+
---
|
13 |
+
|
14 |
+
# DittoTalkingHead - Talking Head Generation
|
15 |
+
|
16 |
+
音声とソース画像から、リアルなTalking Headビデオを生成します。
|
17 |
+
|
18 |
+
## 特徴
|
19 |
+
|
20 |
+
- 高品質なリップシンク
|
21 |
+
- 自然な表情と頭部の動き
|
22 |
+
- TensorRTによる高速推論
|
23 |
+
- 自動モデルダウンロード機能
|
24 |
+
|
25 |
+
## 使い方
|
26 |
+
|
27 |
+
1. **音声ファイル**(WAV形式)をアップロード
|
28 |
+
2. **ソース画像**(PNG/JPG形式)をアップロード
|
29 |
+
3. **生成**ボタンをクリック
|
30 |
+
|
31 |
+
## 技術仕様
|
32 |
+
|
33 |
+
- **GPU**: NVIDIA A100(推奨)
|
34 |
+
- **フレームワーク**: PyTorch
|
35 |
+
- **モデル**: DittoTalkingHead (PyTorch版)
|
36 |
+
- **モデルサイズ**: 約2.5GB
|
37 |
+
|
38 |
+
## 注意事項
|
39 |
+
|
40 |
+
- 初回実行時は、モデルの自動ダウンロードのため時間がかかります(約10-15分)
|
41 |
+
- GPU(A100)環境での実行を推奨します
|
42 |
+
- 音声ファイルは16kHz WAV形式が推奨です
|
43 |
+
|
44 |
+
## モデルソース
|
45 |
+
|
46 |
+
モデルは[digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)から自動的にダウンロードされます。
|
47 |
+
|
48 |
+
## ライセンス
|
49 |
+
|
50 |
+
Apache License 2.0
|
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import shutil
|
5 |
+
from pathlib import Path
|
6 |
+
from model_manager import ModelManager
|
7 |
+
from stream_pipeline_offline import StreamSDK
|
8 |
+
from inference import run, seed_everything
|
9 |
+
|
10 |
+
# モデルの初期化
|
11 |
+
print("=== モデルの初期化開始 ===")
|
12 |
+
|
13 |
+
# PyTorchモデルを使用(TensorRTモデルは非常に大きいため)
|
14 |
+
USE_PYTORCH = True
|
15 |
+
|
16 |
+
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
|
17 |
+
if not model_manager.setup_models():
|
18 |
+
raise RuntimeError("モデルのセットアップに失敗しました。")
|
19 |
+
|
20 |
+
# SDKの初期化
|
21 |
+
if USE_PYTORCH:
|
22 |
+
data_root = "./checkpoints/ditto_pytorch"
|
23 |
+
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
|
24 |
+
else:
|
25 |
+
data_root = "./checkpoints/ditto_trt_Ampere_Plus"
|
26 |
+
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"
|
27 |
+
|
28 |
+
try:
|
29 |
+
SDK = StreamSDK(cfg_pkl, data_root)
|
30 |
+
print("✅ SDK初期化成功")
|
31 |
+
except Exception as e:
|
32 |
+
print(f"❌ SDK初期化エラー: {e}")
|
33 |
+
raise
|
34 |
+
|
35 |
+
def process_talking_head(audio_file, source_image):
|
36 |
+
"""音声とソース画像からTalking Headビデオを生成"""
|
37 |
+
|
38 |
+
if audio_file is None:
|
39 |
+
return None, "音声ファイルをアップロードしてください。"
|
40 |
+
|
41 |
+
if source_image is None:
|
42 |
+
return None, "ソース画像をアップロードしてください。"
|
43 |
+
|
44 |
+
try:
|
45 |
+
# 一時ファイルの作成
|
46 |
+
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output:
|
47 |
+
output_path = tmp_output.name
|
48 |
+
|
49 |
+
# 処理実行
|
50 |
+
print(f"処理開始: audio={audio_file}, image={source_image}")
|
51 |
+
seed_everything(1024)
|
52 |
+
run(SDK, audio_file, source_image, output_path)
|
53 |
+
|
54 |
+
# 結果の確認
|
55 |
+
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
56 |
+
return output_path, "✅ 処理が完了しました!"
|
57 |
+
else:
|
58 |
+
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。"
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
import traceback
|
62 |
+
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}"
|
63 |
+
print(error_msg)
|
64 |
+
return None, error_msg
|
65 |
+
|
66 |
+
# Gradio UI
|
67 |
+
with gr.Blocks(title="DittoTalkingHead") as demo:
|
68 |
+
gr.Markdown("""
|
69 |
+
# DittoTalkingHead - Talking Head Generation
|
70 |
+
|
71 |
+
音声とソース画像から、リアルなTalking Headビデオを生成します。
|
72 |
+
|
73 |
+
## 使い方
|
74 |
+
1. **音声ファイル**(WAV形式)をアップロード
|
75 |
+
2. **ソース画像**(PNG/JPG形式)をアップロード
|
76 |
+
3. **生成**ボタンをクリック
|
77 |
+
|
78 |
+
⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。
|
79 |
+
|
80 |
+
### 技術仕様
|
81 |
+
- **モデル**: DittoTalkingHead (PyTorch版)
|
82 |
+
- **GPU**: NVIDIA A100推奨
|
83 |
+
- **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)
|
84 |
+
""")
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
with gr.Column():
|
88 |
+
audio_input = gr.Audio(
|
89 |
+
label="音声ファイル (WAV)",
|
90 |
+
type="filepath"
|
91 |
+
)
|
92 |
+
image_input = gr.Image(
|
93 |
+
label="ソース画像",
|
94 |
+
type="filepath"
|
95 |
+
)
|
96 |
+
generate_btn = gr.Button("生成", variant="primary")
|
97 |
+
|
98 |
+
with gr.Column():
|
99 |
+
video_output = gr.Video(
|
100 |
+
label="生成されたビデオ"
|
101 |
+
)
|
102 |
+
status_output = gr.Textbox(
|
103 |
+
label="ステータス",
|
104 |
+
lines=3
|
105 |
+
)
|
106 |
+
|
107 |
+
# サンプル
|
108 |
+
gr.Examples(
|
109 |
+
examples=[
|
110 |
+
["example/audio.wav", "example/image.png"]
|
111 |
+
],
|
112 |
+
inputs=[audio_input, image_input],
|
113 |
+
outputs=[video_output, status_output],
|
114 |
+
fn=process_talking_head,
|
115 |
+
cache_examples=True
|
116 |
+
)
|
117 |
+
|
118 |
+
# イベントハンドラ
|
119 |
+
generate_btn.click(
|
120 |
+
fn=process_talking_head,
|
121 |
+
inputs=[audio_input, image_input],
|
122 |
+
outputs=[video_output, status_output]
|
123 |
+
)
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
demo.launch(
|
127 |
+
server_name="0.0.0.0",
|
128 |
+
server_port=7860,
|
129 |
+
share=False
|
130 |
+
)
|
model_manager.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import requests
|
4 |
+
from tqdm import tqdm
|
5 |
+
from pathlib import Path
|
6 |
+
import hashlib
|
7 |
+
import json
|
8 |
+
import time
|
9 |
+
|
10 |
+
class ModelManager:
|
11 |
+
def __init__(self, cache_dir="/tmp/models", use_pytorch=False):
|
12 |
+
self.cache_dir = Path(cache_dir)
|
13 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
14 |
+
self.use_pytorch = use_pytorch
|
15 |
+
|
16 |
+
# Hugging Face公式リポジトリからモデルを取得
|
17 |
+
base_url = "https://huggingface.co/digital-avatar/ditto-talkinghead/resolve/main"
|
18 |
+
|
19 |
+
if use_pytorch:
|
20 |
+
# PyTorchモデルの設定
|
21 |
+
self.model_configs = [
|
22 |
+
{
|
23 |
+
"name": "appearance_extractor.pth",
|
24 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/appearance_extractor.pth",
|
25 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
26 |
+
"dest_file": "appearance_extractor.pth",
|
27 |
+
"type": "file"
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"name": "decoder.pth",
|
31 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/decoder.pth",
|
32 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
33 |
+
"dest_file": "decoder.pth",
|
34 |
+
"type": "file"
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"name": "lmdm_v0.4_hubert.pth",
|
38 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/lmdm_v0.4_hubert.pth",
|
39 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
40 |
+
"dest_file": "lmdm_v0.4_hubert.pth",
|
41 |
+
"type": "file"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "motion_extractor.pth",
|
45 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/motion_extractor.pth",
|
46 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
47 |
+
"dest_file": "motion_extractor.pth",
|
48 |
+
"type": "file"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"name": "stitch_network.pth",
|
52 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/stitch_network.pth",
|
53 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
54 |
+
"dest_file": "stitch_network.pth",
|
55 |
+
"type": "file"
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"name": "warp_network.pth",
|
59 |
+
"url": f"{base_url}/checkpoints/ditto_pytorch/models/warp_network.pth",
|
60 |
+
"dest_dir": "checkpoints/ditto_pytorch/models",
|
61 |
+
"dest_file": "warp_network.pth",
|
62 |
+
"type": "file"
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"name": "v0.4_hubert_cfg.pkl",
|
66 |
+
"url": f"{base_url}/checkpoints/ditto_cfg/v0.4_hubert_cfg.pkl",
|
67 |
+
"dest_dir": "checkpoints/ditto_cfg",
|
68 |
+
"dest_file": "v0.4_hubert_cfg.pkl",
|
69 |
+
"type": "file"
|
70 |
+
}
|
71 |
+
]
|
72 |
+
else:
|
73 |
+
# TensorRTモデルの設定
|
74 |
+
self.model_configs = [
|
75 |
+
{
|
76 |
+
"name": "ditto_trt_models",
|
77 |
+
"url": os.environ.get("DITTO_TRT_URL", f"{base_url}/checkpoints/ditto_trt_Ampere_Plus.tar.gz"),
|
78 |
+
"dest_dir": "checkpoints",
|
79 |
+
"type": "archive",
|
80 |
+
"extract_subdir": "ditto_trt_Ampere_Plus"
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "v0.4_hubert_cfg_trt.pkl",
|
84 |
+
"url": f"{base_url}/checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl",
|
85 |
+
"dest_dir": "checkpoints/ditto_cfg",
|
86 |
+
"dest_file": "v0.4_hubert_cfg_trt.pkl",
|
87 |
+
"type": "file"
|
88 |
+
}
|
89 |
+
]
|
90 |
+
|
91 |
+
self.progress_file = self.cache_dir / "download_progress.json"
|
92 |
+
self.download_progress = self.load_progress()
|
93 |
+
|
94 |
+
def load_progress(self):
|
95 |
+
"""ダウンロード進捗の読み込み"""
|
96 |
+
if self.progress_file.exists():
|
97 |
+
with open(self.progress_file, 'r') as f:
|
98 |
+
return json.load(f)
|
99 |
+
return {}
|
100 |
+
|
101 |
+
def save_progress(self):
|
102 |
+
"""ダウンロード進捗の保存"""
|
103 |
+
with open(self.progress_file, 'w') as f:
|
104 |
+
json.dump(self.download_progress, f)
|
105 |
+
|
106 |
+
def get_file_hash(self, filepath):
|
107 |
+
"""ファイルのハッシュ値を計算"""
|
108 |
+
sha256_hash = hashlib.sha256()
|
109 |
+
with open(filepath, "rb") as f:
|
110 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
111 |
+
sha256_hash.update(byte_block)
|
112 |
+
return sha256_hash.hexdigest()
|
113 |
+
|
114 |
+
def download_file(self, url, dest_path, retries=3):
|
115 |
+
"""ファイルのダウンロード(レジューム対応)"""
|
116 |
+
dest_path = Path(dest_path)
|
117 |
+
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
118 |
+
|
119 |
+
headers = {}
|
120 |
+
mode = 'wb'
|
121 |
+
resume_pos = 0
|
122 |
+
|
123 |
+
# レジューム処理
|
124 |
+
if dest_path.exists():
|
125 |
+
resume_pos = dest_path.stat().st_size
|
126 |
+
headers['Range'] = f'bytes={resume_pos}-'
|
127 |
+
mode = 'ab'
|
128 |
+
|
129 |
+
for attempt in range(retries):
|
130 |
+
try:
|
131 |
+
response = requests.get(url, headers=headers, stream=True, timeout=30)
|
132 |
+
response.raise_for_status()
|
133 |
+
|
134 |
+
total_size = int(response.headers.get('content-length', 0))
|
135 |
+
if resume_pos > 0:
|
136 |
+
total_size += resume_pos
|
137 |
+
|
138 |
+
with open(dest_path, mode) as f:
|
139 |
+
with tqdm(total=total_size, initial=resume_pos, unit='B', unit_scale=True, desc=dest_path.name) as pbar:
|
140 |
+
for chunk in response.iter_content(chunk_size=8192):
|
141 |
+
if chunk:
|
142 |
+
f.write(chunk)
|
143 |
+
pbar.update(len(chunk))
|
144 |
+
|
145 |
+
return True
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
print(f"ダウンロードエラー (試行 {attempt + 1}/{retries}): {e}")
|
149 |
+
if attempt < retries - 1:
|
150 |
+
time.sleep(5) # 再試行前に待機
|
151 |
+
else:
|
152 |
+
raise
|
153 |
+
|
154 |
+
return False
|
155 |
+
|
156 |
+
def extract_archive(self, archive_path, dest_dir, extract_subdir=None):
|
157 |
+
"""アーカイブの展開"""
|
158 |
+
import tarfile
|
159 |
+
import zipfile
|
160 |
+
|
161 |
+
archive_path = Path(archive_path)
|
162 |
+
dest_dir = Path(dest_dir)
|
163 |
+
temp_dir = dest_dir / "temp_extract"
|
164 |
+
|
165 |
+
try:
|
166 |
+
if archive_path.suffix == '.gz' or archive_path.suffix == '.tar' or str(archive_path).endswith('.tar.gz'):
|
167 |
+
with tarfile.open(archive_path, 'r:*') as tar:
|
168 |
+
if extract_subdir:
|
169 |
+
# 一時ディレクトリに展開してから移動
|
170 |
+
temp_dir.mkdir(exist_ok=True)
|
171 |
+
tar.extractall(temp_dir)
|
172 |
+
# 特定のサブディレクトリを移動
|
173 |
+
src_dir = temp_dir / extract_subdir
|
174 |
+
if src_dir.exists():
|
175 |
+
shutil.move(str(src_dir), str(dest_dir / extract_subdir))
|
176 |
+
shutil.rmtree(temp_dir)
|
177 |
+
else:
|
178 |
+
tar.extractall(dest_dir)
|
179 |
+
elif archive_path.suffix == '.zip':
|
180 |
+
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
|
181 |
+
zip_ref.extractall(dest_dir)
|
182 |
+
else:
|
183 |
+
raise ValueError(f"Unsupported archive format: {archive_path.suffix}")
|
184 |
+
except Exception as e:
|
185 |
+
if temp_dir.exists():
|
186 |
+
shutil.rmtree(temp_dir)
|
187 |
+
raise e
|
188 |
+
|
189 |
+
def check_models_exist(self):
|
190 |
+
"""必要なモデルが存在するかチェック"""
|
191 |
+
missing_models = []
|
192 |
+
for config in self.model_configs:
|
193 |
+
if config['type'] == 'file':
|
194 |
+
dest_path = Path(config['dest_dir']) / config['dest_file']
|
195 |
+
if not dest_path.exists():
|
196 |
+
missing_models.append(config)
|
197 |
+
else: # archive
|
198 |
+
dest_dir = Path(config['dest_dir'])
|
199 |
+
if not dest_dir.exists() or not any(dest_dir.iterdir()):
|
200 |
+
missing_models.append(config)
|
201 |
+
return missing_models
|
202 |
+
|
203 |
+
def download_models(self):
|
204 |
+
"""必要なモデルをダウンロード"""
|
205 |
+
missing_models = self.check_models_exist()
|
206 |
+
|
207 |
+
if not missing_models:
|
208 |
+
print("すべてのモデルが既に存在します。")
|
209 |
+
return True
|
210 |
+
|
211 |
+
print(f"{len(missing_models)}個のモデルをダウンロードします...")
|
212 |
+
|
213 |
+
for config in missing_models:
|
214 |
+
size_info = config.get('size', '不明')
|
215 |
+
print(f"\n{config['name']} をダウンロード中... (サイズ: {size_info})")
|
216 |
+
|
217 |
+
# キャッシュパスの設定
|
218 |
+
cache_filename = f"{config['name']}.download"
|
219 |
+
cache_path = self.cache_dir / cache_filename
|
220 |
+
|
221 |
+
try:
|
222 |
+
# ダウンロード
|
223 |
+
if not cache_path.exists() or self.download_progress.get(config['name'], {}).get('status') != 'completed':
|
224 |
+
self.download_file(config['url'], cache_path)
|
225 |
+
self.download_progress[config['name']] = {'status': 'completed'}
|
226 |
+
self.save_progress()
|
227 |
+
|
228 |
+
# 展開またはコピー
|
229 |
+
if config['type'] == 'file':
|
230 |
+
dest_dir = Path(config['dest_dir'])
|
231 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
232 |
+
dest_path = dest_dir / config['dest_file']
|
233 |
+
shutil.copy2(cache_path, dest_path)
|
234 |
+
else: # archive
|
235 |
+
dest_dir = Path(config['dest_dir'])
|
236 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
237 |
+
print(f"{config['name']} を展開中...")
|
238 |
+
extract_subdir = config.get('extract_subdir')
|
239 |
+
self.extract_archive(cache_path, dest_dir, extract_subdir)
|
240 |
+
|
241 |
+
print(f"{config['name']} のセットアップ完了")
|
242 |
+
|
243 |
+
except Exception as e:
|
244 |
+
print(f"エラー: {config['name']} のダウンロード中にエラーが発生しました: {e}")
|
245 |
+
return False
|
246 |
+
|
247 |
+
return True
|
248 |
+
|
249 |
+
def setup_models(self):
|
250 |
+
"""モデルのセットアップ(メイン処理)"""
|
251 |
+
print("=== DittoTalkingHead モデルセットアップ ===")
|
252 |
+
print(f"キャッシュディレクトリ: {self.cache_dir}")
|
253 |
+
|
254 |
+
success = self.download_models()
|
255 |
+
|
256 |
+
if success:
|
257 |
+
print("\n✅ すべてのモデルのセットアップが完了しました!")
|
258 |
+
else:
|
259 |
+
print("\n❌ モデルのセットアップ中にエラーが発生しました。")
|
260 |
+
|
261 |
+
return success
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == "__main__":
|
265 |
+
# テスト実行
|
266 |
+
manager = ModelManager()
|
267 |
+
manager.setup_models()
|
requirements.txt
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core dependencies
|
2 |
+
torch==2.5.1
|
3 |
+
torchvision==0.20.1
|
4 |
+
torchaudio==2.5.1
|
5 |
+
numpy==2.0.1
|
6 |
+
pillow==11.0.0
|
7 |
+
|
8 |
+
# Audio processing
|
9 |
+
librosa==0.10.2.post1
|
10 |
+
soundfile==0.13.0
|
11 |
+
audioread==3.0.1
|
12 |
+
soxr==0.5.0.post1
|
13 |
+
|
14 |
+
# Video/Image processing
|
15 |
+
opencv-python-headless==4.10.0.84
|
16 |
+
imageio==2.36.1
|
17 |
+
imageio-ffmpeg==0.5.1
|
18 |
+
scikit-image==0.25.0
|
19 |
+
|
20 |
+
# Machine learning
|
21 |
+
scikit-learn==1.6.0
|
22 |
+
scipy==1.15.0
|
23 |
+
numba==0.60.0
|
24 |
+
|
25 |
+
# TensorRT (GPU acceleration)
|
26 |
+
tensorrt==8.6.1
|
27 |
+
tensorrt-bindings==8.6.1
|
28 |
+
tensorrt-libs==8.6.1
|
29 |
+
polygraphy
|
30 |
+
colored
|
31 |
+
|
32 |
+
# Web interface
|
33 |
+
gradio==4.19.0
|
34 |
+
|
35 |
+
# Utilities
|
36 |
+
tqdm==4.67.1
|
37 |
+
requests==2.32.3
|
38 |
+
pyyaml==6.0.2
|
39 |
+
joblib==1.4.2
|
40 |
+
cython==3.0.11
|
41 |
+
|
42 |
+
# CUDA dependencies
|
43 |
+
cuda-python==12.6.2.post1
|
44 |
+
nvidia-cublas-cu12==12.6.4.1
|
45 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
46 |
+
nvidia-cudnn-cu12==9.6.0.74
|
setup.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Setup script for Hugging Face Space
|
4 |
+
echo "=== DittoTalkingHead Setup Script ==="
|
5 |
+
|
6 |
+
# Create necessary directories
|
7 |
+
mkdir -p checkpoints/ditto_cfg
|
8 |
+
mkdir -p tmp
|
9 |
+
mkdir -p output
|
10 |
+
|
11 |
+
# Install system dependencies if needed
|
12 |
+
# apt-get update && apt-get install -y ffmpeg
|
13 |
+
|
14 |
+
# Run model download (PyTorch models)
|
15 |
+
echo "Starting model download (PyTorch models)..."
|
16 |
+
python -c "from model_manager import ModelManager; manager = ModelManager(use_pytorch=True); manager.setup_models()"
|
17 |
+
|
18 |
+
echo "Setup complete!"
|