|
import argparse |
|
import torch |
|
import os |
|
from omegaconf import OmegaConf |
|
from tqdm import tqdm |
|
from torchvision import transforms |
|
from torchvision.io import write_video |
|
from einops import rearrange |
|
import torch.distributed as dist |
|
from torch.utils.data import DataLoader, SequentialSampler |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from pipeline import ( |
|
CausalDiffusionInferencePipeline, |
|
CausalInferencePipeline |
|
) |
|
from utils.dataset import TextDataset, TextImagePairDataset |
|
from utils.misc import set_seed |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config_path", type=str, help="Path to the config file") |
|
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder") |
|
parser.add_argument("--data_path", type=str, help="Path to the dataset") |
|
parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt") |
|
parser.add_argument("--output_folder", type=str, help="Output folder") |
|
parser.add_argument("--num_output_frames", type=int, default=21, |
|
help="Number of overlap frames between sliding windows") |
|
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)") |
|
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters") |
|
parser.add_argument("--seed", type=int, default=0, help="Random seed") |
|
parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt") |
|
parser.add_argument("--save_with_index", action="store_true", |
|
help="Whether to save the video using the index or prompt as the filename") |
|
args = parser.parse_args() |
|
|
|
|
|
if "LOCAL_RANK" in os.environ: |
|
dist.init_process_group(backend='nccl') |
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
torch.cuda.set_device(local_rank) |
|
device = torch.device(f"cuda:{local_rank}") |
|
world_size = dist.get_world_size() |
|
set_seed(args.seed + local_rank) |
|
else: |
|
device = torch.device("cuda") |
|
local_rank = 0 |
|
world_size = 1 |
|
set_seed(args.seed) |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
config = OmegaConf.load(args.config_path) |
|
default_config = OmegaConf.load("configs/default_config.yaml") |
|
config = OmegaConf.merge(default_config, config) |
|
|
|
|
|
if hasattr(config, 'denoising_step_list'): |
|
|
|
pipeline = CausalInferencePipeline(config, device=device) |
|
else: |
|
|
|
pipeline = CausalDiffusionInferencePipeline(config, device=device) |
|
|
|
if args.checkpoint_path: |
|
state_dict = torch.load(args.checkpoint_path, map_location="cpu") |
|
pipeline.generator.load_state_dict(state_dict['generator' if not args.use_ema else 'generator_ema']) |
|
|
|
pipeline = pipeline.to(device=device, dtype=torch.bfloat16) |
|
|
|
|
|
if args.i2v: |
|
assert not dist.is_initialized(), "I2V does not support distributed inference yet" |
|
transform = transforms.Compose([ |
|
transforms.Resize((480, 832)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]) |
|
]) |
|
dataset = TextImagePairDataset(args.data_path, transform=transform) |
|
else: |
|
dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path) |
|
num_prompts = len(dataset) |
|
print(f"Number of prompts: {num_prompts}") |
|
|
|
if dist.is_initialized(): |
|
sampler = DistributedSampler(dataset, shuffle=False, drop_last=True) |
|
else: |
|
sampler = SequentialSampler(dataset) |
|
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False) |
|
|
|
|
|
if local_rank == 0: |
|
os.makedirs(args.output_folder, exist_ok=True) |
|
|
|
if dist.is_initialized(): |
|
dist.barrier() |
|
|
|
|
|
def encode(self, videos: torch.Tensor) -> torch.Tensor: |
|
device, dtype = videos[0].device, videos[0].dtype |
|
scale = [self.mean.to(device=device, dtype=dtype), |
|
1.0 / self.std.to(device=device, dtype=dtype)] |
|
output = [ |
|
self.model.encode(u.unsqueeze(0), scale).float().squeeze(0) |
|
for u in videos |
|
] |
|
|
|
output = torch.stack(output, dim=0) |
|
return output |
|
|
|
|
|
for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)): |
|
idx = batch_data['idx'].item() |
|
|
|
|
|
|
|
if isinstance(batch_data, dict): |
|
batch = batch_data |
|
elif isinstance(batch_data, list): |
|
batch = batch_data[0] |
|
|
|
all_video = [] |
|
num_generated_frames = 0 |
|
|
|
if args.i2v: |
|
|
|
prompt = batch['prompts'][0] |
|
prompts = [prompt] * args.num_samples |
|
|
|
|
|
image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16) |
|
|
|
|
|
initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16) |
|
initial_latent = initial_latent.repeat(args.num_samples, 1, 1, 1, 1) |
|
|
|
sampled_noise = torch.randn( |
|
[args.num_samples, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16 |
|
) |
|
else: |
|
|
|
prompt = batch['prompts'][0] |
|
extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None |
|
if extended_prompt is not None: |
|
prompts = [extended_prompt] * args.num_samples |
|
else: |
|
prompts = [prompt] * args.num_samples |
|
initial_latent = None |
|
|
|
sampled_noise = torch.randn( |
|
[args.num_samples, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
video, latents = pipeline.inference( |
|
noise=sampled_noise, |
|
text_prompts=prompts, |
|
return_latents=True, |
|
initial_latent=initial_latent, |
|
) |
|
current_video = rearrange(video, 'b t c h w -> b t h w c').cpu() |
|
all_video.append(current_video) |
|
num_generated_frames += latents.shape[1] |
|
|
|
|
|
video = 255.0 * torch.cat(all_video, dim=1) |
|
|
|
|
|
pipeline.vae.model.clear_cache() |
|
|
|
|
|
if idx < num_prompts: |
|
model = "regular" if not args.use_ema else "ema" |
|
for seed_idx in range(args.num_samples): |
|
|
|
if args.save_with_index: |
|
output_path = os.path.join(args.output_folder, f'{idx}-{seed_idx}_{model}.mp4') |
|
else: |
|
output_path = os.path.join(args.output_folder, f'{prompt[:100]}-{seed_idx}.mp4') |
|
write_video(output_path, video[seed_idx], fps=16) |
|
|