File size: 4,567 Bytes
099dc67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5b16fd
 
099dc67
e5b16fd
099dc67
 
e5b16fd
099dc67
 
 
 
 
 
 
e5b16fd
099dc67
 
 
 
e5b16fd
 
 
 
 
 
 
 
 
099dc67
 
 
e5b16fd
 
099dc67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5b16fd
099dc67
900b160
099dc67
 
 
 
e5b16fd
099dc67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7917826
099dc67
 
 
e5b16fd
099dc67
 
e5b16fd
099dc67
 
 
 
 
 
 
e5b16fd
 
 
7917826
 
 
099dc67
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import numpy as np
import json
import imageio
from PIL import Image
from torchvision.transforms import v2
from einops import rearrange
import torchvision
import logging
from config import TEST_DATA_DIR
from camera_utils import Camera, parse_matrix, get_relative_pose

logger = logging.getLogger(__name__)

class VideoProcessor:
    def __init__(self, pipe):
        self.pipe = pipe
        self.default_height = 480
        self.default_width = 832
    
    def crop_and_resize(self, image, height, width):
        """Crop and resize image to match target dimensions"""
        width_img, height_img = image.size
        scale = max(width / width_img, height / height_img)
        image = torchvision.transforms.functional.resize(
            image,
            (round(height_img*scale), round(width_img*scale)),
            interpolation=torchvision.transforms.InterpolationMode.BILINEAR
        )
        return image
    
    def load_video_frames(self, video_path, num_frames=81, height=480, width=832):
        """Load and process video frames"""
        reader = imageio.get_reader(video_path)
        frames = []
        
        # Create frame processor with specified dimensions
        frame_process = v2.Compose([
            v2.CenterCrop(size=(height, width)),
            v2.Resize(size=(height, width), antialias=True),
            v2.ToTensor(),
            v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
        
        for i in range(num_frames):
            try:
                frame = reader.get_data(i)
                frame = Image.fromarray(frame)
                frame = self.crop_and_resize(frame, height, width)
                frame = frame_process(frame)
                frames.append(frame)
            except:
                # If we run out of frames, repeat the last one
                if frames:
                    frames.append(frames[-1])
                else:
                    raise ValueError("Video is too short!")
        
        reader.close()
        
        frames = torch.stack(frames, dim=0)
        frames = rearrange(frames, "T C H W -> C T H W")
        video_tensor = frames.unsqueeze(0)  # Add batch dimension
        
        return video_tensor
    
    def load_camera_trajectory(self, cam_type, num_frames=81):
        """Load camera trajectory for the selected type"""
        tgt_camera_path = "./camera_trajectories/camera_extrinsics.json"
        with open(tgt_camera_path, 'r') as file:
            cam_data = json.load(file)
        
        # Get camera trajectory for selected type
        cam_idx = list(range(num_frames))[::4]  # Sample every 4 frames
        traj = [parse_matrix(cam_data[f"frame{idx}"][f"cam{int(cam_type):02d}"]) for idx in cam_idx]
        traj = np.stack(traj).transpose(0, 2, 1)
        
        c2ws = []
        for c2w in traj:
            c2w = c2w[:, [1, 2, 0, 3]]
            c2w[:3, 1] *= -1.
            c2w[:3, 3] /= 100
            c2ws.append(c2w)
        
        tgt_cam_params = [Camera(cam_param) for cam_param in c2ws]
        relative_poses = []
        for i in range(len(tgt_cam_params)):
            relative_pose = get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]])
            relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1])
        
        pose_embedding = torch.stack(relative_poses, dim=0)  # 21x3x4
        pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
        camera_tensor = pose_embedding.to(torch.bfloat16).unsqueeze(0)  # Add batch dimension
        
        return camera_tensor
    
    def process_video(self, video_path, text_prompt, cam_type, num_frames=81, height=480, width=832, seed=0, num_inference_steps=50, cfg_scale=5.0):
        """Process video through ReCamMaster model"""
        
        # Load video frames
        video_tensor = self.load_video_frames(video_path, num_frames, height, width)
        
        # Load camera trajectory
        camera_tensor = self.load_camera_trajectory(cam_type, num_frames)
        
        # Generate video with ReCamMaster
        video = self.pipe(
            prompt=[text_prompt],
            negative_prompt=["worst quality, low quality, blurry, jittery, distorted"],
            source_video=video_tensor,
            target_camera=camera_tensor,
            height=height,
            width=width,
            num_frames=num_frames,
            cfg_scale=cfg_scale,
            num_inference_steps=num_inference_steps,
            seed=seed,
            tiled=True
        )
        
        return video