imageToVideo / app.py
alvinichi's picture
update
5b61f28
import os
import sys
import subprocess
import numpy as np
import torch
import imageio
from skimage.transform import resize
from skimage import img_as_ubyte
import gradio as gr
from PIL import Image
import tempfile
import requests
from io import BytesIO
# Đảm bảo cài đặt các thư viện cần thiết
subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-image==0.19.3"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "face-alignment==1.3.5"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "PyYAML==5.3.1"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "imageio-ffmpeg==0.4.5"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "requests"])
# Cài đặt ffmpeg trong môi trường Ubuntu
os.system("apt-get update && apt-get install -y ffmpeg")
# Clone repo nếu chưa có
if not os.path.exists('first_order_model'):
subprocess.call(['git', 'clone', 'https://github.com/AliaksandrSiarohin/first-order-model.git'])
if os.path.exists('first-order-model'):
os.rename('first-order-model', 'first_order_model')
# Thêm đường dẫn vào PYTHONPATH
sys.path.append('.')
sys.path.append('first_order_model')
# Tạo file helper với hàm load_checkpoints
with open('load_helper.py', 'w') as f:
f.write("""
import yaml
import torch
from first_order_model.modules.generator import OcclusionAwareGenerator
from first_order_model.modules.keypoint_detector import KPDetector
def load_checkpoints(config_path, checkpoint_path, device='cpu'):
with open(config_path) as f:
config = yaml.safe_load(f)
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
generator.to(device)
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
kp_detector.to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
generator.eval()
kp_detector.eval()
return generator, kp_detector
def normalize_kp(kp_source, kp_driving, kp_driving_initial,
use_relative_movement=True, use_relative_jacobian=True, adapt_movement_scale=True):
from first_order_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
kp_new = {k: v for k, v in kp_driving.items()}
if use_relative_movement:
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
kp_value_diff_abs = torch.abs(kp_value_diff)
if adapt_movement_scale:
distance = torch.max(kp_value_diff_abs, dim=2, keepdim=True)[0]
distance = torch.max(distance, dim=1, keepdim=True)[0]
kp_source_diff = torch.abs(kp_source['value'])
kp_source_max = torch.max(kp_source_diff, dim=2, keepdim=True)[0]
kp_source_max = torch.max(kp_source_max, dim=1, keepdim=True)[0]
movement_scale = kp_source_max / (distance + 1e-6)
kp_new['value'] = kp_source['value'] + movement_scale * kp_value_diff
else:
kp_new['value'] = kp_source['value'] + kp_value_diff
if use_relative_jacobian:
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
return kp_new
""")
# Import hàm load_checkpoints từ file helper
from load_helper import load_checkpoints, normalize_kp
# Tải mô hình từ GitHub hoặc mirrors của first-order-model
def download_model():
# URLs trực tiếp từ sources khác
checkpoint_urls = [
"https://github.com/AliaksandrSiarohin/first-order-model/releases/download/v1.0.0/vox-cpk.pth.tar",
"https://raw.githubusercontent.com/jiupinjia/stylized-neural-painting/main/checkpoints/vox-cpk.pth.tar",
"https://github.com/snap-research/articulated-animation/raw/master/checkpoints/vox.pth.tar"
]
config_urls = [
"https://raw.githubusercontent.com/AliaksandrSiarohin/first-order-model/master/config/vox-256.yaml",
"https://gist.githubusercontent.com/anonymous/raw/vox-256.yaml"
]
# Tạo thư mục
model_path = 'checkpoints/vox-cpk.pth.tar'
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints', exist_ok=True)
config_path = 'first_order_model/config/vox-256.yaml'
if not os.path.exists('first_order_model/config'):
os.makedirs('first_order_model/config', exist_ok=True)
# Tải model checkpoint
success = False
for url in checkpoint_urls:
try:
print(f"Đang thử tải mô hình từ: {url}")
response = requests.get(url, stream=True, timeout=30)
if response.status_code == 200:
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# Kiểm tra kích thước file (checkpoint mô hình thường > 100MB)
if os.path.getsize(model_path) > 100000000:
success = True
break
except Exception as e:
print(f"Lỗi khi tải từ {url}: {str(e)}")
if not success:
raise Exception("Không thể tải mô hình checkpoint từ bất kỳ nguồn nào")
# Tải file cấu hình
config_success = False
for url in config_urls:
try:
print(f"Đang thử tải file cấu hình từ: {url}")
response = requests.get(url, timeout=30)
if response.status_code == 200:
with open(config_path, 'wb') as f:
f.write(response.content)
if os.path.getsize(config_path) > 1000:
config_success = True
break
except Exception as e:
print(f"Lỗi khi tải cấu hình từ {url}: {str(e)}")
if not config_success:
# Tạo file cấu hình đơn giản nếu không tải được
create_simple_config(config_path)
return config_path, model_path
# Tạo file cấu hình đơn giản nếu không tải được
def create_simple_config(config_path):
with open(config_path, 'w') as f:
f.write("""
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: true
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: true
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
""")
print("Đã tạo file cấu hình đơn giản")
# Hàm tạo animation
def make_animation(source_image, driving_video, relative=True, adapt_movement_scale=True):
config_path, checkpoint_path = download_model()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Tải mô hình và cấu hình
generator, kp_detector = load_checkpoints(config_path, checkpoint_path, device=device)
# Đọc source_image và driving_video
source = imageio.imread(source_image)
reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps']
driving = []
try:
for im in reader:
driving.append(im)
except RuntimeError:
pass
reader.close()
# Tiền xử lý
source = resize(source, (256, 256))[..., :3]
driving = [resize(frame, (256, 256))[..., :3] for frame in driving]
# Chuyển đổi thành tensor
source = torch.tensor(source[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
driving = torch.tensor(np.array(driving).astype(np.float32)).permute(0, 3, 1, 2).to(device)
# Trích xuất keypoints
kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[0:1])
# Tạo animation
with torch.no_grad():
predictions = []
for frame_idx in range(driving.shape[0]):
driving_frame = driving[frame_idx:frame_idx+1]
kp_driving = kp_detector(driving_frame)
# Chuẩn hóa keypoints
kp_norm = normalize_kp(
kp_source=kp_source,
kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial,
use_relative_movement=relative,
use_relative_jacobian=relative,
adapt_movement_scale=adapt_movement_scale
)
# Tạo frame
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
# Lưu video kết quả
output_path = f'result_{int(np.random.rand() * 10000)}.mp4'
if os.path.exists(output_path):
os.remove(output_path) # Xóa video nếu tồn tại
# Lưu frames thành video sử dụng imageio
frames = [img_as_ubyte(frame) for frame in predictions]
imageio.mimsave(output_path, frames, fps=fps)
return output_path
# Tải video mẫu
def download_sample_video():
sample_urls = [
"https://github.com/AliaksandrSiarohin/first-order-model/raw/master/driving.mp4",
"https://raw.githubusercontent.com/jiupinjia/stylized-neural-painting/main/sample/driving.mp4"
]
sample_path = "sample_driving.mp4"
for url in sample_urls:
try:
print(f"Đang thử tải video mẫu từ: {url}")
response = requests.get(url, timeout=30)
if response.status_code == 200:
with open(sample_path, 'wb') as f:
f.write(response.content)
if os.path.getsize(sample_path) > 10000: # Kiểm tra kích thước file
return sample_path
except Exception as e:
print(f"Lỗi khi tải video mẫu từ {url}: {str(e)}")
# Nếu không tải được, tạo video đơn giản
create_simple_video(sample_path)
return sample_path
# Tạo video đơn giản nếu không tải được video mẫu
def create_simple_video(output_path):
import cv2
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (256, 256))
# Tạo 100 khung hình với chuyển động đơn giản
for i in range(100):
frame = np.zeros((256, 256, 3), dtype=np.uint8)
# Vẽ khuôn mặt đơn giản chuyển động
x_center = 128 + int(np.sin(i/10) * 20)
y_center = 128 + int(np.cos(i/20) * 10)
# Vẽ khuôn mặt
cv2.circle(frame, (x_center, y_center), 60, (200, 200, 200), -1) # Mặt
cv2.circle(frame, (x_center - 20, y_center - 15), 10, (0, 0, 0), -1) # Mắt trái
cv2.circle(frame, (x_center + 20, y_center - 15), 10, (0, 0, 0), -1) # Mắt phải
# Vẽ miệng
mouth_y = y_center + 20 + int(np.sin(i/5) * 5)
cv2.ellipse(frame, (x_center, mouth_y), (20, 10), 0, 0, 180, (0, 0, 0), -1)
out.write(frame)
out.release()
print("Đã tạo video đơn giản")
# Định nghĩa giao diện Gradio
def animate_fomm(source_image, driving_video_file, relative=True, adapt_scale=True):
if source_image is None:
return None, "Vui lòng tải lên ảnh nguồn."
try:
# Lưu tạm ảnh nguồn
source_path = f"source_image_{int(np.random.rand() * 10000)}.jpg"
source_image.save(source_path)
# Xử lý video tham chiếu
print(f"Type of driving_video: {type(driving_video_file)}")
# Tạo file tạm cho video
driving_path = f"driving_video_{int(np.random.rand() * 10000)}.mp4"
# Kiểm tra nếu đã chọn sử dụng video mẫu
if driving_video_file is None:
# Tải và sử dụng video mẫu
driving_path = download_sample_video()
else:
# Xử lý video được tải lên
if isinstance(driving_video_file, str):
# Nếu là đường dẫn, copy file
if os.path.exists(driving_video_file):
import shutil
shutil.copyfile(driving_video_file, driving_path)
else:
return None, f"Không tìm thấy file video tại đường dẫn: {driving_video_file}"
else:
# Ghi dữ liệu nhị phân vào file
with open(driving_path, 'wb') as f:
f.write(driving_video_file)
# Tạo animation
result_path = make_animation(
source_path,
driving_path,
relative=relative,
adapt_movement_scale=adapt_scale
)
# Xóa file tạm nếu cần
if os.path.exists(source_path) and source_path != "source_image.jpg":
os.remove(source_path)
if os.path.exists(driving_path) and driving_path != "sample_driving.mp4" and driving_path != "driving_video.mp4":
os.remove(driving_path)
return result_path, "Video được tạo thành công!"
except Exception as e:
import traceback
return None, f"Lỗi: {str(e)}\n{traceback.format_exc()}"
# Tạo giao diện Gradio
with gr.Blocks(title="First Order Motion Model - Tạo video người chuyển động") as demo:
gr.Markdown("# First Order Motion Model")
gr.Markdown("Tạo video người chuyển động từ một ảnh tĩnh và video tham chiếu")
with gr.Row():
with gr.Column():
source_image = gr.Image(type="pil", label="Tải lên ảnh nguồn")
# Thêm tùy chọn sử dụng video mẫu
use_sample = gr.Checkbox(label="Sử dụng video mẫu có sẵn", value=True)
# Thay đổi từ gr.Video sang gr.File để xử lý lỗi binary
driving_video_file = gr.File(label="Tải lên video tham chiếu (.mp4)", visible=False)
with gr.Row():
relative = gr.Checkbox(value=True, label="Chuyển động tương đối")
adapt_scale = gr.Checkbox(value=True, label="Điều chỉnh tỷ lệ chuyển động")
submit_btn = gr.Button("Tạo video")
with gr.Column():
output_video = gr.Video(label="Video kết quả")
output_message = gr.Textbox(label="Thông báo", lines=5)
# Xử lý sự kiện khi checkbox được chọn
def toggle_video_upload(use_sample_video):
return gr.update(visible=not use_sample_video)
use_sample.change(fn=toggle_video_upload, inputs=[use_sample], outputs=[driving_video_file])
# Cập nhật hàm xử lý khi nhấn nút
def process_inputs(source_img, use_sample_vid, driving_vid, rel, adapt):
if use_sample_vid:
return animate_fomm(source_img, None, rel, adapt)
else:
return animate_fomm(source_img, driving_vid, rel, adapt)
submit_btn.click(
fn=process_inputs,
inputs=[source_image, use_sample, driving_video_file, relative, adapt_scale],
outputs=[output_video, output_message]
)
gr.Markdown("### Cách sử dụng")
gr.Markdown("1. Tải lên **ảnh nguồn** - ảnh chứa người/đối tượng bạn muốn làm chuyển động")
gr.Markdown("2. Chọn sử dụng video mẫu có sẵn hoặc tải lên video tham chiếu của riêng bạn")
gr.Markdown("3. Nhấn **Tạo video** và chờ kết quả")
gr.Markdown("### Lưu ý")
gr.Markdown("- Ảnh nguồn và video tham chiếu nên có đối tượng tương tự (người với người, mặt với mặt)")
gr.Markdown("- Đối tượng nên ở vị trí tương tự trong cả ảnh nguồn và khung đầu tiên của video tham chiếu")
gr.Markdown("- Quá trình tạo video có thể mất vài phút")
gr.Markdown("- Nếu gặp vấn đề với việc tải lên video, hãy sử dụng video mẫu có sẵn")
demo.launch()