Spaces:
Runtime error
Runtime error
""" | |
FastAPI server for DittoTalkingHead with Phase 3 optimizations | |
Implements /prepare_avatar and /generate_video endpoints | |
""" | |
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import os | |
import tempfile | |
import shutil | |
from pathlib import Path | |
import torch | |
import time | |
from typing import Optional, Dict, Any | |
import io | |
import asyncio | |
from datetime import datetime | |
import uvicorn | |
from model_manager import ModelManager | |
from core.optimization import ( | |
FixedResolutionProcessor, | |
GPUOptimizer, | |
AvatarCache, | |
AvatarTokenManager, | |
ColdStartOptimizer | |
) | |
# FastAPIアプリケーションの初期化 | |
app = FastAPI( | |
title="DittoTalkingHead API", | |
description="High-performance talking head generation API with Phase 3 optimizations", | |
version="3.0.0" | |
) | |
# CORS設定 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# グローバル初期化 | |
print("=== API Server Phase 3 - 初期化開始 ===") | |
# 1. 解像度最適化 | |
resolution_optimizer = FixedResolutionProcessor() | |
FIXED_RESOLUTION = resolution_optimizer.get_max_dim() | |
# 2. GPU最適化 | |
gpu_optimizer = GPUOptimizer() | |
# 3. Cold Start最適化 | |
cold_start_optimizer = ColdStartOptimizer(persistent_dir="/tmp/persistent_model_cache") | |
# 4. アバターキャッシュ | |
avatar_cache = AvatarCache(cache_dir="/tmp/avatar_cache", ttl_days=14) | |
token_manager = AvatarTokenManager(avatar_cache) | |
# モデルとSDKの初期化 | |
USE_PYTORCH = True | |
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH) | |
SDK = None | |
# 初期化処理 | |
async def startup_event(): | |
"""アプリケーション起動時の初期化""" | |
global SDK | |
print("Starting model initialization...") | |
# Cold start最適化 | |
cold_start_optimizer.setup_persistent_model_cache("./checkpoints") | |
# モデルセットアップ | |
if not model_manager.setup_models(): | |
raise RuntimeError("Failed to setup models") | |
# SDK初期化 | |
if USE_PYTORCH: | |
data_root = "./checkpoints/ditto_pytorch" | |
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
else: | |
data_root = "./checkpoints/ditto_trt_Ampere_Plus" | |
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" | |
try: | |
from stream_pipeline_offline import StreamSDK | |
SDK = StreamSDK(cfg_pkl, data_root) | |
# GPU最適化を適用(torch.nn.Moduleの場合のみ) | |
if hasattr(SDK, 'decode_f3d') and hasattr(SDK.decode_f3d, 'decoder'): | |
try: | |
import torch.nn as nn | |
if isinstance(SDK.decode_f3d.decoder, nn.Module): | |
SDK.decode_f3d.decoder = gpu_optimizer.optimize_model(SDK.decode_f3d.decoder) | |
print("✅ Decoder model optimized") | |
else: | |
print("ℹ️ Decoder is not nn.Module, skipping optimization") | |
except Exception as e: | |
print(f"⚠️ Skipping GPU optimization: {e}") | |
print("✅ SDK initialized with optimizations") | |
except Exception as e: | |
print(f"❌ SDK initialization error: {e}") | |
raise | |
# ヘルスチェックエンドポイント | |
async def health_check(): | |
"""サーバーの状態を確認""" | |
return { | |
"status": "healthy", | |
"gpu_available": torch.cuda.is_available(), | |
"cache_info": avatar_cache.get_cache_info(), | |
"optimization_enabled": True | |
} | |
# アバター準備エンドポイント | |
async def prepare_avatar(file: UploadFile = File(...)): | |
""" | |
画像を事前にアップロードして埋め込みを生成 | |
Args: | |
file: アップロードされた画像ファイル | |
Returns: | |
avatar_token と有効期限 | |
""" | |
# ファイル検証 | |
if not file.content_type.startswith("image/"): | |
raise HTTPException(status_code=400, detail="File must be an image") | |
try: | |
# 画像データを読み込む | |
image_data = await file.read() | |
# 画像を処理して埋め込みを生成 | |
from PIL import Image | |
import numpy as np | |
# 画像を読み込んで前処理 | |
img = Image.open(io.BytesIO(image_data)) | |
img = img.convert('RGB') | |
img = img.resize((FIXED_RESOLUTION, FIXED_RESOLUTION)) | |
# 外観エンコーダーで埋め込みを生成(簡略化版) | |
# TODO: 実際のappearance_extractorを使用 | |
def encode_appearance(img_data): | |
# ここでSDKの外観抽出機能を使用 | |
import numpy as np | |
# 仮の埋め込みベクトル生成 | |
# 実際の実装では、SDKのappearance_extractorを使用 | |
embedding = np.random.randn(512).astype(np.float32) | |
return embedding | |
# トークンを生成 | |
result = token_manager.prepare_avatar( | |
image_data, | |
encode_appearance | |
) | |
return JSONResponse(content={ | |
"avatar_token": result['avatar_token'], | |
"expires": result['expires'], | |
"cached": result['cached'], | |
"resolution": f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# 動画生成エンドポイント | |
async def generate_video( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(...), | |
avatar_token: Optional[str] = None, | |
avatar_image: Optional[UploadFile] = None | |
): | |
""" | |
音声とavatar_tokenから動画を生成 | |
Args: | |
file: 音声ファイル(WAV) | |
avatar_token: 事前生成されたアバタートークン(オプション) | |
avatar_image: アバター画像(avatar_tokenがない場合) | |
Returns: | |
生成された動画(MP4) | |
""" | |
# 音声ファイル検証 | |
if not file.content_type.startswith("audio/"): | |
raise HTTPException(status_code=400, detail="File must be an audio file") | |
# アバター入力の検証 | |
if avatar_token is None and avatar_image is None: | |
raise HTTPException( | |
status_code=400, | |
detail="Either avatar_token or avatar_image must be provided" | |
) | |
try: | |
start_time = time.time() | |
# 一時ファイルを作成 | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: | |
audio_content = await file.read() | |
tmp_audio.write(audio_content) | |
audio_path = tmp_audio.name | |
# アバター処理 | |
if avatar_token: | |
# キャッシュから埋め込みを取得 | |
embedding = avatar_cache.load_embedding(avatar_token) | |
if embedding is None: | |
raise HTTPException( | |
status_code=400, | |
detail="Invalid or expired avatar_token" | |
) | |
print(f"✅ Using cached embedding: {avatar_token[:8]}...") | |
# 仮の画像パス(SDKの要求に応じて) | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_img: | |
# ダミー画像を作成(実際はキャッシュされた埋め込みを使用) | |
from PIL import Image | |
dummy_img = Image.new('RGB', (FIXED_RESOLUTION, FIXED_RESOLUTION), 'white') | |
dummy_img.save(tmp_img.name) | |
image_path = tmp_img.name | |
else: | |
# 画像を一時保存 | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_img: | |
img_content = await avatar_image.read() | |
tmp_img.write(img_content) | |
image_path = tmp_img.name | |
# 出力ファイル | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output: | |
output_path = tmp_output.name | |
# 解像度最適化設定 | |
setup_kwargs = { | |
"max_size": FIXED_RESOLUTION, | |
"sampling_timesteps": resolution_optimizer.get_diffusion_steps() | |
} | |
# 動画生成を実行 | |
from inference import run, seed_everything | |
seed_everything(1024) | |
# 非同期実行のためのラッパー | |
loop = asyncio.get_event_loop() | |
await loop.run_in_executor( | |
None, | |
run, | |
SDK, | |
audio_path, | |
image_path, | |
output_path, | |
{"setup_kwargs": setup_kwargs} | |
) | |
# 処理時間 | |
process_time = time.time() - start_time | |
print(f"✅ Video generated in {process_time:.2f}s") | |
# クリーンアップをバックグラウンドで実行 | |
def cleanup_files(): | |
try: | |
os.unlink(audio_path) | |
os.unlink(image_path) | |
# output_pathは返却後に削除 | |
except: | |
pass | |
background_tasks.add_task(cleanup_files) | |
# 動画をストリーミング返却 | |
def iterfile(): | |
with open(output_path, 'rb') as f: | |
yield from f | |
# ファイルを削除 | |
try: | |
os.unlink(output_path) | |
except: | |
pass | |
return StreamingResponse( | |
iterfile(), | |
media_type="video/mp4", | |
headers={ | |
"Content-Disposition": f"attachment; filename=talking_head_{int(time.time())}.mp4", | |
"X-Process-Time": str(process_time), | |
"X-Resolution": f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" | |
} | |
) | |
except Exception as e: | |
# エラー時のクリーンアップ | |
for path in [audio_path, image_path, output_path]: | |
try: | |
if 'path' in locals() and os.path.exists(path): | |
os.unlink(path) | |
except: | |
pass | |
raise HTTPException(status_code=500, detail=str(e)) | |
# キャッシュ情報エンドポイント | |
async def get_cache_info(): | |
"""キャッシュの統計情報を取得""" | |
return { | |
"avatar_cache": avatar_cache.get_cache_info(), | |
"gpu_memory": gpu_optimizer.get_memory_stats(), | |
"cold_start_stats": cold_start_optimizer.get_optimization_stats() | |
} | |
# トークン検証エンドポイント | |
async def validate_token(token: str): | |
"""アバタートークンの有効性を確認""" | |
info = token_manager.get_token_info(token) | |
if info is None: | |
raise HTTPException(status_code=404, detail="Token not found") | |
return info | |
# パフォーマンステストエンドポイント | |
async def run_benchmark(duration_seconds: int = 16): | |
""" | |
パフォーマンステストを実行 | |
Args: | |
duration_seconds: テスト音声の長さ(秒) | |
""" | |
try: | |
# ダミーの音声と画像を生成 | |
import numpy as np | |
from scipy.io import wavfile | |
from PIL import Image | |
# テスト音声生成(無音) | |
sample_rate = 16000 | |
audio_data = np.zeros(duration_seconds * sample_rate, dtype=np.float32) | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: | |
wavfile.write(tmp_audio.name, sample_rate, audio_data) | |
audio_path = tmp_audio.name | |
# テスト画像生成 | |
test_img = Image.new('RGB', (FIXED_RESOLUTION, FIXED_RESOLUTION), 'white') | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_img: | |
test_img.save(tmp_img.name) | |
image_path = tmp_img.name | |
# 出力パス | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output: | |
output_path = tmp_output.name | |
# ベンチマーク実行 | |
start_time = time.time() | |
from inference import run, seed_everything | |
seed_everything(1024) | |
setup_kwargs = { | |
"max_size": FIXED_RESOLUTION, | |
"sampling_timesteps": resolution_optimizer.get_diffusion_steps() | |
} | |
run(SDK, audio_path, image_path, output_path, {"setup_kwargs": setup_kwargs}) | |
process_time = time.time() - start_time | |
# クリーンアップ | |
for path in [audio_path, image_path, output_path]: | |
try: | |
os.unlink(path) | |
except: | |
pass | |
# パフォーマンス検証 | |
perf_result = resolution_optimizer.validate_performance_improvement( | |
original_time=duration_seconds * 1.9, # 元の処理時間(推定) | |
optimized_time=process_time | |
) | |
return { | |
"audio_duration_seconds": duration_seconds, | |
"process_time_seconds": process_time, | |
"realtime_factor": process_time / duration_seconds, | |
"performance": perf_result, | |
"optimization_config": resolution_optimizer.get_performance_config() | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
# サーバー起動 | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=8000, | |
workers=1, # GPUを使用するため単一ワーカー | |
log_level="info" | |
) |