File size: 16,925 Bytes
1a447e6
3cb17a3
1a447e6
dff2739
 
 
 
 
 
 
 
5b61f28
 
3cb17a3
5b61f28
3cb17a3
 
 
e77f2ba
dff2739
5b61f28
e77f2ba
 
 
1a447e6
 
3cb17a3
1a447e6
dff2739
 
1a447e6
 
3cb17a3
1a447e6
 
e77f2ba
 
 
 
 
 
 
 
 
 
dff2739
e77f2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0090ad3
5b61f28
1a447e6
5b61f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a447e6
5b61f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6be96f1
5b61f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a447e6
 
 
 
 
 
e77f2ba
1a447e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a842d4
1a447e6
 
 
7a842d4
1a447e6
 
 
7a842d4
1a447e6
 
 
 
 
 
 
 
 
 
 
 
e77f2ba
 
 
 
 
 
 
 
1a447e6
 
 
 
 
 
6be96f1
dff2739
 
e77f2ba
 
 
 
1a447e6
 
3328cca
5b61f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a447e6
dff2739
6be96f1
 
0090ad3
 
dff2739
6be96f1
1a447e6
7a842d4
6be96f1
dff2739
 
 
6be96f1
dff2739
6be96f1
 
 
5b61f28
dff2739
6be96f1
 
 
 
 
 
 
 
 
 
 
 
7491096
1a447e6
 
 
 
 
 
 
4642fe2
6be96f1
 
 
 
 
 
 
1a447e6
0090ad3
dff2739
 
0090ad3
 
1a447e6
 
 
0090ad3
 
 
1a447e6
dff2739
6be96f1
 
 
dff2739
6be96f1
1a447e6
 
 
 
 
0090ad3
4642fe2
0090ad3
 
dff2739
 
 
 
6be96f1
dff2739
 
 
 
6be96f1
dff2739
6be96f1
dff2739
 
0090ad3
 
dff2739
6be96f1
0090ad3
 
 
1a447e6
 
6be96f1
 
1a447e6
7491096
1a447e6
 
 
6be96f1
0090ad3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
import os
import sys
import subprocess
import numpy as np
import torch
import imageio
from skimage.transform import resize
from skimage import img_as_ubyte
import gradio as gr
from PIL import Image
import tempfile
import requests
from io import BytesIO

# Đảm bảo cài đặt các thư viện cần thiết
subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-image==0.19.3"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "face-alignment==1.3.5"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "PyYAML==5.3.1"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "imageio-ffmpeg==0.4.5"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "requests"])

# Cài đặt ffmpeg trong môi trường Ubuntu
os.system("apt-get update && apt-get install -y ffmpeg")

# Clone repo nếu chưa có
if not os.path.exists('first_order_model'):
    subprocess.call(['git', 'clone', 'https://github.com/AliaksandrSiarohin/first-order-model.git'])
    if os.path.exists('first-order-model'):
        os.rename('first-order-model', 'first_order_model')

# Thêm đường dẫn vào PYTHONPATH
sys.path.append('.')
sys.path.append('first_order_model')

# Tạo file helper với hàm load_checkpoints
with open('load_helper.py', 'w') as f:
    f.write("""
import yaml
import torch
from first_order_model.modules.generator import OcclusionAwareGenerator
from first_order_model.modules.keypoint_detector import KPDetector

def load_checkpoints(config_path, checkpoint_path, device='cpu'):
    with open(config_path) as f:
        config = yaml.safe_load(f)
        
    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    generator.to(device)
    
    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                              **config['model_params']['common_params'])
    kp_detector.to(device)
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])
    
    generator.eval()
    kp_detector.eval()
    
    return generator, kp_detector
    
def normalize_kp(kp_source, kp_driving, kp_driving_initial, 
                 use_relative_movement=True, use_relative_jacobian=True, adapt_movement_scale=True):
    from first_order_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
    
    kp_new = {k: v for k, v in kp_driving.items()}
    
    if use_relative_movement:
        kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
        kp_value_diff_abs = torch.abs(kp_value_diff)
        
        if adapt_movement_scale:
            distance = torch.max(kp_value_diff_abs, dim=2, keepdim=True)[0]
            distance = torch.max(distance, dim=1, keepdim=True)[0]
            
            kp_source_diff = torch.abs(kp_source['value'])
            kp_source_max = torch.max(kp_source_diff, dim=2, keepdim=True)[0]
            kp_source_max = torch.max(kp_source_max, dim=1, keepdim=True)[0]
            
            movement_scale = kp_source_max / (distance + 1e-6)
            
            kp_new['value'] = kp_source['value'] + movement_scale * kp_value_diff
        else:
            kp_new['value'] = kp_source['value'] + kp_value_diff
            
    if use_relative_jacobian:
        jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
        kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
        
    return kp_new
""")

# Import hàm load_checkpoints từ file helper
from load_helper import load_checkpoints, normalize_kp

# Tải mô hình từ GitHub hoặc mirrors của first-order-model
def download_model():
    # URLs trực tiếp từ sources khác
    checkpoint_urls = [
        "https://github.com/AliaksandrSiarohin/first-order-model/releases/download/v1.0.0/vox-cpk.pth.tar",
        "https://raw.githubusercontent.com/jiupinjia/stylized-neural-painting/main/checkpoints/vox-cpk.pth.tar",
        "https://github.com/snap-research/articulated-animation/raw/master/checkpoints/vox.pth.tar"
    ]
    
    config_urls = [
        "https://raw.githubusercontent.com/AliaksandrSiarohin/first-order-model/master/config/vox-256.yaml",
        "https://gist.githubusercontent.com/anonymous/raw/vox-256.yaml"
    ]
    
    # Tạo thư mục
    model_path = 'checkpoints/vox-cpk.pth.tar'
    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints', exist_ok=True)
    
    config_path = 'first_order_model/config/vox-256.yaml'
    if not os.path.exists('first_order_model/config'):
        os.makedirs('first_order_model/config', exist_ok=True)
    
    # Tải model checkpoint
    success = False
    for url in checkpoint_urls:
        try:
            print(f"Đang thử tải mô hình từ: {url}")
            response = requests.get(url, stream=True, timeout=30)
            if response.status_code == 200:
                with open(model_path, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                
                # Kiểm tra kích thước file (checkpoint mô hình thường > 100MB)
                if os.path.getsize(model_path) > 100000000:
                    success = True
                    break
        except Exception as e:
            print(f"Lỗi khi tải từ {url}: {str(e)}")
    
    if not success:
        raise Exception("Không thể tải mô hình checkpoint từ bất kỳ nguồn nào")
    
    # Tải file cấu hình
    config_success = False
    for url in config_urls:
        try:
            print(f"Đang thử tải file cấu hình từ: {url}")
            response = requests.get(url, timeout=30)
            if response.status_code == 200:
                with open(config_path, 'wb') as f:
                    f.write(response.content)
                
                if os.path.getsize(config_path) > 1000:
                    config_success = True
                    break
        except Exception as e:
            print(f"Lỗi khi tải cấu hình từ {url}: {str(e)}")
    
    if not config_success:
        # Tạo file cấu hình đơn giản nếu không tải được
        create_simple_config(config_path)
    
    return config_path, model_path

# Tạo file cấu hình đơn giản nếu không tải được
def create_simple_config(config_path):
    with open(config_path, 'w') as f:
        f.write("""
model_params:
  common_params:
    num_kp: 10
    num_channels: 3
    estimate_jacobian: true
  kp_detector_params:
     temperature: 0.1
     block_expansion: 32
     max_features: 1024
     scale_factor: 0.25
     num_blocks: 5
  generator_params:
    block_expansion: 64
    max_features: 512
    num_down_blocks: 2
    num_bottleneck_blocks: 6
    estimate_occlusion_map: true
    dense_motion_params:
      block_expansion: 64
      max_features: 1024
      num_blocks: 5
      scale_factor: 0.25
        """)
    print("Đã tạo file cấu hình đơn giản")

# Hàm tạo animation
def make_animation(source_image, driving_video, relative=True, adapt_movement_scale=True):
    config_path, checkpoint_path = download_model()
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Tải mô hình và cấu hình
    generator, kp_detector = load_checkpoints(config_path, checkpoint_path, device=device)
    
    # Đọc source_image và driving_video
    source = imageio.imread(source_image)
    reader = imageio.get_reader(driving_video)
    fps = reader.get_meta_data()['fps']
    driving = []
    try:
        for im in reader:
            driving.append(im)
    except RuntimeError:
        pass
    reader.close()
    
    # Tiền xử lý
    source = resize(source, (256, 256))[..., :3]
    driving = [resize(frame, (256, 256))[..., :3] for frame in driving]
    
    # Chuyển đổi thành tensor
    source = torch.tensor(source[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
    driving = torch.tensor(np.array(driving).astype(np.float32)).permute(0, 3, 1, 2).to(device)
    
    # Trích xuất keypoints
    kp_source = kp_detector(source)
    kp_driving_initial = kp_detector(driving[0:1])
    
    # Tạo animation
    with torch.no_grad():
        predictions = []
        for frame_idx in range(driving.shape[0]):
            driving_frame = driving[frame_idx:frame_idx+1]
            kp_driving = kp_detector(driving_frame)
            
            # Chuẩn hóa keypoints
            kp_norm = normalize_kp(
                kp_source=kp_source,
                kp_driving=kp_driving,
                kp_driving_initial=kp_driving_initial,
                use_relative_movement=relative,
                use_relative_jacobian=relative,
                adapt_movement_scale=adapt_movement_scale
            )
            
            # Tạo frame
            out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
            predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
    
    # Lưu video kết quả
    output_path = f'result_{int(np.random.rand() * 10000)}.mp4'
    if os.path.exists(output_path):
        os.remove(output_path)  # Xóa video nếu tồn tại
    
    # Lưu frames thành video sử dụng imageio
    frames = [img_as_ubyte(frame) for frame in predictions]
    imageio.mimsave(output_path, frames, fps=fps)
    
    return output_path

# Tải video mẫu
def download_sample_video():
    sample_urls = [
        "https://github.com/AliaksandrSiarohin/first-order-model/raw/master/driving.mp4",
        "https://raw.githubusercontent.com/jiupinjia/stylized-neural-painting/main/sample/driving.mp4"
    ]
    
    sample_path = "sample_driving.mp4"
    
    for url in sample_urls:
        try:
            print(f"Đang thử tải video mẫu từ: {url}")
            response = requests.get(url, timeout=30)
            if response.status_code == 200:
                with open(sample_path, 'wb') as f:
                    f.write(response.content)
                
                if os.path.getsize(sample_path) > 10000:  # Kiểm tra kích thước file
                    return sample_path
        except Exception as e:
            print(f"Lỗi khi tải video mẫu từ {url}: {str(e)}")
    
    # Nếu không tải được, tạo video đơn giản
    create_simple_video(sample_path)
    return sample_path

# Tạo video đơn giản nếu không tải được video mẫu
def create_simple_video(output_path):
    import cv2
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (256, 256))
    
    # Tạo 100 khung hình với chuyển động đơn giản
    for i in range(100):
        frame = np.zeros((256, 256, 3), dtype=np.uint8)
        
        # Vẽ khuôn mặt đơn giản chuyển động
        x_center = 128 + int(np.sin(i/10) * 20)
        y_center = 128 + int(np.cos(i/20) * 10)
        
        # Vẽ khuôn mặt
        cv2.circle(frame, (x_center, y_center), 60, (200, 200, 200), -1)  # Mặt
        cv2.circle(frame, (x_center - 20, y_center - 15), 10, (0, 0, 0), -1)  # Mắt trái
        cv2.circle(frame, (x_center + 20, y_center - 15), 10, (0, 0, 0), -1)  # Mắt phải
        
        # Vẽ miệng
        mouth_y = y_center + 20 + int(np.sin(i/5) * 5)
        cv2.ellipse(frame, (x_center, mouth_y), (20, 10), 0, 0, 180, (0, 0, 0), -1)
        
        out.write(frame)
    
    out.release()
    print("Đã tạo video đơn giản")

# Định nghĩa giao diện Gradio
def animate_fomm(source_image, driving_video_file, relative=True, adapt_scale=True):
    if source_image is None:
        return None, "Vui lòng tải lên ảnh nguồn."
    
    try:
        # Lưu tạm ảnh nguồn
        source_path = f"source_image_{int(np.random.rand() * 10000)}.jpg"
        source_image.save(source_path)
        
        # Xử lý video tham chiếu
        print(f"Type of driving_video: {type(driving_video_file)}")
        
        # Tạo file tạm cho video
        driving_path = f"driving_video_{int(np.random.rand() * 10000)}.mp4"
        
        # Kiểm tra nếu đã chọn sử dụng video mẫu
        if driving_video_file is None:
            # Tải và sử dụng video mẫu
            driving_path = download_sample_video()
        else:
            # Xử lý video được tải lên
            if isinstance(driving_video_file, str):
                # Nếu là đường dẫn, copy file
                if os.path.exists(driving_video_file):
                    import shutil
                    shutil.copyfile(driving_video_file, driving_path)
                else:
                    return None, f"Không tìm thấy file video tại đường dẫn: {driving_video_file}"
            else:
                # Ghi dữ liệu nhị phân vào file
                with open(driving_path, 'wb') as f:
                    f.write(driving_video_file)
        
        # Tạo animation
        result_path = make_animation(
            source_path, 
            driving_path,
            relative=relative,
            adapt_movement_scale=adapt_scale
        )
        
        # Xóa file tạm nếu cần
        if os.path.exists(source_path) and source_path != "source_image.jpg":
            os.remove(source_path)
        
        if os.path.exists(driving_path) and driving_path != "sample_driving.mp4" and driving_path != "driving_video.mp4":
            os.remove(driving_path)
        
        return result_path, "Video được tạo thành công!"
    except Exception as e:
        import traceback
        return None, f"Lỗi: {str(e)}\n{traceback.format_exc()}"

# Tạo giao diện Gradio
with gr.Blocks(title="First Order Motion Model - Tạo video người chuyển động") as demo:
    gr.Markdown("# First Order Motion Model")
    gr.Markdown("Tạo video người chuyển động từ một ảnh tĩnh và video tham chiếu")
    
    with gr.Row():
        with gr.Column():
            source_image = gr.Image(type="pil", label="Tải lên ảnh nguồn")
            
            # Thêm tùy chọn sử dụng video mẫu
            use_sample = gr.Checkbox(label="Sử dụng video mẫu có sẵn", value=True)
            
            # Thay đổi từ gr.Video sang gr.File để xử lý lỗi binary
            driving_video_file = gr.File(label="Tải lên video tham chiếu (.mp4)", visible=False)
            
            with gr.Row():
                relative = gr.Checkbox(value=True, label="Chuyển động tương đối")
                adapt_scale = gr.Checkbox(value=True, label="Điều chỉnh tỷ lệ chuyển động")
            
            submit_btn = gr.Button("Tạo video")
            
        with gr.Column():
            output_video = gr.Video(label="Video kết quả")
            output_message = gr.Textbox(label="Thông báo", lines=5)
    
    # Xử lý sự kiện khi checkbox được chọn
    def toggle_video_upload(use_sample_video):
        return gr.update(visible=not use_sample_video)
    
    use_sample.change(fn=toggle_video_upload, inputs=[use_sample], outputs=[driving_video_file])
    
    # Cập nhật hàm xử lý khi nhấn nút
    def process_inputs(source_img, use_sample_vid, driving_vid, rel, adapt):
        if use_sample_vid:
            return animate_fomm(source_img, None, rel, adapt)
        else:
            return animate_fomm(source_img, driving_vid, rel, adapt)
    
    submit_btn.click(
        fn=process_inputs,
        inputs=[source_image, use_sample, driving_video_file, relative, adapt_scale],
        outputs=[output_video, output_message]
    )
    
    gr.Markdown("### Cách sử dụng")
    gr.Markdown("1. Tải lên **ảnh nguồn** - ảnh chứa người/đối tượng bạn muốn làm chuyển động")
    gr.Markdown("2. Chọn sử dụng video mẫu có sẵn hoặc tải lên video tham chiếu của riêng bạn")
    gr.Markdown("3. Nhấn **Tạo video** và chờ kết quả")
    
    gr.Markdown("### Lưu ý")
    gr.Markdown("- Ảnh nguồn và video tham chiếu nên có đối tượng tương tự (người với người, mặt với mặt)")
    gr.Markdown("- Đối tượng nên ở vị trí tương tự trong cả ảnh nguồn và khung đầu tiên của video tham chiếu")
    gr.Markdown("- Quá trình tạo video có thể mất vài phút")
    gr.Markdown("- Nếu gặp vấn đề với việc tải lên video, hãy sử dụng video mẫu có sẵn")

demo.launch()