Spaces:
Runtime error
Runtime error
""" | |
DittoTalkingHead Streaming Client | |
WebSocketを使用したストリーミングクライアントの実装例 | |
""" | |
import asyncio | |
import websockets | |
import numpy as np | |
import soundfile as sf | |
import base64 | |
import json | |
import cv2 | |
from typing import Optional, Callable | |
import pyaudio | |
import threading | |
import queue | |
from pathlib import Path | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class DittoStreamingClient: | |
"""DittoTalkingHeadストリーミングクライアント""" | |
def __init__(self, server_url="ws://localhost:8000/ws/generate"): | |
self.server_url = server_url | |
self.sample_rate = 16000 | |
self.chunk_duration = 0.2 # 200ms | |
self.chunk_size = int(self.sample_rate * self.chunk_duration) | |
self.websocket = None | |
self.is_connected = False | |
self.frame_callback: Optional[Callable] = None | |
self.final_video_callback: Optional[Callable] = None | |
async def connect(self, source_image_path: str): | |
"""サーバーに接続してセッションを開始""" | |
try: | |
# 画像をBase64エンコード | |
with open(source_image_path, "rb") as f: | |
image_b64 = base64.b64encode(f.read()).decode('utf-8') | |
# WebSocket接続 | |
self.websocket = await websockets.connect(self.server_url) | |
self.is_connected = True | |
# 初期設定を送信 | |
await self.websocket.send(json.dumps({ | |
"source_image": image_b64, | |
"sample_rate": self.sample_rate, | |
"chunk_duration": self.chunk_duration | |
})) | |
# 応答を待つ | |
response = await self.websocket.recv() | |
data = json.loads(response) | |
if data["type"] == "ready": | |
logger.info(f"Connected to server: {data['message']}") | |
return True | |
else: | |
logger.error(f"Connection failed: {data}") | |
return False | |
except Exception as e: | |
logger.error(f"Connection error: {e}") | |
self.is_connected = False | |
raise | |
async def disconnect(self): | |
"""接続を切断""" | |
if self.websocket: | |
await self.websocket.close() | |
self.is_connected = False | |
logger.info("Disconnected from server") | |
async def stream_audio_file(self, audio_path: str, source_image_path: str): | |
"""音声ファイルをストリーミング""" | |
try: | |
# 接続 | |
await self.connect(source_image_path) | |
# 音声を読み込み | |
audio_data, sr = sf.read(audio_path) | |
if sr != self.sample_rate: | |
import librosa | |
audio_data = librosa.resample( | |
audio_data, | |
orig_sr=sr, | |
target_sr=self.sample_rate | |
) | |
# フレーム受信タスク | |
receive_task = asyncio.create_task(self._receive_frames()) | |
# 音声をチャンク単位で送信 | |
total_chunks = 0 | |
for i in range(0, len(audio_data), self.chunk_size): | |
chunk = audio_data[i:i+self.chunk_size] | |
if len(chunk) < self.chunk_size: | |
chunk = np.pad(chunk, (0, self.chunk_size - len(chunk))) | |
# Float32として送信 | |
await self.websocket.send(chunk.astype(np.float32).tobytes()) | |
total_chunks += 1 | |
# リアルタイムシミュレーション | |
await asyncio.sleep(self.chunk_duration) | |
# 進捗表示 | |
progress = (i + self.chunk_size) / len(audio_data) * 100 | |
logger.info(f"Streaming progress: {progress:.1f}%") | |
# 停止コマンドを送信 | |
await self.websocket.send(json.dumps({"action": "stop"})) | |
logger.info(f"Sent {total_chunks} audio chunks") | |
# フレーム受信を待つ | |
await receive_task | |
finally: | |
await self.disconnect() | |
async def stream_microphone(self, source_image_path: str, duration: Optional[float] = None): | |
"""マイクからリアルタイムストリーミング""" | |
try: | |
# 接続 | |
await self.connect(source_image_path) | |
# フレーム受信タスク | |
receive_task = asyncio.create_task(self._receive_frames()) | |
# マイク録音用のキュー | |
audio_queue = queue.Queue() | |
stop_event = threading.Event() | |
# マイク録音スレッド | |
def record_audio(): | |
p = pyaudio.PyAudio() | |
stream = p.open( | |
format=pyaudio.paFloat32, | |
channels=1, | |
rate=self.sample_rate, | |
input=True, | |
frames_per_buffer=self.chunk_size | |
) | |
logger.info("Recording started... Press Ctrl+C to stop") | |
try: | |
start_time = asyncio.get_event_loop().time() | |
while not stop_event.is_set(): | |
if duration and (asyncio.get_event_loop().time() - start_time) > duration: | |
break | |
audio_chunk = stream.read(self.chunk_size, exception_on_overflow=False) | |
audio_queue.put(audio_chunk) | |
except Exception as e: | |
logger.error(f"Recording error: {e}") | |
finally: | |
stream.stop_stream() | |
stream.close() | |
p.terminate() | |
logger.info("Recording stopped") | |
# 録音スレッドを開始 | |
record_thread = threading.Thread(target=record_audio) | |
record_thread.start() | |
try: | |
# 音声データを送信 | |
while record_thread.is_alive() or not audio_queue.empty(): | |
try: | |
audio_chunk = audio_queue.get(timeout=0.1) | |
audio_array = np.frombuffer(audio_chunk, dtype=np.float32) | |
await self.websocket.send(audio_array.tobytes()) | |
except queue.Empty: | |
continue | |
except KeyboardInterrupt: | |
logger.info("Stopping recording...") | |
break | |
finally: | |
stop_event.set() | |
record_thread.join() | |
# 停止コマンドを送信 | |
await self.websocket.send(json.dumps({"action": "stop"})) | |
# フレーム受信を待つ | |
await receive_task | |
finally: | |
await self.disconnect() | |
async def _receive_frames(self): | |
"""フレームとメッセージを受信""" | |
frame_count = 0 | |
try: | |
while True: | |
message = await self.websocket.recv() | |
data = json.loads(message) | |
if data["type"] == "frame": | |
frame_count += 1 | |
logger.info(f"Received frame {data['frame_id']} (FPS: {data.get('fps', 0)})") | |
if self.frame_callback: | |
# フレームをデコード | |
frame_data = base64.b64decode(data["data"]) | |
nparr = np.frombuffer(frame_data, np.uint8) | |
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
self.frame_callback(frame, data) | |
elif data["type"] == "progress": | |
logger.info(f"Progress: {data['duration_seconds']:.1f}s processed") | |
elif data["type"] == "processing": | |
logger.info(f"Server: {data['message']}") | |
elif data["type"] == "final_video": | |
logger.info(f"Received final video ({data['size_bytes']} bytes, {data['duration_seconds']:.1f}s)") | |
if self.final_video_callback: | |
video_data = base64.b64decode(data["data"]) | |
self.final_video_callback(video_data, data) | |
break | |
elif data["type"] == "error": | |
logger.error(f"Server error: {data['message']}") | |
break | |
except websockets.exceptions.ConnectionClosed: | |
logger.info("Connection closed by server") | |
except Exception as e: | |
logger.error(f"Receive error: {e}") | |
logger.info(f"Total frames received: {frame_count}") | |
def set_frame_callback(self, callback: Callable): | |
"""フレーム受信時のコールバックを設定""" | |
self.frame_callback = callback | |
def set_final_video_callback(self, callback: Callable): | |
"""最終動画受信時のコールバックを設定""" | |
self.final_video_callback = callback | |
# 使用例とテスト | |
async def main(): | |
"""使用例""" | |
client = DittoStreamingClient() | |
# フレーム表示用のコールバック | |
def display_frame(frame, metadata): | |
cv2.imshow("Live Frame", frame) | |
cv2.waitKey(1) | |
# 最終動画保存用のコールバック | |
def save_video(video_data, metadata): | |
output_path = "output_streaming.mp4" | |
with open(output_path, "wb") as f: | |
f.write(video_data) | |
logger.info(f"Video saved to {output_path}") | |
client.set_frame_callback(display_frame) | |
client.set_final_video_callback(save_video) | |
# テスト画像とサンプル音声のパス | |
source_image = "example/reference.png" | |
audio_file = "example/audio.wav" | |
# ファイルが存在するか確認 | |
if not Path(source_image).exists(): | |
logger.error(f"Source image not found: {source_image}") | |
return | |
# 音声ファイルからストリーミング | |
if Path(audio_file).exists(): | |
logger.info("=== Testing audio file streaming ===") | |
await client.stream_audio_file(audio_file, source_image) | |
else: | |
logger.warning(f"Audio file not found: {audio_file}") | |
# マイクからストリーミング(5秒間) | |
# logger.info("\n=== Testing microphone streaming (5 seconds) ===") | |
# await client.stream_microphone(source_image, duration=5.0) | |
cv2.destroyAllWindows() | |
# バッチ処理クライアント | |
class BatchStreamingClient: | |
"""複数のリクエストを並列処理するクライアント""" | |
def __init__(self, server_url="ws://localhost:8000/ws/generate", max_parallel=3): | |
self.server_url = server_url | |
self.max_parallel = max_parallel | |
async def process_batch(self, tasks: list): | |
"""バッチ処理""" | |
semaphore = asyncio.Semaphore(self.max_parallel) | |
async def process_with_limit(task): | |
async with semaphore: | |
client = DittoStreamingClient(self.server_url) | |
await client.stream_audio_file( | |
task["audio_path"], | |
task["image_path"] | |
) | |
return task["id"] | |
results = await asyncio.gather( | |
*[process_with_limit(task) for task in tasks], | |
return_exceptions=True | |
) | |
return results | |
if __name__ == "__main__": | |
# 単一クライアントのテスト | |
asyncio.run(main()) | |
# バッチ処理の例 | |
# batch_client = BatchStreamingClient() | |
# tasks = [ | |
# {"id": 1, "audio_path": "audio1.wav", "image_path": "image1.png"}, | |
# {"id": 2, "audio_path": "audio2.wav", "image_path": "image2.png"}, | |
# ] | |
# asyncio.run(batch_client.process_batch(tasks)) |