Spaces:
Running
Running
import os | |
import sys | |
import subprocess | |
import numpy as np | |
import torch | |
import imageio | |
from skimage.transform import resize | |
from skimage import img_as_ubyte | |
import gradio as gr | |
from PIL import Image | |
import tempfile | |
import requests | |
from io import BytesIO | |
# Đảm bảo cài đặt các thư viện cần thiết | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-image==0.19.3"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "face-alignment==1.3.5"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "PyYAML==5.3.1"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "imageio-ffmpeg==0.4.5"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "requests"]) | |
# Cài đặt ffmpeg trong môi trường Ubuntu | |
os.system("apt-get update && apt-get install -y ffmpeg") | |
# Clone repo nếu chưa có | |
if not os.path.exists('first_order_model'): | |
subprocess.call(['git', 'clone', 'https://github.com/AliaksandrSiarohin/first-order-model.git']) | |
if os.path.exists('first-order-model'): | |
os.rename('first-order-model', 'first_order_model') | |
# Thêm đường dẫn vào PYTHONPATH | |
sys.path.append('.') | |
sys.path.append('first_order_model') | |
# Tạo file helper với hàm load_checkpoints | |
with open('load_helper.py', 'w') as f: | |
f.write(""" | |
import yaml | |
import torch | |
from first_order_model.modules.generator import OcclusionAwareGenerator | |
from first_order_model.modules.keypoint_detector import KPDetector | |
def load_checkpoints(config_path, checkpoint_path, device='cpu'): | |
with open(config_path) as f: | |
config = yaml.safe_load(f) | |
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], | |
**config['model_params']['common_params']) | |
generator.to(device) | |
kp_detector = KPDetector(**config['model_params']['kp_detector_params'], | |
**config['model_params']['common_params']) | |
kp_detector.to(device) | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
generator.load_state_dict(checkpoint['generator']) | |
kp_detector.load_state_dict(checkpoint['kp_detector']) | |
generator.eval() | |
kp_detector.eval() | |
return generator, kp_detector | |
def normalize_kp(kp_source, kp_driving, kp_driving_initial, | |
use_relative_movement=True, use_relative_jacobian=True, adapt_movement_scale=True): | |
from first_order_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d | |
kp_new = {k: v for k, v in kp_driving.items()} | |
if use_relative_movement: | |
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) | |
kp_value_diff_abs = torch.abs(kp_value_diff) | |
if adapt_movement_scale: | |
distance = torch.max(kp_value_diff_abs, dim=2, keepdim=True)[0] | |
distance = torch.max(distance, dim=1, keepdim=True)[0] | |
kp_source_diff = torch.abs(kp_source['value']) | |
kp_source_max = torch.max(kp_source_diff, dim=2, keepdim=True)[0] | |
kp_source_max = torch.max(kp_source_max, dim=1, keepdim=True)[0] | |
movement_scale = kp_source_max / (distance + 1e-6) | |
kp_new['value'] = kp_source['value'] + movement_scale * kp_value_diff | |
else: | |
kp_new['value'] = kp_source['value'] + kp_value_diff | |
if use_relative_jacobian: | |
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) | |
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) | |
return kp_new | |
""") | |
# Import hàm load_checkpoints từ file helper | |
from load_helper import load_checkpoints, normalize_kp | |
# Tải mô hình từ GitHub hoặc mirrors của first-order-model | |
def download_model(): | |
# URLs trực tiếp từ sources khác | |
checkpoint_urls = [ | |
"https://github.com/AliaksandrSiarohin/first-order-model/releases/download/v1.0.0/vox-cpk.pth.tar", | |
"https://raw.githubusercontent.com/jiupinjia/stylized-neural-painting/main/checkpoints/vox-cpk.pth.tar", | |
"https://github.com/snap-research/articulated-animation/raw/master/checkpoints/vox.pth.tar" | |
] | |
config_urls = [ | |
"https://raw.githubusercontent.com/AliaksandrSiarohin/first-order-model/master/config/vox-256.yaml", | |
"https://gist.githubusercontent.com/anonymous/raw/vox-256.yaml" | |
] | |
# Tạo thư mục | |
model_path = 'checkpoints/vox-cpk.pth.tar' | |
if not os.path.exists('checkpoints'): | |
os.makedirs('checkpoints', exist_ok=True) | |
config_path = 'first_order_model/config/vox-256.yaml' | |
if not os.path.exists('first_order_model/config'): | |
os.makedirs('first_order_model/config', exist_ok=True) | |
# Tải model checkpoint | |
success = False | |
for url in checkpoint_urls: | |
try: | |
print(f"Đang thử tải mô hình từ: {url}") | |
response = requests.get(url, stream=True, timeout=30) | |
if response.status_code == 200: | |
with open(model_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
# Kiểm tra kích thước file (checkpoint mô hình thường > 100MB) | |
if os.path.getsize(model_path) > 100000000: | |
success = True | |
break | |
except Exception as e: | |
print(f"Lỗi khi tải từ {url}: {str(e)}") | |
if not success: | |
raise Exception("Không thể tải mô hình checkpoint từ bất kỳ nguồn nào") | |
# Tải file cấu hình | |
config_success = False | |
for url in config_urls: | |
try: | |
print(f"Đang thử tải file cấu hình từ: {url}") | |
response = requests.get(url, timeout=30) | |
if response.status_code == 200: | |
with open(config_path, 'wb') as f: | |
f.write(response.content) | |
if os.path.getsize(config_path) > 1000: | |
config_success = True | |
break | |
except Exception as e: | |
print(f"Lỗi khi tải cấu hình từ {url}: {str(e)}") | |
if not config_success: | |
# Tạo file cấu hình đơn giản nếu không tải được | |
create_simple_config(config_path) | |
return config_path, model_path | |
# Tạo file cấu hình đơn giản nếu không tải được | |
def create_simple_config(config_path): | |
with open(config_path, 'w') as f: | |
f.write(""" | |
model_params: | |
common_params: | |
num_kp: 10 | |
num_channels: 3 | |
estimate_jacobian: true | |
kp_detector_params: | |
temperature: 0.1 | |
block_expansion: 32 | |
max_features: 1024 | |
scale_factor: 0.25 | |
num_blocks: 5 | |
generator_params: | |
block_expansion: 64 | |
max_features: 512 | |
num_down_blocks: 2 | |
num_bottleneck_blocks: 6 | |
estimate_occlusion_map: true | |
dense_motion_params: | |
block_expansion: 64 | |
max_features: 1024 | |
num_blocks: 5 | |
scale_factor: 0.25 | |
""") | |
print("Đã tạo file cấu hình đơn giản") | |
# Hàm tạo animation | |
def make_animation(source_image, driving_video, relative=True, adapt_movement_scale=True): | |
config_path, checkpoint_path = download_model() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
# Tải mô hình và cấu hình | |
generator, kp_detector = load_checkpoints(config_path, checkpoint_path, device=device) | |
# Đọc source_image và driving_video | |
source = imageio.imread(source_image) | |
reader = imageio.get_reader(driving_video) | |
fps = reader.get_meta_data()['fps'] | |
driving = [] | |
try: | |
for im in reader: | |
driving.append(im) | |
except RuntimeError: | |
pass | |
reader.close() | |
# Tiền xử lý | |
source = resize(source, (256, 256))[..., :3] | |
driving = [resize(frame, (256, 256))[..., :3] for frame in driving] | |
# Chuyển đổi thành tensor | |
source = torch.tensor(source[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device) | |
driving = torch.tensor(np.array(driving).astype(np.float32)).permute(0, 3, 1, 2).to(device) | |
# Trích xuất keypoints | |
kp_source = kp_detector(source) | |
kp_driving_initial = kp_detector(driving[0:1]) | |
# Tạo animation | |
with torch.no_grad(): | |
predictions = [] | |
for frame_idx in range(driving.shape[0]): | |
driving_frame = driving[frame_idx:frame_idx+1] | |
kp_driving = kp_detector(driving_frame) | |
# Chuẩn hóa keypoints | |
kp_norm = normalize_kp( | |
kp_source=kp_source, | |
kp_driving=kp_driving, | |
kp_driving_initial=kp_driving_initial, | |
use_relative_movement=relative, | |
use_relative_jacobian=relative, | |
adapt_movement_scale=adapt_movement_scale | |
) | |
# Tạo frame | |
out = generator(source, kp_source=kp_source, kp_driving=kp_norm) | |
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) | |
# Lưu video kết quả | |
output_path = f'result_{int(np.random.rand() * 10000)}.mp4' | |
if os.path.exists(output_path): | |
os.remove(output_path) # Xóa video nếu tồn tại | |
# Lưu frames thành video sử dụng imageio | |
frames = [img_as_ubyte(frame) for frame in predictions] | |
imageio.mimsave(output_path, frames, fps=fps) | |
return output_path | |
# Tải video mẫu | |
def download_sample_video(): | |
sample_urls = [ | |
"https://github.com/AliaksandrSiarohin/first-order-model/raw/master/driving.mp4", | |
"https://raw.githubusercontent.com/jiupinjia/stylized-neural-painting/main/sample/driving.mp4" | |
] | |
sample_path = "sample_driving.mp4" | |
for url in sample_urls: | |
try: | |
print(f"Đang thử tải video mẫu từ: {url}") | |
response = requests.get(url, timeout=30) | |
if response.status_code == 200: | |
with open(sample_path, 'wb') as f: | |
f.write(response.content) | |
if os.path.getsize(sample_path) > 10000: # Kiểm tra kích thước file | |
return sample_path | |
except Exception as e: | |
print(f"Lỗi khi tải video mẫu từ {url}: {str(e)}") | |
# Nếu không tải được, tạo video đơn giản | |
create_simple_video(sample_path) | |
return sample_path | |
# Tạo video đơn giản nếu không tải được video mẫu | |
def create_simple_video(output_path): | |
import cv2 | |
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (256, 256)) | |
# Tạo 100 khung hình với chuyển động đơn giản | |
for i in range(100): | |
frame = np.zeros((256, 256, 3), dtype=np.uint8) | |
# Vẽ khuôn mặt đơn giản chuyển động | |
x_center = 128 + int(np.sin(i/10) * 20) | |
y_center = 128 + int(np.cos(i/20) * 10) | |
# Vẽ khuôn mặt | |
cv2.circle(frame, (x_center, y_center), 60, (200, 200, 200), -1) # Mặt | |
cv2.circle(frame, (x_center - 20, y_center - 15), 10, (0, 0, 0), -1) # Mắt trái | |
cv2.circle(frame, (x_center + 20, y_center - 15), 10, (0, 0, 0), -1) # Mắt phải | |
# Vẽ miệng | |
mouth_y = y_center + 20 + int(np.sin(i/5) * 5) | |
cv2.ellipse(frame, (x_center, mouth_y), (20, 10), 0, 0, 180, (0, 0, 0), -1) | |
out.write(frame) | |
out.release() | |
print("Đã tạo video đơn giản") | |
# Định nghĩa giao diện Gradio | |
def animate_fomm(source_image, driving_video_file, relative=True, adapt_scale=True): | |
if source_image is None: | |
return None, "Vui lòng tải lên ảnh nguồn." | |
try: | |
# Lưu tạm ảnh nguồn | |
source_path = f"source_image_{int(np.random.rand() * 10000)}.jpg" | |
source_image.save(source_path) | |
# Xử lý video tham chiếu | |
print(f"Type of driving_video: {type(driving_video_file)}") | |
# Tạo file tạm cho video | |
driving_path = f"driving_video_{int(np.random.rand() * 10000)}.mp4" | |
# Kiểm tra nếu đã chọn sử dụng video mẫu | |
if driving_video_file is None: | |
# Tải và sử dụng video mẫu | |
driving_path = download_sample_video() | |
else: | |
# Xử lý video được tải lên | |
if isinstance(driving_video_file, str): | |
# Nếu là đường dẫn, copy file | |
if os.path.exists(driving_video_file): | |
import shutil | |
shutil.copyfile(driving_video_file, driving_path) | |
else: | |
return None, f"Không tìm thấy file video tại đường dẫn: {driving_video_file}" | |
else: | |
# Ghi dữ liệu nhị phân vào file | |
with open(driving_path, 'wb') as f: | |
f.write(driving_video_file) | |
# Tạo animation | |
result_path = make_animation( | |
source_path, | |
driving_path, | |
relative=relative, | |
adapt_movement_scale=adapt_scale | |
) | |
# Xóa file tạm nếu cần | |
if os.path.exists(source_path) and source_path != "source_image.jpg": | |
os.remove(source_path) | |
if os.path.exists(driving_path) and driving_path != "sample_driving.mp4" and driving_path != "driving_video.mp4": | |
os.remove(driving_path) | |
return result_path, "Video được tạo thành công!" | |
except Exception as e: | |
import traceback | |
return None, f"Lỗi: {str(e)}\n{traceback.format_exc()}" | |
# Tạo giao diện Gradio | |
with gr.Blocks(title="First Order Motion Model - Tạo video người chuyển động") as demo: | |
gr.Markdown("# First Order Motion Model") | |
gr.Markdown("Tạo video người chuyển động từ một ảnh tĩnh và video tham chiếu") | |
with gr.Row(): | |
with gr.Column(): | |
source_image = gr.Image(type="pil", label="Tải lên ảnh nguồn") | |
# Thêm tùy chọn sử dụng video mẫu | |
use_sample = gr.Checkbox(label="Sử dụng video mẫu có sẵn", value=True) | |
# Thay đổi từ gr.Video sang gr.File để xử lý lỗi binary | |
driving_video_file = gr.File(label="Tải lên video tham chiếu (.mp4)", visible=False) | |
with gr.Row(): | |
relative = gr.Checkbox(value=True, label="Chuyển động tương đối") | |
adapt_scale = gr.Checkbox(value=True, label="Điều chỉnh tỷ lệ chuyển động") | |
submit_btn = gr.Button("Tạo video") | |
with gr.Column(): | |
output_video = gr.Video(label="Video kết quả") | |
output_message = gr.Textbox(label="Thông báo", lines=5) | |
# Xử lý sự kiện khi checkbox được chọn | |
def toggle_video_upload(use_sample_video): | |
return gr.update(visible=not use_sample_video) | |
use_sample.change(fn=toggle_video_upload, inputs=[use_sample], outputs=[driving_video_file]) | |
# Cập nhật hàm xử lý khi nhấn nút | |
def process_inputs(source_img, use_sample_vid, driving_vid, rel, adapt): | |
if use_sample_vid: | |
return animate_fomm(source_img, None, rel, adapt) | |
else: | |
return animate_fomm(source_img, driving_vid, rel, adapt) | |
submit_btn.click( | |
fn=process_inputs, | |
inputs=[source_image, use_sample, driving_video_file, relative, adapt_scale], | |
outputs=[output_video, output_message] | |
) | |
gr.Markdown("### Cách sử dụng") | |
gr.Markdown("1. Tải lên **ảnh nguồn** - ảnh chứa người/đối tượng bạn muốn làm chuyển động") | |
gr.Markdown("2. Chọn sử dụng video mẫu có sẵn hoặc tải lên video tham chiếu của riêng bạn") | |
gr.Markdown("3. Nhấn **Tạo video** và chờ kết quả") | |
gr.Markdown("### Lưu ý") | |
gr.Markdown("- Ảnh nguồn và video tham chiếu nên có đối tượng tương tự (người với người, mặt với mặt)") | |
gr.Markdown("- Đối tượng nên ở vị trí tương tự trong cả ảnh nguồn và khung đầu tiên của video tham chiếu") | |
gr.Markdown("- Quá trình tạo video có thể mất vài phút") | |
gr.Markdown("- Nếu gặp vấn đề với việc tải lên video, hãy sử dụng video mẫu có sẵn") | |
demo.launch() |