Spaces:
Paused
Paused
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import os | |
import copy | |
import time | |
import inspect | |
import argparse | |
import importlib | |
from configs import VACE_PREPROCCESS_CONFIGS | |
import annotators | |
from annotators.utils import read_image, read_mask, read_video_frames, save_one_video, save_one_image | |
def parse_bboxes(s): | |
bboxes = [] | |
for bbox_str in s.split(): | |
coords = list(map(float, bbox_str.split(','))) | |
if len(coords) != 4: | |
raise ValueError(f"The bounding box requires 4 values, but the input is {len(coords)}.") | |
bboxes.append(coords) | |
return bboxes | |
def validate_args(args): | |
assert args.task in VACE_PREPROCCESS_CONFIGS, f"Unsupport task: [{args.task}]" | |
assert args.video is not None or args.image is not None or args.bbox is not None, "Please specify the video or image or bbox." | |
return args | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description="Data processing carried out by VACE" | |
) | |
parser.add_argument( | |
"--task", | |
type=str, | |
default='', | |
choices=list(VACE_PREPROCCESS_CONFIGS.keys()), | |
help="The task to run.") | |
parser.add_argument( | |
"--video", | |
type=str, | |
default=None, | |
help="The path of the videos to be processed, separated by commas if there are multiple.") | |
parser.add_argument( | |
"--image", | |
type=str, | |
default=None, | |
help="The path of the images to be processed, separated by commas if there are multiple.") | |
parser.add_argument( | |
"--mode", | |
type=str, | |
default=None, | |
help="The specific mode of the task, such as firstframe, mask, bboxtrack, label...") | |
parser.add_argument( | |
"--mask", | |
type=str, | |
default=None, | |
help="The path of the mask images to be processed, separated by commas if there are multiple.") | |
parser.add_argument( | |
"--bbox", | |
type=parse_bboxes, | |
default=None, | |
help="Enter the bounding box, with each four numbers separated by commas (x1, y1, x2, y2), and each pair separated by a space." | |
) | |
parser.add_argument( | |
"--label", | |
type=str, | |
default=None, | |
help="Enter the label to be processed, separated by commas if there are multiple." | |
) | |
parser.add_argument( | |
"--caption", | |
type=str, | |
default=None, | |
help="Enter the caption to be processed." | |
) | |
parser.add_argument( | |
"--direction", | |
type=str, | |
default=None, | |
help="The direction of outpainting includes any combination of left, right, up, down, with multiple combinations separated by commas.") | |
parser.add_argument( | |
"--expand_ratio", | |
type=float, | |
default=None, | |
help="The outpainting's outward expansion ratio.") | |
parser.add_argument( | |
"--expand_num", | |
type=int, | |
default=None, | |
help="The number of frames extended by the extension task.") | |
parser.add_argument( | |
"--maskaug_mode", | |
type=str, | |
default=None, | |
help="The mode of mask augmentation, such as original, original_expand, hull, hull_expand, bbox, bbox_expand.") | |
parser.add_argument( | |
"--maskaug_ratio", | |
type=float, | |
default=None, | |
help="The ratio of mask augmentation.") | |
parser.add_argument( | |
"--pre_save_dir", | |
type=str, | |
default=None, | |
help="The path to save the processed data.") | |
parser.add_argument( | |
"--save_fps", | |
type=int, | |
default=16, | |
help="The fps to save the processed data.") | |
return parser | |
def preproccess(): | |
pass | |
def proccess(): | |
pass | |
def postproccess(): | |
pass | |
def main(args): | |
args = argparse.Namespace(**args) if isinstance(args, dict) else args | |
args = validate_args(args) | |
task_name = args.task | |
video_path = args.video | |
image_path = args.image | |
mask_path = args.mask | |
bbox = args.bbox | |
caption = args.caption | |
label = args.label | |
save_fps = args.save_fps | |
# init class | |
task_cfg = copy.deepcopy(VACE_PREPROCCESS_CONFIGS)[task_name] | |
class_name = task_cfg.pop("NAME") | |
input_params = task_cfg.pop("INPUTS") | |
output_params = task_cfg.pop("OUTPUTS") | |
# input data | |
fps = None | |
input_data = copy.deepcopy(input_params) | |
if 'video' in input_params: | |
assert video_path is not None, "Please set video or check configs" | |
frames, fps, width, height, num_frames = read_video_frames(video_path.split(",")[0], use_type='cv2', info=True) | |
assert frames is not None, "Video read error" | |
input_data['frames'] = frames | |
input_data['video'] = video_path | |
if 'frames' in input_params: | |
assert video_path is not None, "Please set video or check configs" | |
frames, fps, width, height, num_frames = read_video_frames(video_path.split(",")[0], use_type='cv2', info=True) | |
assert frames is not None, "Video read error" | |
input_data['frames'] = frames | |
if 'frames_2' in input_params: | |
# assert video_path is not None and len(video_path.split(",")[1]) >= 2, "Please set two videos or check configs" | |
if len(video_path.split(",")) >= 2: | |
frames, fps, width, height, num_frames = read_video_frames(video_path.split(",")[1], use_type='cv2', info=True) | |
assert frames is not None, "Video read error" | |
input_data['frames_2'] = frames | |
if 'image' in input_params: | |
assert image_path is not None, "Please set image or check configs" | |
image, width, height = read_image(image_path.split(",")[0], use_type='pil', info=True) | |
assert image is not None, "Image read error" | |
input_data['image'] = image | |
if 'image_2' in input_params: | |
# assert image_path is not None and len(image_path.split(",")[1]) >= 2, "Please set two images or check configs" | |
if len(image_path.split(",")) >= 2: | |
image, width, height = read_image(image_path.split(",")[1], use_type='pil', info=True) | |
assert image is not None, "Image read error" | |
input_data['image_2'] = image | |
if 'images' in input_params: | |
assert image_path is not None, "Please set image or check configs" | |
images = [ read_image(path, use_type='pil', info=True)[0] for path in image_path.split(",") ] | |
input_data['images'] = images | |
if 'mask' in input_params: | |
# assert mask_path is not None, "Please set mask or check configs" | |
if mask_path is not None: | |
mask, width, height = read_mask(mask_path.split(",")[0], use_type='pil', info=True) | |
assert mask is not None, "Mask read error" | |
input_data['mask'] = mask | |
if 'bbox' in input_params: | |
# assert bbox is not None, "Please set bbox" | |
if bbox is not None: | |
input_data['bbox'] = bbox[0] if len(bbox) == 1 else bbox | |
if 'label' in input_params: | |
# assert label is not None, "Please set label or check configs" | |
input_data['label'] = label.split(',') if label is not None else None | |
if 'caption' in input_params: | |
# assert caption is not None, "Please set caption or check configs" | |
input_data['caption'] = caption | |
if 'mode' in input_params: | |
input_data['mode'] = args.mode | |
if 'direction' in input_params: | |
if args.direction is not None: | |
input_data['direction'] = args.direction.split(',') | |
if 'expand_ratio' in input_params: | |
if args.expand_ratio is not None: | |
input_data['expand_ratio'] = args.expand_ratio | |
if 'expand_num' in input_params: | |
# assert args.expand_num is not None, "Please set expand_num or check configs" | |
if args.expand_num is not None: | |
input_data['expand_num'] = args.expand_num | |
if 'mask_cfg' in input_params: | |
# assert args.maskaug_mode is not None and args.maskaug_ratio is not None, "Please set maskaug_mode and maskaug_ratio or check configs" | |
if args.maskaug_mode is not None: | |
if args.maskaug_ratio is not None: | |
input_data['mask_cfg'] = {"mode": args.maskaug_mode, "kwargs": {'expand_ratio': args.maskaug_ratio, 'expand_iters': 5}} | |
else: | |
input_data['mask_cfg'] = {"mode": args.maskaug_mode} | |
# processing | |
pre_ins = getattr(annotators, class_name)(cfg=task_cfg, device=f'cuda:{os.getenv("RANK", 0)}') | |
results = pre_ins.forward(**input_data) | |
# output data | |
save_fps = fps if fps is not None else save_fps | |
if args.pre_save_dir is None: | |
pre_save_dir = os.path.join('processed', task_name, time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))) | |
else: | |
pre_save_dir = args.pre_save_dir | |
if not os.path.exists(pre_save_dir): | |
os.makedirs(pre_save_dir) | |
ret_data = {} | |
if 'frames' in output_params: | |
frames = results['frames'] if isinstance(results, dict) else results | |
if frames is not None: | |
save_path = os.path.join(pre_save_dir, f'src_video-{task_name}.mp4') | |
save_one_video(save_path, frames, fps=save_fps) | |
print(f"Save frames result to {save_path}") | |
ret_data['src_video'] = save_path | |
if 'masks' in output_params: | |
frames = results['masks'] if isinstance(results, dict) else results | |
if frames is not None: | |
save_path = os.path.join(pre_save_dir, f'src_mask-{task_name}.mp4') | |
save_one_video(save_path, frames, fps=save_fps) | |
print(f"Save frames result to {save_path}") | |
ret_data['src_mask'] = save_path | |
if 'image' in output_params: | |
ret_image = results['image'] if isinstance(results, dict) else results | |
if ret_image is not None: | |
save_path = os.path.join(pre_save_dir, f'src_ref_image-{task_name}.png') | |
save_one_image(save_path, ret_image, use_type='pil') | |
print(f"Save image result to {save_path}") | |
ret_data['src_ref_images'] = save_path | |
if 'images' in output_params: | |
ret_images = results['images'] if isinstance(results, dict) else results | |
if ret_images is not None: | |
src_ref_images = [] | |
for i, img in enumerate(ret_images): | |
if img is not None: | |
save_path = os.path.join(pre_save_dir, f'src_ref_image_{i}-{task_name}.png') | |
save_one_image(save_path, img, use_type='pil') | |
print(f"Save image result to {save_path}") | |
src_ref_images.append(save_path) | |
if len(src_ref_images) > 0: | |
ret_data['src_ref_images'] = ','.join(src_ref_images) | |
else: | |
ret_data['src_ref_images'] = None | |
if 'mask' in output_params: | |
ret_image = results['mask'] if isinstance(results, dict) else results | |
if ret_image is not None: | |
save_path = os.path.join(pre_save_dir, f'src_mask-{task_name}.png') | |
save_one_image(save_path, ret_image, use_type='pil') | |
print(f"Save mask result to {save_path}") | |
return ret_data | |
if __name__ == "__main__": | |
args = get_parser().parse_args() | |
main(args) | |