+
+***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for:
+- Video object tracking and segmentation with shot changes.
+- Visualized development and data annnotation for video object tracking and segmentation.
+- Object-centric downstream video tasks, such as video inpainting and editing.
+
+
+
+
+
+
+
+## :rocket: Updates
+- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, an inventive project from our lab that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.
+
+- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:!
+
+## Demo
+
+https://user-images.githubusercontent.com/28050374/232842703-8395af24-b13e-4b8e-aafb-e94b61e6c449.MP4
+
+### Multiple Object Tracking and Segmentation (with [XMem](https://github.com/hkchengrex/XMem))
+
+https://user-images.githubusercontent.com/39208339/233035206-0a151004-6461-4deb-b782-d1dbfe691493.mp4
+
+### Video Object Tracking and Segmentation with Shot Changes (with [XMem](https://github.com/hkchengrex/XMem))
+
+https://user-images.githubusercontent.com/30309970/232848349-f5e29e71-2ea4-4529-ac9a-94b9ca1e7055.mp4
+
+### Video Inpainting (with [E2FGVI](https://github.com/MCG-NKU/E2FGVI))
+
+https://user-images.githubusercontent.com/28050374/232959816-07f2826f-d267-4dda-8ae5-a5132173b8f4.mp4
+
+## Get Started
+#### Linux
+```bash
+# Clone the repository:
+git clone https://github.com/gaomingqi/Track-Anything.git
+cd Track-Anything
+
+# Install dependencies:
+pip install -r requirements.txt
+
+# Run the Track-Anything gradio demo.
+python app.py --device cuda:0 --sam_model_type vit_h --port 12212
+```
+
+## Citation
+If you find this work useful for your research or applications, please cite using this BibTeX:
+```bibtex
+@misc{yang2023track,
+ title={Track Anything: Segment Anything Meets Videos},
+ author={Jinyu Yang and Mingqi Gao and Zhe Li and Shang Gao and Fangjing Wang and Feng Zheng},
+ year={2023},
+ eprint={2304.11968},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+## Acknowledgements
+
+The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything), [XMem](https://github.com/hkchengrex/XMem), and [E2FGVI](https://github.com/MCG-NKU/E2FGVI). Thanks for the authors for their efforts.
diff --git a/XMem-s012.pth b/XMem-s012.pth
new file mode 100644
index 0000000000000000000000000000000000000000..043c62f4abf18499fa7ca0a9937d4689b5b695b6
--- /dev/null
+++ b/XMem-s012.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16205ad04bfc55b442bd4d7af894382e09868b35e10721c5afc09a24ea8d72d9
+size 249026057
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..053a5d454e024f1ea019804988493c06398f6569
--- /dev/null
+++ b/app.py
@@ -0,0 +1,665 @@
+import gradio as gr
+import argparse
+import gdown
+import cv2
+import numpy as np
+import os
+import sys
+sys.path.append(sys.path[0]+"/tracker")
+sys.path.append(sys.path[0]+"/tracker/model")
+from track_anything import TrackingAnything
+from track_anything import parse_augment, save_image_to_userfolder, read_image_from_userfolder
+import requests
+import json
+import torchvision
+import torch
+from tools.painter import mask_painter
+import psutil
+import time
+try:
+ from mmcv.cnn import ConvModule
+except:
+ os.system("mim install mmcv")
+
+# download checkpoints
+def download_checkpoint(url, folder, filename):
+ os.makedirs(folder, exist_ok=True)
+ filepath = os.path.join(folder, filename)
+
+ if not os.path.exists(filepath):
+ print("download checkpoints ......")
+ response = requests.get(url, stream=True)
+ with open(filepath, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ if chunk:
+ f.write(chunk)
+
+ print("download successfully!")
+
+ return filepath
+
+def download_checkpoint_from_google_drive(file_id, folder, filename):
+ os.makedirs(folder, exist_ok=True)
+ filepath = os.path.join(folder, filename)
+
+ if not os.path.exists(filepath):
+ print("Downloading checkpoints from Google Drive... tips: If you cannot see the progress bar, please try to download it manuall \
+ and put it in the checkpointes directory. E2FGVI-HQ-CVPR22.pth: https://github.com/MCG-NKU/E2FGVI(E2FGVI-HQ model)")
+ url = f"https://drive.google.com/uc?id={file_id}"
+ gdown.download(url, filepath, quiet=False)
+ print("Downloaded successfully!")
+
+ return filepath
+
+# convert points input to prompt state
+def get_prompt(click_state, click_input):
+ inputs = json.loads(click_input)
+ points = click_state[0]
+ labels = click_state[1]
+ for input in inputs:
+ points.append(input[:2])
+ labels.append(input[2])
+ click_state[0] = points
+ click_state[1] = labels
+ prompt = {
+ "prompt_type":["click"],
+ "input_point":click_state[0],
+ "input_label":click_state[1],
+ "multimask_output":"True",
+ }
+ return prompt
+
+
+
+# extract frames from upload video
+def get_frames_from_video(video_input, video_state):
+ """
+ Args:
+ video_path:str
+ timestamp:float64
+ Return
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
+ """
+ video_path = video_input
+ frames = [] # save image path
+ user_name = time.time()
+ video_state["video_name"] = os.path.split(video_path)[-1]
+ video_state["user_name"] = user_name
+
+ os.makedirs(os.path.join("/tmp/{}/originimages/{}".format(video_state["user_name"], video_state["video_name"])), exist_ok=True)
+ os.makedirs(os.path.join("/tmp/{}/paintedimages/{}".format(video_state["user_name"], video_state["video_name"])), exist_ok=True)
+ operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
+ try:
+ cap = cv2.VideoCapture(video_path)
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ if not cap.isOpened():
+ operation_log = [("No frames extracted, please input video file with '.mp4.' '.mov'.", "Error")]
+ print("No frames extracted, please input video file with '.mp4.' '.mov'.")
+ return None, None, None, None, \
+ None, None, None, None, \
+ None, None, None, None, \
+ None, None, gr.update(visible=True, value=operation_log)
+ image_index = 0
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret == True:
+ current_memory_usage = psutil.virtual_memory().percent
+
+ # try solve memory usage problem, save image to disk instead of memory
+ frames.append(save_image_to_userfolder(video_state, image_index, frame, True))
+ image_index +=1
+ # frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ if current_memory_usage > 90:
+ operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
+ print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
+ break
+ else:
+ break
+
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
+ # except:
+ operation_log = [("read_frame_source:{} error. {}\n".format(video_path, str(e)), "Error")]
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
+ return None, None, None, None, \
+ None, None, None, None, \
+ None, None, None, None, \
+ None, None, gr.update(visible=True, value=operation_log)
+ first_image = read_image_from_userfolder(frames[0])
+ image_size = (first_image.shape[0], first_image.shape[1])
+ # initialize video_state
+ video_state = {
+ "user_name": user_name,
+ "video_name": os.path.split(video_path)[-1],
+ "origin_images": frames,
+ "painted_images": frames.copy(),
+ "masks": [np.zeros((image_size[0], image_size[1]), np.uint8)]*len(frames),
+ "logits": [None]*len(frames),
+ "select_frame_number": 0,
+ "fps": fps
+ }
+ video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
+ model.samcontroler.sam_controler.reset_image()
+ model.samcontroler.sam_controler.set_image(first_image)
+ return video_state, video_info, first_image, gr.update(visible=True, maximum=len(frames), value=1), \
+ gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True, value=operation_log),
+
+def run_example(example):
+ return example
+# get the select frame from gradio slider
+def select_template(image_selection_slider, video_state, interactive_state):
+
+ # images = video_state[1]
+ image_selection_slider -= 1
+ video_state["select_frame_number"] = image_selection_slider
+
+ # once select a new template frame, set the image in sam
+
+ model.samcontroler.sam_controler.reset_image()
+ model.samcontroler.sam_controler.set_image(read_image_from_userfolder(video_state["origin_images"][image_selection_slider]))
+
+ # update the masks when select a new template frame
+ # if video_state["masks"][image_selection_slider] is not None:
+ # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
+ operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
+
+ return read_image_from_userfolder(video_state["painted_images"][image_selection_slider]), video_state, interactive_state, operation_log
+
+# set the tracking end frame
+def get_end_number(track_pause_number_slider, video_state, interactive_state):
+ track_pause_number_slider -= 1
+ interactive_state["track_end_number"] = track_pause_number_slider
+ operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
+
+ return read_image_from_userfolder(video_state["painted_images"][track_pause_number_slider]),interactive_state, operation_log
+
+def get_resize_ratio(resize_ratio_slider, interactive_state):
+ interactive_state["resize_ratio"] = resize_ratio_slider
+
+ return interactive_state
+
+# use sam to get the mask
+def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
+ """
+ Args:
+ template_frame: PIL.Image
+ point_prompt: flag for positive or negative button click
+ click_state: [[points], [labels]]
+ """
+ if point_prompt == "Positive":
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
+ interactive_state["positive_click_times"] += 1
+ else:
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
+ interactive_state["negative_click_times"] += 1
+
+ # prompt for sam model
+ model.samcontroler.sam_controler.reset_image()
+ model.samcontroler.sam_controler.set_image(read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]]))
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
+
+ mask, logit, painted_image = model.first_frame_click(
+ image=read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]]),
+ points=np.array(prompt["input_point"]),
+ labels=np.array(prompt["input_label"]),
+ multimask=prompt["multimask_output"],
+ )
+ video_state["masks"][video_state["select_frame_number"]] = mask
+ video_state["logits"][video_state["select_frame_number"]] = logit
+ video_state["painted_images"][video_state["select_frame_number"]] = save_image_to_userfolder(video_state, index=video_state["select_frame_number"], image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB),type=False)
+
+ operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
+ return painted_image, video_state, interactive_state, operation_log
+
+def add_multi_mask(video_state, interactive_state, mask_dropdown):
+ try:
+ mask = video_state["masks"][video_state["select_frame_number"]]
+ interactive_state["multi_mask"]["masks"].append(mask)
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
+ select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
+
+ operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
+ except:
+ operation_log = [("Please click the left image to generate mask.", "Error"), ("","")]
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
+
+def clear_click(video_state, click_state):
+ click_state = [[],[]]
+ template_frame = read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]])
+ operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
+ return template_frame, click_state, operation_log
+
+def remove_multi_mask(interactive_state, mask_dropdown):
+ interactive_state["multi_mask"]["mask_names"]= []
+ interactive_state["multi_mask"]["masks"] = []
+
+ operation_log = [("",""), ("Remove all mask, please add new masks","Normal")]
+ return interactive_state, gr.update(choices=[],value=[]), operation_log
+
+def show_mask(video_state, interactive_state, mask_dropdown):
+ mask_dropdown.sort()
+ select_frame = read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]])
+
+ for i in range(len(mask_dropdown)):
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
+
+ operation_log = [("",""), ("Select {} for tracking or inpainting".format(mask_dropdown),"Normal")]
+ return select_frame, operation_log
+
+# tracking vos
+def vos_tracking_video(video_state, interactive_state, mask_dropdown):
+ operation_log = [("",""), ("Track the selected masks, and then you can select the masks for inpainting.","Normal")]
+ model.xmem.clear_memory()
+ if interactive_state["track_end_number"]:
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
+ else:
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
+
+ if interactive_state["multi_mask"]["masks"]:
+ if len(mask_dropdown) == 0:
+ mask_dropdown = ["mask_001"]
+ mask_dropdown.sort()
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
+ for i in range(1,len(mask_dropdown)):
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
+ else:
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
+ fps = video_state["fps"]
+
+ # operation error
+ if len(np.unique(template_mask))==1:
+ template_mask[0][0]=1
+ operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
+ # return video_output, video_state, interactive_state, operation_error
+ masks, logits, painted_images_path = model.generator(images=following_frames, template_mask=template_mask, video_state=video_state)
+ # clear GPU memory
+ model.xmem.clear_memory()
+
+ if interactive_state["track_end_number"]:
+ video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
+ video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
+ video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images_path
+ else:
+ video_state["masks"][video_state["select_frame_number"]:] = masks
+ video_state["logits"][video_state["select_frame_number"]:] = logits
+ video_state["painted_images"][video_state["select_frame_number"]:] = painted_images_path
+
+ video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
+ interactive_state["inference_times"] += 1
+
+ print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"],
+ interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
+ interactive_state["positive_click_times"],
+ interactive_state["negative_click_times"]))
+
+ #### shanggao code for mask save
+ if interactive_state["mask_save"]:
+ if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
+ os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0]))
+ i = 0
+ print("save mask")
+ for mask in video_state["masks"]:
+ np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
+ i+=1
+ #### shanggao code for mask save
+ return video_output, video_state, interactive_state, operation_log
+
+
+
+# inpaint
+def inpaint_video(video_state, interactive_state, mask_dropdown):
+ operation_log = [("",""), ("Removed the selected masks.","Normal")]
+
+ # solve memory
+ frames = np.asarray(video_state["origin_images"])
+ fps = video_state["fps"]
+ inpaint_masks = np.asarray(video_state["masks"])
+ if len(mask_dropdown) == 0:
+ mask_dropdown = ["mask_001"]
+ mask_dropdown.sort()
+ # convert mask_dropdown to mask numbers
+ inpaint_mask_numbers = [int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))]
+ # interate through all masks and remove the masks that are not in mask_dropdown
+ unique_masks = np.unique(inpaint_masks)
+ num_masks = len(unique_masks) - 1
+ for i in range(1, num_masks + 1):
+ if i in inpaint_mask_numbers:
+ continue
+ inpaint_masks[inpaint_masks==i] = 0
+ # inpaint for videos
+
+ try:
+ inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
+ video_output = generate_video_from_paintedframes(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps)
+ except:
+ operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
+ inpainted_frames = video_state["origin_images"]
+ video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
+ return video_output, operation_log
+
+
+# generate video after vos inference
+def generate_video_from_frames(frames_path, output_path, fps=30):
+ """
+ Generates a video from a list of frames.
+
+ Args:
+ frames (list of numpy arrays): The frames to include in the video.
+ output_path (str): The path to save the generated video.
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
+ """
+ # height, width, layers = frames[0].shape
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
+ # print(output_path)
+ # for frame in frames:
+ # video.write(frame)
+
+ # video.release()
+ frames = []
+ for file in frames_path:
+ frames.append(read_image_from_userfolder(file))
+ frames = torch.from_numpy(np.asarray(frames))
+ if not os.path.exists(os.path.dirname(output_path)):
+ os.makedirs(os.path.dirname(output_path))
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
+ return output_path
+
+def generate_video_from_paintedframes(frames, output_path, fps=30):
+ """
+ Generates a video from a list of frames.
+
+ Args:
+ frames (list of numpy arrays): The frames to include in the video.
+ output_path (str): The path to save the generated video.
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
+ """
+ # height, width, layers = frames[0].shape
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
+ # print(output_path)
+ # for frame in frames:
+ # video.write(frame)
+
+ # video.release()
+ frames = torch.from_numpy(np.asarray(frames))
+ if not os.path.exists(os.path.dirname(output_path)):
+ os.makedirs(os.path.dirname(output_path))
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
+ return output_path
+
+
+# args, defined in track_anything.py
+args = parse_augment()
+
+# check and download checkpoints if needed
+SAM_checkpoint_dict = {
+ 'vit_h': "sam_vit_h_4b8939.pth",
+ 'vit_l': "sam_vit_l_0b3195.pth",
+ "vit_b": "sam_vit_b_01ec64.pth"
+}
+SAM_checkpoint_url_dict = {
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
+}
+sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
+sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
+xmem_checkpoint = "XMem-s012.pth"
+xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
+e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
+e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
+
+
+folder ="./checkpoints"
+SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
+xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
+e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
+# args.port = 12213
+# args.device = "cuda:8"
+# args.mask_save = True
+
+# initialize sam, xmem, e2fgvi models
+model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
+
+
+title = """
Track-Anything
+ """
+description = """
Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. To use it, simply upload your video, or click one of the examples to load them. Code: Track-Anything If you stuck in unknown errors, please feel free to watch the Tutorial video.
"""
+
+
+with gr.Blocks() as iface:
+ """
+ state for
+ """
+ click_state = gr.State([[],[]])
+ interactive_state = gr.State({
+ "inference_times": 0,
+ "negative_click_times" : 0,
+ "positive_click_times": 0,
+ "mask_save": args.mask_save,
+ "multi_mask": {
+ "mask_names": [],
+ "masks": []
+ },
+ "track_end_number": None,
+ "resize_ratio": 0.6
+ }
+ )
+
+ video_state = gr.State(
+ {
+ "user_name": "",
+ "video_name": "",
+ "origin_images": None,
+ "painted_images": None,
+ "masks": None,
+ "inpaint_masks": None,
+ "logits": None,
+ "select_frame_number": 0,
+ "fps": 30
+ }
+ )
+ gr.Markdown(title)
+ gr.Markdown(description)
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("Test"):
+ # for user video input
+ with gr.Column():
+ with gr.Row(scale=0.4):
+ video_input = gr.Video(autosize=True)
+ with gr.Column():
+ video_info = gr.Textbox(label="Video Info")
+ resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
+ Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
+ resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=0.6, label="Resize ratio", visible=True)
+
+
+ with gr.Row():
+ # put the template frame under the radio button
+ with gr.Column():
+ # extract frames
+ with gr.Column():
+ extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
+
+ # click points settins, negative or positive, mode continuous or single
+ with gr.Row():
+ with gr.Row():
+ point_prompt = gr.Radio(
+ choices=["Positive", "Negative"],
+ value="Positive",
+ label="Point prompt",
+ interactive=True,
+ visible=False)
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
+ clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False).style(height=160)
+ Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
+
+ with gr.Column():
+ run_status = gr.HighlightedText(value=[("Run","Error"),("Status","Normal")], visible=True)
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
+ video_output = gr.Video(autosize=True, visible=False).style(height=360)
+ with gr.Row():
+ tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
+ inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
+ # set example
+ gr.Markdown("## Examples")
+ gr.Examples(
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \
+ "test-sample2.mp4","test-sample13.mp4"]],
+ fn=run_example,
+ inputs=[
+ video_input
+ ],
+ outputs=[video_input],
+ # cache_examples=True,
+ )
+
+ with gr.Tab("Tutorial"):
+ with gr.Column():
+ with gr.Row(scale=0.4):
+ video_demo_operation = gr.Video(autosize=True)
+
+ # set example
+ gr.Markdown("## Operation tutorial video")
+ gr.Examples(
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["huggingface_demo_operation.mp4"]],
+ fn=run_example,
+ inputs=[
+ video_demo_operation
+ ],
+ outputs=[video_demo_operation],
+ # cache_examples=True,
+ )
+
+ # first step: get the video information
+ extract_frames_button.click(
+ fn=get_frames_from_video,
+ inputs=[
+ video_input, video_state
+ ],
+ outputs=[video_state, video_info, template_frame, image_selection_slider,
+ track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button,
+ template_frame, tracking_video_predict_button, video_output, mask_dropdown,
+ remove_mask_button, inpaint_video_predict_button, run_status]
+ )
+
+ # second step: select images from slider
+ image_selection_slider.release(fn=select_template,
+ inputs=[image_selection_slider, video_state, interactive_state],
+ outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image")
+ track_pause_number_slider.release(fn=get_end_number,
+ inputs=[track_pause_number_slider, video_state, interactive_state],
+ outputs=[template_frame, interactive_state, run_status], api_name="end_image")
+ resize_ratio_slider.release(fn=get_resize_ratio,
+ inputs=[resize_ratio_slider, interactive_state],
+ outputs=[interactive_state], api_name="resize_ratio")
+
+ # click select image to get mask using sam
+ template_frame.select(
+ fn=sam_refine,
+ inputs=[video_state, point_prompt, click_state, interactive_state],
+ outputs=[template_frame, video_state, interactive_state, run_status]
+ )
+
+ # add different mask
+ Add_mask_button.click(
+ fn=add_multi_mask,
+ inputs=[video_state, interactive_state, mask_dropdown],
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status]
+ )
+
+ remove_mask_button.click(
+ fn=remove_multi_mask,
+ inputs=[interactive_state, mask_dropdown],
+ outputs=[interactive_state, mask_dropdown, run_status]
+ )
+
+ # tracking video from select image and mask
+ tracking_video_predict_button.click(
+ fn=vos_tracking_video,
+ inputs=[video_state, interactive_state, mask_dropdown],
+ outputs=[video_output, video_state, interactive_state, run_status]
+ )
+
+ # inpaint video from select image and mask
+ inpaint_video_predict_button.click(
+ fn=inpaint_video,
+ inputs=[video_state, interactive_state, mask_dropdown],
+ outputs=[video_output, run_status]
+ )
+
+ # click to get mask
+ mask_dropdown.change(
+ fn=show_mask,
+ inputs=[video_state, interactive_state, mask_dropdown],
+ outputs=[template_frame, run_status]
+ )
+
+ # clear input
+ video_input.clear(
+ lambda: (
+ {
+ "user_name": "",
+ "video_name": "",
+ "origin_images": None,
+ "painted_images": None,
+ "masks": None,
+ "inpaint_masks": None,
+ "logits": None,
+ "select_frame_number": 0,
+ "fps": 30
+ },
+ {
+ "inference_times": 0,
+ "negative_click_times" : 0,
+ "positive_click_times": 0,
+ "mask_save": args.mask_save,
+ "multi_mask": {
+ "mask_names": [],
+ "masks": []
+ },
+ "track_end_number": 0,
+ "resize_ratio": 0.6
+ },
+ [[],[]],
+ None,
+ None,
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
+ gr.update(visible=False), gr.update(visible=True)
+
+ ),
+ [],
+ [
+ video_state,
+ interactive_state,
+ click_state,
+ video_output,
+ template_frame,
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
+ Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
+ ],
+ queue=False,
+ show_progress=False)
+
+ # points clear
+ clear_button_click.click(
+ fn = clear_click,
+ inputs = [video_state, click_state,],
+ outputs = [template_frame,click_state, run_status],
+ )
+iface.queue(concurrency_count=1)
+# iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
+iface.launch(debug=True, enable_queue=True)
\ No newline at end of file
diff --git a/app_save.py b/app_save.py
new file mode 100644
index 0000000000000000000000000000000000000000..1625dff5cd655e01fce51654f1341832b9d72859
--- /dev/null
+++ b/app_save.py
@@ -0,0 +1,381 @@
+import gradio as gr
+from demo import automask_image_app, automask_video_app, sahi_autoseg_app
+import argparse
+import cv2
+import time
+from PIL import Image
+import numpy as np
+import os
+import sys
+sys.path.append(sys.path[0]+"/tracker")
+sys.path.append(sys.path[0]+"/tracker/model")
+from track_anything import TrackingAnything
+from track_anything import parse_augment
+import requests
+import json
+import torchvision
+import torch
+import concurrent.futures
+import queue
+
+def download_checkpoint(url, folder, filename):
+ os.makedirs(folder, exist_ok=True)
+ filepath = os.path.join(folder, filename)
+
+ if not os.path.exists(filepath):
+ print("download checkpoints ......")
+ response = requests.get(url, stream=True)
+ with open(filepath, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ if chunk:
+ f.write(chunk)
+
+ print("download successfully!")
+
+ return filepath
+
+def pause_video(play_state):
+ print("user pause_video")
+ play_state.append(time.time())
+ return play_state
+
+def play_video(play_state):
+ print("user play_video")
+ play_state.append(time.time())
+ return play_state
+
+# convert points input to prompt state
+def get_prompt(click_state, click_input):
+ inputs = json.loads(click_input)
+ points = click_state[0]
+ labels = click_state[1]
+ for input in inputs:
+ points.append(input[:2])
+ labels.append(input[2])
+ click_state[0] = points
+ click_state[1] = labels
+ prompt = {
+ "prompt_type":["click"],
+ "input_point":click_state[0],
+ "input_label":click_state[1],
+ "multimask_output":"True",
+ }
+ return prompt
+
+def get_frames_from_video(video_input, play_state):
+ """
+ Args:
+ video_path:str
+ timestamp:float64
+ Return
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
+ """
+ video_path = video_input
+ # video_name = video_path.split('/')[-1]
+
+ try:
+ timestamp = play_state[1] - play_state[0]
+ except:
+ timestamp = 0
+ frames = []
+ try:
+ cap = cv2.VideoCapture(video_path)
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret == True:
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ else:
+ break
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
+
+ # for index, frame in enumerate(frames):
+ # frames[index] = np.asarray(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
+
+ key_frame_index = int(timestamp * fps)
+ nearest_frame = frames[key_frame_index]
+ frames_split = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
+ # output_path='./seperate.mp4'
+ # torchvision.io.write_video(output_path, frames[1], fps=fps, video_codec="libx264")
+
+ # set image in sam when select the template frame
+ model.samcontroler.sam_controler.set_image(nearest_frame)
+ return frames_split, nearest_frame, nearest_frame, fps
+
+def generate_video_from_frames(frames, output_path, fps=30):
+ """
+ Generates a video from a list of frames.
+
+ Args:
+ frames (list of numpy arrays): The frames to include in the video.
+ output_path (str): The path to save the generated video.
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
+ """
+ # height, width, layers = frames[0].shape
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
+
+ # for frame in frames:
+ # video.write(frame)
+
+ # video.release()
+ frames = torch.from_numpy(np.asarray(frames))
+ output_path='./output.mp4'
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
+ return output_path
+
+def model_reset():
+ model.xmem.clear_memory()
+ return None
+
+def sam_refine(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
+ """
+ Args:
+ template_frame: PIL.Image
+ point_prompt: flag for positive or negative button click
+ click_state: [[points], [labels]]
+ """
+ if point_prompt == "Positive":
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
+ else:
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
+
+ # prompt for sam model
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
+
+ # default value
+ # points = np.array([[evt.index[0],evt.index[1]]])
+ # labels= np.array([1])
+ if len(logit)==0:
+ logit = None
+
+ mask, logit, painted_image = model.first_frame_click(
+ image=origin_frame,
+ points=np.array(prompt["input_point"]),
+ labels=np.array(prompt["input_label"]),
+ multimask=prompt["multimask_output"],
+ )
+ return painted_image, click_state, logit, mask
+
+
+
+def vos_tracking_video(video_state, template_mask,fps,video_input):
+
+ masks, logits, painted_images = model.generator(images=video_state[1], template_mask=template_mask)
+ video_output = generate_video_from_frames(painted_images, output_path="./output.mp4", fps=fps)
+ # image_selection_slider = gr.Slider(minimum=1, maximum=len(video_state[1]), value=1, label="Image Selection", interactive=True)
+ video_name = video_input.split('/')[-1].split('.')[0]
+ result_path = os.path.join('/hhd3/gaoshang/Track-Anything/results/'+video_name)
+ if not os.path.exists(result_path):
+ os.makedirs(result_path)
+ i=0
+ for mask in masks:
+ np.save(os.path.join(result_path,'{:05}.npy'.format(i)), mask)
+ i+=1
+ return video_output, painted_images, masks, logits
+
+def vos_tracking_image(image_selection_slider, painted_images):
+
+ # images = video_state[1]
+ percentage = image_selection_slider / 100
+ select_frame_num = int(percentage * len(painted_images))
+ return painted_images[select_frame_num], select_frame_num
+
+def interactive_correction(video_state, point_prompt, click_state, select_correction_frame, evt: gr.SelectData):
+ """
+ Args:
+ template_frame: PIL.Image
+ point_prompt: flag for positive or negative button click
+ click_state: [[points], [labels]]
+ """
+ refine_image = video_state[1][select_correction_frame]
+ if point_prompt == "Positive":
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
+ else:
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
+
+ # prompt for sam model
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
+ model.samcontroler.seg_again(refine_image)
+ corrected_mask, corrected_logit, corrected_painted_image = model.first_frame_click(
+ image=refine_image,
+ points=np.array(prompt["input_point"]),
+ labels=np.array(prompt["input_label"]),
+ multimask=prompt["multimask_output"],
+ )
+ return corrected_painted_image, [corrected_mask, corrected_logit, corrected_painted_image]
+
+def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images, fps, video_input):
+ model.xmem.clear_memory()
+ # inference the following images
+ following_images = video_state[1][select_correction_frame:]
+ corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, template_mask=corrected_state[0])
+ masks = masks[:select_correction_frame] + corrected_masks
+ logits = logits[:select_correction_frame] + corrected_logits
+ painted_images = painted_images[:select_correction_frame] + corrected_painted_images
+ video_output = generate_video_from_frames(painted_images, output_path="./output.mp4", fps=fps)
+
+ video_name = video_input.split('/')[-1].split('.')[0]
+ result_path = os.path.join('/hhd3/gaoshang/Track-Anything/results/'+video_name)
+ if not os.path.exists(result_path):
+ os.makedirs(result_path)
+ i=0
+ for mask in masks:
+ np.save(os.path.join(result_path,'{:05}.npy'.format(i)), mask)
+ i+=1
+ return video_output, painted_images, logits, masks
+
+# check and download checkpoints if needed
+SAM_checkpoint = "sam_vit_h_4b8939.pth"
+sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
+xmem_checkpoint = "XMem-s012.pth"
+xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
+folder ="./checkpoints"
+SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
+xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
+
+# args, defined in track_anything.py
+args = parse_augment()
+args.port = 12207
+args.device = "cuda:5"
+
+model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
+
+with gr.Blocks() as iface:
+ """
+ state for
+ """
+ state = gr.State([])
+ play_state = gr.State([])
+ video_state = gr.State([[],[],[]])
+ click_state = gr.State([[],[]])
+ logits = gr.State([])
+ masks = gr.State([])
+ painted_images = gr.State([])
+ origin_image = gr.State(None)
+ template_mask = gr.State(None)
+ select_correction_frame = gr.State(None)
+ corrected_state = gr.State([[],[],[]])
+ fps = gr.State([])
+ # video_name = gr.State([])
+ # queue value for image refresh, origin image, mask, logits, painted image
+
+
+
+ with gr.Row():
+
+ # for user video input
+ with gr.Column(scale=1.0):
+ video_input = gr.Video().style(height=720)
+
+ # listen to the user action for play and pause input video
+ video_input.play(fn=play_video, inputs=play_state, outputs=play_state, scroll_to_output=True, show_progress=True)
+ video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state)
+
+
+ with gr.Row(scale=1):
+ # put the template frame under the radio button
+ with gr.Column(scale=0.5):
+ # click points settins, negative or positive, mode continuous or single
+ with gr.Row():
+ with gr.Row(scale=0.5):
+ point_prompt = gr.Radio(
+ choices=["Positive", "Negative"],
+ value="Positive",
+ label="Point Prompt",
+ interactive=True)
+ click_mode = gr.Radio(
+ choices=["Continuous", "Single"],
+ value="Continuous",
+ label="Clicking Mode",
+ interactive=True)
+ with gr.Row(scale=0.5):
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
+ with gr.Column():
+ template_select_button = gr.Button(value="Template select", interactive=True, variant="primary")
+
+
+
+ with gr.Column(scale=0.5):
+
+
+ # for intermedia result check and correction
+ # intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
+ video_output = gr.Video().style(height=360)
+ tracking_video_predict_button = gr.Button(value="Tracking")
+
+ image_output = gr.Image(type="pil", interactive=True, elem_id="image_output").style(height=360)
+ image_selection_slider = gr.Slider(minimum=0, maximum=100, step=0.1, value=0, label="Image Selection", interactive=True)
+ correct_track_button = gr.Button(value="Interactive Correction")
+
+ template_frame.select(
+ fn=sam_refine,
+ inputs=[
+ origin_image, point_prompt, click_state, logits
+ ],
+ outputs=[
+ template_frame, click_state, logits, template_mask
+ ]
+ )
+
+ template_select_button.click(
+ fn=get_frames_from_video,
+ inputs=[
+ video_input,
+ play_state
+ ],
+ # outputs=[video_state, template_frame, origin_image, fps, video_name],
+ outputs=[video_state, template_frame, origin_image, fps],
+ )
+
+ tracking_video_predict_button.click(
+ fn=vos_tracking_video,
+ inputs=[video_state, template_mask, fps, video_input],
+ outputs=[video_output, painted_images, masks, logits]
+ )
+ image_selection_slider.release(fn=vos_tracking_image,
+ inputs=[image_selection_slider, painted_images], outputs=[image_output, select_correction_frame], api_name="select_image")
+ # correction
+ image_output.select(
+ fn=interactive_correction,
+ inputs=[video_state, point_prompt, click_state, select_correction_frame],
+ outputs=[image_output, corrected_state]
+ )
+ correct_track_button.click(
+ fn=correct_track,
+ inputs=[video_state, select_correction_frame, corrected_state, masks, logits, painted_images, fps,video_input],
+ outputs=[video_output, painted_images, logits, masks ]
+ )
+
+
+
+ # clear input
+ video_input.clear(
+ lambda: ([], [], [[], [], []],
+ None, "", "", "", "", "", "", "", [[],[]],
+ None),
+ [],
+ [ state, play_state, video_state,
+ template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
+ select_correction_frame],
+ queue=False,
+ show_progress=False
+ )
+ clear_button_image.click(
+ fn=model_reset
+ )
+ clear_button_clike.click(
+ lambda: ([[],[]]),
+ [],
+ [click_state],
+ queue=False,
+ show_progress=False
+ )
+iface.queue(concurrency_count=1)
+iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
+
+
+
diff --git a/app_test.py b/app_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd10fe77cec552dffba84c6516ec33a6622b6c38
--- /dev/null
+++ b/app_test.py
@@ -0,0 +1,46 @@
+# import gradio as gr
+
+# def update_iframe(slider_value):
+# return f'''
+#
+#
+# '''
+
+# iface = gr.Interface(
+# fn=update_iframe,
+# inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
+# outputs=gr.outputs.HTML(),
+# allow_flagging=False,
+# )
+
+# iface.launch(server_name='0.0.0.0', server_port=12212)
+
+import gradio as gr
+
+
+def change_mask(drop):
+ return gr.update(choices=["hello", "kitty"])
+
+with gr.Blocks() as iface:
+ drop = gr.Dropdown(
+ choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!"
+ )
+ radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
+ multi_drop = gr.Dropdown(
+ ["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl."
+ )
+
+ multi_drop.change(
+ fn=change_mask,
+ inputs = multi_drop,
+ outputs=multi_drop
+ )
+
+iface.launch(server_name='0.0.0.0', server_port=1223)
\ No newline at end of file
diff --git a/assets/avengers.gif b/assets/avengers.gif
new file mode 100644
index 0000000000000000000000000000000000000000..672be575d675a8e3d755640ed370284f024bc890
--- /dev/null
+++ b/assets/avengers.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e07b86ee4cf002b3481c71e2038c03f4420883c3be78220dafbc4b59abfb32d
+size 30038625
diff --git a/assets/demo_version_1.MP4 b/assets/demo_version_1.MP4
new file mode 100644
index 0000000000000000000000000000000000000000..69684a7905c07bfc149ef66aad0d147d9e1010bf
--- /dev/null
+++ b/assets/demo_version_1.MP4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b61b54bc6eb0d0f7416f95aa3cd6a48d850ca7473022ec1aff48310911b0233
+size 27053146
diff --git a/assets/inpainting.gif b/assets/inpainting.gif
new file mode 100644
index 0000000000000000000000000000000000000000..d30fb2551a0fe2b6eabe61d2d1df39e23270e62f
--- /dev/null
+++ b/assets/inpainting.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e99bd697bccaed7a0dded7f00855f222031b7dcefd8f64f22f374fcdab390d2
+size 22228969
diff --git a/assets/poster_demo_version_1.png b/assets/poster_demo_version_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0c4196ac8250d94e215c555cb8cbf13abe061011
Binary files /dev/null and b/assets/poster_demo_version_1.png differ
diff --git a/assets/qingming.mp4 b/assets/qingming.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..60205d6cc3ee087277e096d244e8c6fada6446b4
--- /dev/null
+++ b/assets/qingming.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58b34bbce0bd0a18ab5fc5450d4046e1cfc6bd55c508046695545819d8fc46dc
+size 4483842
diff --git a/assets/track-anything-logo.jpg b/assets/track-anything-logo.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b1f1103d8f5c822e37376be7bcfda176bd508a78
Binary files /dev/null and b/assets/track-anything-logo.jpg differ
diff --git a/checkpoints/E2FGVI-HQ-CVPR22.pth b/checkpoints/E2FGVI-HQ-CVPR22.pth
new file mode 100644
index 0000000000000000000000000000000000000000..79dfff57e7206e37a86bf9a7e8e7306296ba97a5
--- /dev/null
+++ b/checkpoints/E2FGVI-HQ-CVPR22.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afff989d41205598a79ce24630b9c83af4b0a06f45b137979a25937d94c121a5
+size 164535938
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5d4d2129751906128f6db9b37070f41b89ac1a
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,87 @@
+from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor, SahiAutoSegmentation, sahi_sliced_predict
+
+# For image
+
+def automask_image_app(image_path, model_type, points_per_side, points_per_batch, min_area):
+ SegAutoMaskPredictor().image_predict(
+ source=image_path,
+ model_type=model_type, # vit_l, vit_h, vit_b
+ points_per_side=points_per_side,
+ points_per_batch=points_per_batch,
+ min_area=min_area,
+ output_path="output.png",
+ show=False,
+ save=True,
+ )
+ return "output.png"
+
+
+# For video
+
+def automask_video_app(video_path, model_type, points_per_side, points_per_batch, min_area):
+ SegAutoMaskPredictor().video_predict(
+ source=video_path,
+ model_type=model_type, # vit_l, vit_h, vit_b
+ points_per_side=points_per_side,
+ points_per_batch=points_per_batch,
+ min_area=min_area,
+ output_path="output.mp4",
+ )
+ return "output.mp4"
+
+
+# For manuel box and point selection
+
+def manual_app(image_path, model_type, input_point, input_label, input_box, multimask_output, random_color):
+ SegManualMaskPredictor().image_predict(
+ source=image_path,
+ model_type=model_type, # vit_l, vit_h, vit_b
+ input_point=input_point,
+ input_label=input_label,
+ input_box=input_box,
+ multimask_output=multimask_output,
+ random_color=random_color,
+ output_path="output.png",
+ show=False,
+ save=True,
+ )
+ return "output.png"
+
+
+# For sahi sliced prediction
+
+def sahi_autoseg_app(
+ image_path,
+ sam_model_type,
+ detection_model_type,
+ detection_model_path,
+ conf_th,
+ image_size,
+ slice_height,
+ slice_width,
+ overlap_height_ratio,
+ overlap_width_ratio,
+):
+ boxes = sahi_sliced_predict(
+ image_path=image_path,
+ detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision
+ detection_model_path=detection_model_path,
+ conf_th=conf_th,
+ image_size=image_size,
+ slice_height=slice_height,
+ slice_width=slice_width,
+ overlap_height_ratio=overlap_height_ratio,
+ overlap_width_ratio=overlap_width_ratio,
+ )
+
+ SahiAutoSegmentation().predict(
+ source=image_path,
+ model_type=sam_model_type,
+ input_box=boxes,
+ multimask_output=False,
+ random_color=False,
+ show=False,
+ save=True,
+ )
+
+ return "output.png"
diff --git a/images/groceries.jpg b/images/groceries.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..85f791c45610e5a3c230fddb1e712dbc602f79d0
Binary files /dev/null and b/images/groceries.jpg differ
diff --git a/images/mask_painter.png b/images/mask_painter.png
new file mode 100644
index 0000000000000000000000000000000000000000..e27dbf37d56ea25005d8067b7aed0845902adea2
Binary files /dev/null and b/images/mask_painter.png differ
diff --git a/images/painter_input_image.jpg b/images/painter_input_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..deeafdbc1d4ac40426f75ee7395ecd82025d6e95
Binary files /dev/null and b/images/painter_input_image.jpg differ
diff --git a/images/painter_input_mask.jpg b/images/painter_input_mask.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0720afed9caf1e4e8b1864a86a7004c43307d845
Binary files /dev/null and b/images/painter_input_mask.jpg differ
diff --git a/images/painter_output_image.png b/images/painter_output_image.png
new file mode 100644
index 0000000000000000000000000000000000000000..3ffbfaeb3181857f8940ff71e151eff3e1b4eb74
Binary files /dev/null and b/images/painter_output_image.png differ
diff --git a/images/painter_output_image__.png b/images/painter_output_image__.png
new file mode 100644
index 0000000000000000000000000000000000000000..cf39379ff16fa027fe6231c94dde51254ee60783
Binary files /dev/null and b/images/painter_output_image__.png differ
diff --git a/images/point_painter.png b/images/point_painter.png
new file mode 100644
index 0000000000000000000000000000000000000000..c3f40aff6478633b9e0c90375fab9cf79ae3f79d
Binary files /dev/null and b/images/point_painter.png differ
diff --git a/images/point_painter_1.png b/images/point_painter_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..6b1c0facec30ef1a94677c2b1179a12d531d7467
Binary files /dev/null and b/images/point_painter_1.png differ
diff --git a/images/point_painter_2.png b/images/point_painter_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..c9bcb1b1b1125aa8e35656bb2576919588e54423
Binary files /dev/null and b/images/point_painter_2.png differ
diff --git a/images/truck.jpg b/images/truck.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6b98688c3c84981200c06259b8d54820ebf85660
Binary files /dev/null and b/images/truck.jpg differ
diff --git a/images/truck_both.jpg b/images/truck_both.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..53e663f40da9247bee0f3c97fcf964199ed176b3
Binary files /dev/null and b/images/truck_both.jpg differ
diff --git a/images/truck_mask.jpg b/images/truck_mask.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d97832f1138c81e1c855586079caece524af911f
Binary files /dev/null and b/images/truck_mask.jpg differ
diff --git a/images/truck_point.jpg b/images/truck_point.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..23648aa26abffef539b83ee0fbdddb678bfb2fc9
Binary files /dev/null and b/images/truck_point.jpg differ
diff --git a/inpainter/.DS_Store b/inpainter/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..1b89560b276dfab91ed4f11ab4b0a68cffd450f8
Binary files /dev/null and b/inpainter/.DS_Store differ
diff --git a/inpainter/base_inpainter.py b/inpainter/base_inpainter.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f2f3147bc21f5dec52b82271d40ac95e0405817
--- /dev/null
+++ b/inpainter/base_inpainter.py
@@ -0,0 +1,287 @@
+import os
+import glob
+from PIL import Image
+import torch
+import yaml
+import cv2
+import importlib
+import numpy as np
+from tqdm import tqdm
+from inpainter.util.tensor_util import resize_frames, resize_masks
+
+def read_image_from_split(videp_split_path):
+ # if type:
+ image = np.asarray([np.asarray(Image.open(path)) for path in videp_split_path])
+ # else:
+ # image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
+ return image
+
+def save_image_to_userfolder(video_state, index, image, type:bool):
+ if type:
+ image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
+ else:
+ image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
+ cv2.imwrite(image_path, image)
+ return image_path
+class BaseInpainter:
+ def __init__(self, E2FGVI_checkpoint, device) -> None:
+ """
+ E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support)
+ """
+ net = importlib.import_module('inpainter.model.e2fgvi_hq')
+ self.model = net.InpaintGenerator().to(device)
+ self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device))
+ self.model.eval()
+ self.device = device
+ # load configurations
+ with open("inpainter/config/config.yaml", 'r') as stream:
+ config = yaml.safe_load(stream)
+ self.neighbor_stride = config['neighbor_stride']
+ self.num_ref = config['num_ref']
+ self.step = config['step']
+
+ # sample reference frames from the whole video
+ def get_ref_index(self, f, neighbor_ids, length):
+ ref_index = []
+ if self.num_ref == -1:
+ for i in range(0, length, self.step):
+ if i not in neighbor_ids:
+ ref_index.append(i)
+ else:
+ start_idx = max(0, f - self.step * (self.num_ref // 2))
+ end_idx = min(length, f + self.step * (self.num_ref // 2))
+ for i in range(start_idx, end_idx + 1, self.step):
+ if i not in neighbor_ids:
+ if len(ref_index) > self.num_ref:
+ break
+ ref_index.append(i)
+ return ref_index
+
+ def inpaint_efficient(self, frames, masks, num_tcb, num_tca, dilate_radius=15, ratio=1):
+ """
+ Perform Inpainting for video subsets
+ frames: numpy array, T, H, W, 3
+ masks: numpy array, T, H, W
+ num_tcb: constant, number of temporal context before, frames
+ num_tca: constant, number of temporal context after, frames
+ dilate_radius: radius when applying dilation on masks
+ ratio: down-sample ratio
+
+ Output:
+ inpainted_frames: numpy array, T, H, W, 3
+ """
+ assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
+ assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
+
+ # --------------------
+ # pre-processing
+ # --------------------
+ masks = masks.copy()
+ masks = np.clip(masks, 0, 1)
+ kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
+ masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
+ T, H, W = masks.shape
+ masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
+ # size: (w, h)
+ if ratio == 1:
+ size = None
+ binary_masks = masks
+ else:
+ size = [int(W*ratio), int(H*ratio)]
+ size = [si+1 if si%2>0 else si for si in size] # only consider even values
+ # shortest side should be larger than 50
+ if min(size) < 50:
+ ratio = 50. / min(H, W)
+ size = [int(W*ratio), int(H*ratio)]
+ binary_masks = resize_masks(masks, tuple(size))
+ frames = resize_frames(frames, tuple(size)) # T, H, W, 3
+ # frames and binary_masks are numpy arrays
+ h, w = frames.shape[1:3]
+ video_length = T - (num_tca + num_tcb) # real video length
+ # convert to tensor
+ imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
+ masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
+ imgs, masks = imgs.to(self.device), masks.to(self.device)
+ comp_frames = [None] * video_length
+ tcb_imgs = None
+ tca_imgs = None
+ tcb_masks = None
+ tca_masks = None
+ # --------------------
+ # end of pre-processing
+ # --------------------
+
+ # separate tc frames/masks from imgs and masks
+ if num_tcb > 0:
+ tcb_imgs = imgs[:, :num_tcb]
+ tcb_masks = masks[:, :num_tcb]
+ tcb_binary = binary_masks[:num_tcb]
+ if num_tca > 0:
+ tca_imgs = imgs[:, -num_tca:]
+ tca_masks = masks[:, -num_tca:]
+ tca_binary = binary_masks[-num_tca:]
+ end_idx = -num_tca
+ else:
+ end_idx = T
+
+ imgs = imgs[:, num_tcb:end_idx]
+ masks = masks[:, num_tcb:end_idx]
+ binary_masks = binary_masks[num_tcb:end_idx] # only neighbor area are involved
+ frames = frames[num_tcb:end_idx] # only neighbor area are involved
+
+ for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
+ neighbor_ids = [
+ i for i in range(max(0, f - self.neighbor_stride),
+ min(video_length, f + self.neighbor_stride + 1))
+ ]
+ ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
+
+ # selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
+ # selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
+
+ selected_imgs = imgs[:, neighbor_ids]
+ selected_masks = masks[:, neighbor_ids]
+ # pad before
+ if tcb_imgs is not None:
+ selected_imgs = torch.concat([selected_imgs, tcb_imgs], dim=1)
+ selected_masks = torch.concat([selected_masks, tcb_masks], dim=1)
+ # integrate ref frames
+ selected_imgs = torch.concat([selected_imgs, imgs[:, ref_ids]], dim=1)
+ selected_masks = torch.concat([selected_masks, masks[:, ref_ids]], dim=1)
+ # pad after
+ if tca_imgs is not None:
+ selected_imgs = torch.concat([selected_imgs, tca_imgs], dim=1)
+ selected_masks = torch.concat([selected_masks, tca_masks], dim=1)
+
+ with torch.no_grad():
+ masked_imgs = selected_imgs * (1 - selected_masks)
+ mod_size_h = 60
+ mod_size_w = 108
+ h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
+ w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
+ masked_imgs = torch.cat(
+ [masked_imgs, torch.flip(masked_imgs, [3])],
+ 3)[:, :, :, :h + h_pad, :]
+ masked_imgs = torch.cat(
+ [masked_imgs, torch.flip(masked_imgs, [4])],
+ 4)[:, :, :, :, :w + w_pad]
+ pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
+ pred_imgs = pred_imgs[:, :, :h, :w]
+ pred_imgs = (pred_imgs + 1) / 2
+ pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
+ for i in range(len(neighbor_ids)):
+ idx = neighbor_ids[i]
+ img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (
+ 1 - binary_masks[idx])
+ if comp_frames[idx] is None:
+ comp_frames[idx] = img
+ else:
+ comp_frames[idx] = comp_frames[idx].astype(
+ np.float32) * 0.5 + img.astype(np.float32) * 0.5
+ torch.cuda.empty_cache()
+ inpainted_frames = np.stack(comp_frames, 0)
+ return inpainted_frames.astype(np.uint8)
+
+ def inpaint(self, frames_path, masks, dilate_radius=15, ratio=1):
+ """
+ Perform Inpainting for video subsets
+ frames: numpy array, T, H, W, 3
+ masks: numpy array, T, H, W
+ dilate_radius: radius when applying dilation on masks
+ ratio: down-sample ratio
+
+ Output:
+ inpainted_frames: numpy array, T, H, W, 3
+ """
+ # assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
+ assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
+
+ # set interval
+ interval = 45
+ context_range = 10 # for each split, consider its temporal context [-context_range] frames and [context_range] frames
+ # split frames into subsets
+ video_length = len(frames_path)
+ num_splits = video_length // interval
+ id_splits = [[i*interval, (i+1)*interval] for i in range(num_splits)] # id splits
+ # if remaining split > interval/2, add a new split, else, append to the last split
+ if video_length - id_splits[-1][-1] > interval / 2:
+ id_splits.append([num_splits*interval, video_length])
+ else:
+ id_splits[-1][-1] = video_length
+
+ # perform inpainting for each split
+ inpainted_splits = []
+ for id_split in id_splits:
+ video_split_path = frames_path[id_split[0]:id_split[1]]
+ video_split = read_image_from_split(video_split_path)
+ mask_split = masks[id_split[0]:id_split[1]]
+
+ # | id_before | ----- | id_split[0] | ----- | id_split[1] | ----- | id_after |
+ # add temporal context
+ id_before = max(0, id_split[0] - self.step * context_range)
+ try:
+ tcb_frames = np.stack([np.array(Image.open(frames_path[idb])) for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
+ tcb_masks = np.stack([masks[idb] for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
+ num_tcb = len(tcb_frames)
+ except:
+ num_tcb = 0
+ id_after = min(video_length, id_split[1] + self.step * context_range)
+ try:
+ tca_frames = np.stack([np.array(Image.open(frames_path[ida])) for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
+ tca_masks = np.stack([masks[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
+ num_tca = len(tca_frames)
+ except:
+ num_tca = 0
+
+ # concatenate temporal context frames/masks with input frames/masks (for parallel pre-processing)
+ if num_tcb > 0:
+ video_split = np.concatenate([tcb_frames, video_split], 0)
+ mask_split = np.concatenate([tcb_masks, mask_split], 0)
+ if num_tca > 0:
+ video_split = np.concatenate([video_split, tca_frames], 0)
+ mask_split = np.concatenate([mask_split, tca_masks], 0)
+
+ # inpaint each split
+ inpainted_splits.append(self.inpaint_efficient(video_split, mask_split, num_tcb, num_tca, dilate_radius, ratio))
+
+ inpainted_frames = np.concatenate(inpainted_splits, 0)
+ return inpainted_frames.astype(np.uint8)
+
+if __name__ == '__main__':
+
+ frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
+ frame_path.sort()
+ mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png"))
+ mask_path.sort()
+ save_path = '/ssd1/gaomingqi/results/inpainting/parkour'
+
+ if not os.path.exists(save_path):
+ os.mkdir(save_path)
+
+ frames = []
+ masks = []
+ for fid, mid in zip(frame_path, mask_path):
+ frames.append(Image.open(fid).convert('RGB'))
+ masks.append(Image.open(mid).convert('P'))
+
+ frames = np.stack(frames, 0)
+ masks = np.stack(masks, 0)
+
+ # ----------------------------------------------
+ # how to use
+ # ----------------------------------------------
+ # 1/3: set checkpoint and device
+ checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth'
+ device = 'cuda:6'
+ # 2/3: initialise inpainter
+ base_inpainter = BaseInpainter(checkpoint, device)
+ # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
+ # ratio: (0, 1], ratio for down sample, default value is 1
+ inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.01) # numpy array, T, H, W, 3
+ # ----------------------------------------------
+ # end
+ # ----------------------------------------------
+ # save
+ for ti, inpainted_frame in enumerate(inpainted_frames):
+ frame = Image.fromarray(inpainted_frame).convert('RGB')
+ frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
\ No newline at end of file
diff --git a/inpainter/config/config.yaml b/inpainter/config/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ef4c180a74866cf25839f91a7b474bef679ea342
--- /dev/null
+++ b/inpainter/config/config.yaml
@@ -0,0 +1,4 @@
+# config info for E2FGVI
+neighbor_stride: 5
+num_ref: -1
+step: 10
diff --git a/inpainter/model/e2fgvi.py b/inpainter/model/e2fgvi.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea90b61e0c7fe44b1968a2c59592bf50e0426bb0
--- /dev/null
+++ b/inpainter/model/e2fgvi.py
@@ -0,0 +1,350 @@
+''' Towards An End-to-End Framework for Video Inpainting
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from model.modules.flow_comp import SPyNet
+from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
+from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp
+from model.modules.spectral_norm import spectral_norm as _spectral_norm
+
+
+class BaseNetwork(nn.Module):
+ def __init__(self):
+ super(BaseNetwork, self).__init__()
+
+ def print_network(self):
+ if isinstance(self, list):
+ self = self[0]
+ num_params = 0
+ for param in self.parameters():
+ num_params += param.numel()
+ print(
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
+ 'To see the architecture, do print(network).' %
+ (type(self).__name__, num_params / 1000000))
+
+ def init_weights(self, init_type='normal', gain=0.02):
+ '''
+ initialize network's weights
+ init_type: normal | xavier | kaiming | orthogonal
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
+ '''
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('InstanceNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ nn.init.constant_(m.weight.data, 1.0)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
+ or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ nn.init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ nn.init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(
+ 'initialization method [%s] is not implemented' %
+ init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
+
+class Encoder(nn.Module):
+ def __init__(self):
+ super(Encoder, self).__init__()
+ self.group = [1, 2, 4, 8, 1]
+ self.layers = nn.ModuleList([
+ nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ])
+
+ def forward(self, x):
+ bt, c, h, w = x.size()
+ h, w = h // 4, w // 4
+ out = x
+ for i, layer in enumerate(self.layers):
+ if i == 8:
+ x0 = out
+ if i > 8 and i % 2 == 0:
+ g = self.group[(i - 8) // 2]
+ x = x0.view(bt, g, -1, h, w)
+ o = out.view(bt, g, -1, h, w)
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
+ out = layer(out)
+ return out
+
+
+class deconv(nn.Module):
+ def __init__(self,
+ input_channel,
+ output_channel,
+ kernel_size=3,
+ padding=0):
+ super().__init__()
+ self.conv = nn.Conv2d(input_channel,
+ output_channel,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding)
+
+ def forward(self, x):
+ x = F.interpolate(x,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=True)
+ return self.conv(x)
+
+
+class InpaintGenerator(BaseNetwork):
+ def __init__(self, init_weights=True):
+ super(InpaintGenerator, self).__init__()
+ channel = 256
+ hidden = 512
+
+ # encoder
+ self.encoder = Encoder()
+
+ # decoder
+ self.decoder = nn.Sequential(
+ deconv(channel // 2, 128, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ deconv(64, 64, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
+
+ # feature propagation module
+ self.feat_prop_module = BidirectionalPropagation(channel // 2)
+
+ # soft split and soft composition
+ kernel_size = (7, 7)
+ padding = (3, 3)
+ stride = (3, 3)
+ output_size = (60, 108)
+ t2t_params = {
+ 'kernel_size': kernel_size,
+ 'stride': stride,
+ 'padding': padding,
+ 'output_size': output_size
+ }
+ self.ss = SoftSplit(channel // 2,
+ hidden,
+ kernel_size,
+ stride,
+ padding,
+ t2t_param=t2t_params)
+ self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size,
+ stride, padding)
+
+ n_vecs = 1
+ for i, d in enumerate(kernel_size):
+ n_vecs *= int((output_size[i] + 2 * padding[i] -
+ (d - 1) - 1) / stride[i] + 1)
+
+ blocks = []
+ depths = 8
+ num_heads = [4] * depths
+ window_size = [(5, 9)] * depths
+ focal_windows = [(5, 9)] * depths
+ focal_levels = [2] * depths
+ pool_method = "fc"
+
+ for i in range(depths):
+ blocks.append(
+ TemporalFocalTransformerBlock(dim=hidden,
+ num_heads=num_heads[i],
+ window_size=window_size[i],
+ focal_level=focal_levels[i],
+ focal_window=focal_windows[i],
+ n_vecs=n_vecs,
+ t2t_params=t2t_params,
+ pool_method=pool_method))
+ self.transformer = nn.Sequential(*blocks)
+
+ if init_weights:
+ self.init_weights()
+ # Need to initial the weights of MSDeformAttn specifically
+ for m in self.modules():
+ if isinstance(m, SecondOrderDeformableAlignment):
+ m.init_offset()
+
+ # flow completion network
+ self.update_spynet = SPyNet()
+
+ def forward_bidirect_flow(self, masked_local_frames):
+ b, l_t, c, h, w = masked_local_frames.size()
+
+ # compute forward and backward flows of masked frames
+ masked_local_frames = F.interpolate(masked_local_frames.view(
+ -1, c, h, w),
+ scale_factor=1 / 4,
+ mode='bilinear',
+ align_corners=True,
+ recompute_scale_factor=True)
+ masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
+ w // 4)
+ mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
+ pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
+
+ pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
+ w // 4)
+ pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
+ w // 4)
+
+ return pred_flows_forward, pred_flows_backward
+
+ def forward(self, masked_frames, num_local_frames):
+ l_t = num_local_frames
+ b, t, ori_c, ori_h, ori_w = masked_frames.size()
+
+ # normalization before feeding into the flow completion module
+ masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
+ pred_flows = self.forward_bidirect_flow(masked_local_frames)
+
+ # extracting features and performing the feature propagation on local features
+ enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
+ _, c, h, w = enc_feat.size()
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
+ local_feat = self.feat_prop_module(local_feat, pred_flows[0],
+ pred_flows[1])
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
+
+ # content hallucination through stacking multiple temporal focal transformer blocks
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b)
+ trans_feat = self.transformer(trans_feat)
+ trans_feat = self.sc(trans_feat, t)
+ trans_feat = trans_feat.view(b, t, -1, h, w)
+ enc_feat = enc_feat + trans_feat
+
+ # decode frames from features
+ output = self.decoder(enc_feat.view(b * t, c, h, w))
+ output = torch.tanh(output)
+ return output, pred_flows
+
+
+# ######################################################################
+# Discriminator for Temporal Patch GAN
+# ######################################################################
+
+
+class Discriminator(BaseNetwork):
+ def __init__(self,
+ in_channels=3,
+ use_sigmoid=False,
+ use_spectral_norm=True,
+ init_weights=True):
+ super(Discriminator, self).__init__()
+ self.use_sigmoid = use_sigmoid
+ nf = 32
+
+ self.conv = nn.Sequential(
+ spectral_norm(
+ nn.Conv3d(in_channels=in_channels,
+ out_channels=nf * 1,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=1,
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(64, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 1,
+ nf * 2,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(128, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 2,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2)))
+
+ if init_weights:
+ self.init_weights()
+
+ def forward(self, xs):
+ # T, C, H, W = xs.shape (old)
+ # B, T, C, H, W (new)
+ xs_t = torch.transpose(xs, 1, 2)
+ feat = self.conv(xs_t)
+ if self.use_sigmoid:
+ feat = torch.sigmoid(feat)
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
+ return out
+
+
+def spectral_norm(module, mode=True):
+ if mode:
+ return _spectral_norm(module)
+ return module
diff --git a/inpainter/model/e2fgvi_hq.py b/inpainter/model/e2fgvi_hq.py
new file mode 100644
index 0000000000000000000000000000000000000000..8741f231e5bbb0503b72f5604af178131d7e50d1
--- /dev/null
+++ b/inpainter/model/e2fgvi_hq.py
@@ -0,0 +1,350 @@
+''' Towards An End-to-End Framework for Video Inpainting
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from inpainter.model.modules.flow_comp import SPyNet
+from inpainter.model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
+from inpainter.model.modules.tfocal_transformer_hq import TemporalFocalTransformerBlock, SoftSplit, SoftComp
+from inpainter.model.modules.spectral_norm import spectral_norm as _spectral_norm
+
+
+class BaseNetwork(nn.Module):
+ def __init__(self):
+ super(BaseNetwork, self).__init__()
+
+ def print_network(self):
+ if isinstance(self, list):
+ self = self[0]
+ num_params = 0
+ for param in self.parameters():
+ num_params += param.numel()
+ print(
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
+ 'To see the architecture, do print(network).' %
+ (type(self).__name__, num_params / 1000000))
+
+ def init_weights(self, init_type='normal', gain=0.02):
+ '''
+ initialize network's weights
+ init_type: normal | xavier | kaiming | orthogonal
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
+ '''
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('InstanceNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ nn.init.constant_(m.weight.data, 1.0)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
+ or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ nn.init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ nn.init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(
+ 'initialization method [%s] is not implemented' %
+ init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
+
+class Encoder(nn.Module):
+ def __init__(self):
+ super(Encoder, self).__init__()
+ self.group = [1, 2, 4, 8, 1]
+ self.layers = nn.ModuleList([
+ nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ])
+
+ def forward(self, x):
+ bt, c, _, _ = x.size()
+ # h, w = h//4, w//4
+ out = x
+ for i, layer in enumerate(self.layers):
+ if i == 8:
+ x0 = out
+ _, _, h, w = x0.size()
+ if i > 8 and i % 2 == 0:
+ g = self.group[(i - 8) // 2]
+ x = x0.view(bt, g, -1, h, w)
+ o = out.view(bt, g, -1, h, w)
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
+ out = layer(out)
+ return out
+
+
+class deconv(nn.Module):
+ def __init__(self,
+ input_channel,
+ output_channel,
+ kernel_size=3,
+ padding=0):
+ super().__init__()
+ self.conv = nn.Conv2d(input_channel,
+ output_channel,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding)
+
+ def forward(self, x):
+ x = F.interpolate(x,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=True)
+ return self.conv(x)
+
+
+class InpaintGenerator(BaseNetwork):
+ def __init__(self, init_weights=True):
+ super(InpaintGenerator, self).__init__()
+ channel = 256
+ hidden = 512
+
+ # encoder
+ self.encoder = Encoder()
+
+ # decoder
+ self.decoder = nn.Sequential(
+ deconv(channel // 2, 128, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ deconv(64, 64, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
+
+ # feature propagation module
+ self.feat_prop_module = BidirectionalPropagation(channel // 2)
+
+ # soft split and soft composition
+ kernel_size = (7, 7)
+ padding = (3, 3)
+ stride = (3, 3)
+ output_size = (60, 108)
+ t2t_params = {
+ 'kernel_size': kernel_size,
+ 'stride': stride,
+ 'padding': padding
+ }
+ self.ss = SoftSplit(channel // 2,
+ hidden,
+ kernel_size,
+ stride,
+ padding,
+ t2t_param=t2t_params)
+ self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding)
+
+ n_vecs = 1
+ for i, d in enumerate(kernel_size):
+ n_vecs *= int((output_size[i] + 2 * padding[i] -
+ (d - 1) - 1) / stride[i] + 1)
+
+ blocks = []
+ depths = 8
+ num_heads = [4] * depths
+ window_size = [(5, 9)] * depths
+ focal_windows = [(5, 9)] * depths
+ focal_levels = [2] * depths
+ pool_method = "fc"
+
+ for i in range(depths):
+ blocks.append(
+ TemporalFocalTransformerBlock(dim=hidden,
+ num_heads=num_heads[i],
+ window_size=window_size[i],
+ focal_level=focal_levels[i],
+ focal_window=focal_windows[i],
+ n_vecs=n_vecs,
+ t2t_params=t2t_params,
+ pool_method=pool_method))
+ self.transformer = nn.Sequential(*blocks)
+
+ if init_weights:
+ self.init_weights()
+ # Need to initial the weights of MSDeformAttn specifically
+ for m in self.modules():
+ if isinstance(m, SecondOrderDeformableAlignment):
+ m.init_offset()
+
+ # flow completion network
+ self.update_spynet = SPyNet()
+
+ def forward_bidirect_flow(self, masked_local_frames):
+ b, l_t, c, h, w = masked_local_frames.size()
+
+ # compute forward and backward flows of masked frames
+ masked_local_frames = F.interpolate(masked_local_frames.view(
+ -1, c, h, w),
+ scale_factor=1 / 4,
+ mode='bilinear',
+ align_corners=True,
+ recompute_scale_factor=True)
+ masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
+ w // 4)
+ mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
+ pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
+
+ pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
+ w // 4)
+ pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
+ w // 4)
+
+ return pred_flows_forward, pred_flows_backward
+
+ def forward(self, masked_frames, num_local_frames):
+ l_t = num_local_frames
+ b, t, ori_c, ori_h, ori_w = masked_frames.size()
+
+ # normalization before feeding into the flow completion module
+ masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
+ pred_flows = self.forward_bidirect_flow(masked_local_frames)
+
+ # extracting features and performing the feature propagation on local features
+ enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
+ _, c, h, w = enc_feat.size()
+ fold_output_size = (h, w)
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
+ local_feat = self.feat_prop_module(local_feat, pred_flows[0],
+ pred_flows[1])
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
+
+ # content hallucination through stacking multiple temporal focal transformer blocks
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size)
+ trans_feat = self.transformer([trans_feat, fold_output_size])
+ trans_feat = self.sc(trans_feat[0], t, fold_output_size)
+ trans_feat = trans_feat.view(b, t, -1, h, w)
+ enc_feat = enc_feat + trans_feat
+
+ # decode frames from features
+ output = self.decoder(enc_feat.view(b * t, c, h, w))
+ output = torch.tanh(output)
+ return output, pred_flows
+
+
+# ######################################################################
+# Discriminator for Temporal Patch GAN
+# ######################################################################
+
+
+class Discriminator(BaseNetwork):
+ def __init__(self,
+ in_channels=3,
+ use_sigmoid=False,
+ use_spectral_norm=True,
+ init_weights=True):
+ super(Discriminator, self).__init__()
+ self.use_sigmoid = use_sigmoid
+ nf = 32
+
+ self.conv = nn.Sequential(
+ spectral_norm(
+ nn.Conv3d(in_channels=in_channels,
+ out_channels=nf * 1,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=1,
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(64, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 1,
+ nf * 2,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(128, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 2,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2)))
+
+ if init_weights:
+ self.init_weights()
+
+ def forward(self, xs):
+ # T, C, H, W = xs.shape (old)
+ # B, T, C, H, W (new)
+ xs_t = torch.transpose(xs, 1, 2)
+ feat = self.conv(xs_t)
+ if self.use_sigmoid:
+ feat = torch.sigmoid(feat)
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
+ return out
+
+
+def spectral_norm(module, mode=True):
+ if mode:
+ return _spectral_norm(module)
+ return module
diff --git a/inpainter/model/modules/feat_prop.py b/inpainter/model/modules/feat_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ac41286ba5640a191cf0edb47ad01d9ef91b623
--- /dev/null
+++ b/inpainter/model/modules/feat_prop.py
@@ -0,0 +1,149 @@
+"""
+ BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
+"""
+import torch
+import torch.nn as nn
+
+from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
+from mmengine.model import constant_init
+
+from inpainter.model.modules.flow_comp import flow_warp
+
+
+class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
+ """Second-order deformable alignment module."""
+ def __init__(self, *args, **kwargs):
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
+
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Sequential(
+ nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
+ )
+
+ self.init_offset()
+
+ def init_offset(self):
+ constant_init(self.conv_offset[-1], val=0, bias=0)
+
+ def forward(self, x, extra_feat, flow_1, flow_2):
+ extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
+ out = self.conv_offset(extra_feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+
+ # offset
+ offset = self.max_residue_magnitude * torch.tanh(
+ torch.cat((o1, o2), dim=1))
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
+ offset_1 = offset_1 + flow_1.flip(1).repeat(1,
+ offset_1.size(1) // 2, 1,
+ 1)
+ offset_2 = offset_2 + flow_2.flip(1).repeat(1,
+ offset_2.size(1) // 2, 1,
+ 1)
+ offset = torch.cat([offset_1, offset_2], dim=1)
+
+ # mask
+ mask = torch.sigmoid(mask)
+
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+
+
+class BidirectionalPropagation(nn.Module):
+ def __init__(self, channel):
+ super(BidirectionalPropagation, self).__init__()
+ modules = ['backward_', 'forward_']
+ self.deform_align = nn.ModuleDict()
+ self.backbone = nn.ModuleDict()
+ self.channel = channel
+
+ for i, module in enumerate(modules):
+ self.deform_align[module] = SecondOrderDeformableAlignment(
+ 2 * channel, channel, 3, padding=1, deform_groups=16)
+
+ self.backbone[module] = nn.Sequential(
+ nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(channel, channel, 3, 1, 1),
+ )
+
+ self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
+
+ def forward(self, x, flows_backward, flows_forward):
+ """
+ x shape : [b, t, c, h, w]
+ return [b, t, c, h, w]
+ """
+ b, t, c, h, w = x.shape
+ feats = {}
+ feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
+
+ for module_name in ['backward_', 'forward_']:
+
+ feats[module_name] = []
+
+ frame_idx = range(0, t)
+ flow_idx = range(-1, t - 1)
+ mapping_idx = list(range(0, len(feats['spatial'])))
+ mapping_idx += mapping_idx[::-1]
+
+ if 'backward' in module_name:
+ frame_idx = frame_idx[::-1]
+ flows = flows_backward
+ else:
+ flows = flows_forward
+
+ feat_prop = x.new_zeros(b, self.channel, h, w)
+ for i, idx in enumerate(frame_idx):
+ feat_current = feats['spatial'][mapping_idx[idx]]
+
+ if i > 0:
+ flow_n1 = flows[:, flow_idx[i], :, :, :]
+ cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
+
+ # initialize second-order features
+ feat_n2 = torch.zeros_like(feat_prop)
+ flow_n2 = torch.zeros_like(flow_n1)
+ cond_n2 = torch.zeros_like(cond_n1)
+ if i > 1:
+ feat_n2 = feats[module_name][-2]
+ flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
+ flow_n2 = flow_n1 + flow_warp(
+ flow_n2, flow_n1.permute(0, 2, 3, 1))
+ cond_n2 = flow_warp(feat_n2,
+ flow_n2.permute(0, 2, 3, 1))
+
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
+ feat_prop = self.deform_align[module_name](feat_prop, cond,
+ flow_n1,
+ flow_n2)
+
+ feat = [feat_current] + [
+ feats[k][idx]
+ for k in feats if k not in ['spatial', module_name]
+ ] + [feat_prop]
+
+ feat = torch.cat(feat, dim=1)
+ feat_prop = feat_prop + self.backbone[module_name](feat)
+ feats[module_name].append(feat_prop)
+
+ if 'backward' in module_name:
+ feats[module_name] = feats[module_name][::-1]
+
+ outputs = []
+ for i in range(0, t):
+ align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
+ align_feats = torch.cat(align_feats, dim=1)
+ outputs.append(self.fusion(align_feats))
+
+ return torch.stack(outputs, dim=1) + x
diff --git a/inpainter/model/modules/flow_comp.py b/inpainter/model/modules/flow_comp.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3abf2f72a6162e2b420c572c55081c557638c59
--- /dev/null
+++ b/inpainter/model/modules/flow_comp.py
@@ -0,0 +1,450 @@
+import numpy as np
+
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+
+from mmcv.cnn import ConvModule
+from mmengine.runner import load_checkpoint
+
+
+class FlowCompletionLoss(nn.Module):
+ """Flow completion loss"""
+ def __init__(self):
+ super().__init__()
+ self.fix_spynet = SPyNet()
+ for p in self.fix_spynet.parameters():
+ p.requires_grad = False
+
+ self.l1_criterion = nn.L1Loss()
+
+ def forward(self, pred_flows, gt_local_frames):
+ b, l_t, c, h, w = gt_local_frames.size()
+
+ with torch.no_grad():
+ # compute gt forward and backward flows
+ gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
+ scale_factor=1 / 4,
+ mode='bilinear',
+ align_corners=True,
+ recompute_scale_factor=True)
+ gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4)
+ gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
+ gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)
+
+ # calculate loss for flow completion
+ forward_flow_loss = self.l1_criterion(
+ pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
+ backward_flow_loss = self.l1_criterion(
+ pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
+ flow_loss = forward_flow_loss + backward_flow_loss
+
+ return flow_loss
+
+
+class SPyNet(nn.Module):
+ """SPyNet network structure.
+ The difference to the SPyNet in [tof.py] is that
+ 1. more SPyNetBasicModule is used in this version, and
+ 2. no batch normalization is used in this version.
+ Paper:
+ Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
+ Args:
+ pretrained (str): path for pre-trained SPyNet. Default: None.
+ """
+ def __init__(
+ self,
+ use_pretrain=True,
+ pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'
+ ):
+ super().__init__()
+
+ self.basic_module = nn.ModuleList(
+ [SPyNetBasicModule() for _ in range(6)])
+
+ if use_pretrain:
+ if isinstance(pretrained, str):
+ print("load pretrained SPyNet...")
+ load_checkpoint(self, pretrained, strict=True)
+ elif pretrained is not None:
+ raise TypeError('[pretrained] should be str or None, '
+ f'but got {type(pretrained)}.')
+
+ self.register_buffer(
+ 'mean',
+ torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer(
+ 'std',
+ torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def compute_flow(self, ref, supp):
+ """Compute flow from ref to supp.
+ Note that in this function, the images are already resized to a
+ multiple of 32.
+ Args:
+ ref (Tensor): Reference image with shape of (n, 3, h, w).
+ supp (Tensor): Supporting image with shape of (n, 3, h, w).
+ Returns:
+ Tensor: Estimated optical flow: (n, 2, h, w).
+ """
+ n, _, h, w = ref.size()
+
+ # normalize the input images
+ ref = [(ref - self.mean) / self.std]
+ supp = [(supp - self.mean) / self.std]
+
+ # generate downsampled frames
+ for level in range(5):
+ ref.append(
+ F.avg_pool2d(input=ref[-1],
+ kernel_size=2,
+ stride=2,
+ count_include_pad=False))
+ supp.append(
+ F.avg_pool2d(input=supp[-1],
+ kernel_size=2,
+ stride=2,
+ count_include_pad=False))
+ ref = ref[::-1]
+ supp = supp[::-1]
+
+ # flow computation
+ flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
+ for level in range(len(ref)):
+ if level == 0:
+ flow_up = flow
+ else:
+ flow_up = F.interpolate(input=flow,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=True) * 2.0
+
+ # add the residue to the upsampled flow
+ flow = flow_up + self.basic_module[level](torch.cat([
+ ref[level],
+ flow_warp(supp[level],
+ flow_up.permute(0, 2, 3, 1).contiguous(),
+ padding_mode='border'), flow_up
+ ], 1))
+
+ return flow
+
+ def forward(self, ref, supp):
+ """Forward function of SPyNet.
+ This function computes the optical flow from ref to supp.
+ Args:
+ ref (Tensor): Reference image with shape of (n, 3, h, w).
+ supp (Tensor): Supporting image with shape of (n, 3, h, w).
+ Returns:
+ Tensor: Estimated optical flow: (n, 2, h, w).
+ """
+
+ # upsize to a multiple of 32
+ h, w = ref.shape[2:4]
+ w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
+ h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
+ ref = F.interpolate(input=ref,
+ size=(h_up, w_up),
+ mode='bilinear',
+ align_corners=False)
+ supp = F.interpolate(input=supp,
+ size=(h_up, w_up),
+ mode='bilinear',
+ align_corners=False)
+
+ # compute flow, and resize back to the original resolution
+ flow = F.interpolate(input=self.compute_flow(ref, supp),
+ size=(h, w),
+ mode='bilinear',
+ align_corners=False)
+
+ # adjust the flow values
+ flow[:, 0, :, :] *= float(w) / float(w_up)
+ flow[:, 1, :, :] *= float(h) / float(h_up)
+
+ return flow
+
+
+class SPyNetBasicModule(nn.Module):
+ """Basic Module for SPyNet.
+ Paper:
+ Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
+ """
+ def __init__(self):
+ super().__init__()
+
+ self.basic_module = nn.Sequential(
+ ConvModule(in_channels=8,
+ out_channels=32,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')),
+ ConvModule(in_channels=32,
+ out_channels=64,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')),
+ ConvModule(in_channels=64,
+ out_channels=32,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')),
+ ConvModule(in_channels=32,
+ out_channels=16,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')),
+ ConvModule(in_channels=16,
+ out_channels=2,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=None))
+
+ def forward(self, tensor_input):
+ """
+ Args:
+ tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
+ 8 channels contain:
+ [reference image (3), neighbor image (3), initial flow (2)].
+ Returns:
+ Tensor: Refined flow with shape (b, 2, h, w)
+ """
+ return self.basic_module(tensor_input)
+
+
+# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
+def make_colorwheel():
+ """
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+
+ Code follows the original C++ source code of Daniel Scharstein.
+ Code follows the the Matlab source code of Deqing Sun.
+
+ Returns:
+ np.ndarray: Color wheel
+ """
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+ col = col + RY
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+ colorwheel[col:col + YG, 1] = 255
+ col = col + YG
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+ col = col + GC
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
+ colorwheel[col:col + CB, 2] = 255
+ col = col + CB
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+ col = col + BM
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
+ colorwheel[col:col + MR, 0] = 255
+ return colorwheel
+
+
+def flow_uv_to_colors(u, v, convert_to_bgr=False):
+ """
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+
+ Args:
+ u (np.ndarray): Input horizontal flow of shape [H,W]
+ v (np.ndarray): Input vertical flow of shape [H,W]
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u) / np.pi
+ fk = (a + 1) / 2 * (ncols - 1)
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 0
+ f = fk - k0
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1 - f) * col0 + f * col1
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2 - i if convert_to_bgr else i
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
+ return flow_image
+
+
+def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
+ """
+ Expects a two dimensional flow image of shape.
+
+ Args:
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+ u = flow_uv[:, :, 0]
+ v = flow_uv[:, :, 1]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+ return flow_uv_to_colors(u, v, convert_to_bgr)
+
+
+def flow_warp(x,
+ flow,
+ interpolation='bilinear',
+ padding_mode='zeros',
+ align_corners=True):
+ """Warp an image or a feature map with optical flow.
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
+ a two-channel, denoting the width and height relative offsets.
+ Note that the values are not normalized to [-1, 1].
+ interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
+ Default: 'bilinear'.
+ padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Whether align corners. Default: True.
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ if x.size()[-2:] != flow.size()[1:3]:
+ raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
+ f'flow ({flow.size()[1:3]}) are not the same.')
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
+ grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
+ grid.requires_grad = False
+
+ grid_flow = grid + flow
+ # scale grid_flow to [-1,1]
+ grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
+ grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
+ grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
+ output = F.grid_sample(x,
+ grid_flow,
+ mode=interpolation,
+ padding_mode=padding_mode,
+ align_corners=align_corners)
+ return output
+
+
+def initial_mask_flow(mask):
+ """
+ mask 1 indicates valid pixel 0 indicates unknown pixel
+ """
+ B, T, C, H, W = mask.shape
+
+ # calculate relative position
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
+
+ grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
+ abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
+ relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])
+
+ abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
+ relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])
+
+ # calculate the nearest indices
+ pos_up = mask.unsqueeze(3).repeat(
+ 1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
+ relative_pos_y <= H)[None, None, None]
+ nearest_indice_up = pos_up.max(dim=4)[1]
+
+ pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
+ None, None, None] * (relative_pos_y <= H)[None, None, None]
+ nearest_indice_down = (pos_down).max(dim=4)[1]
+
+ pos_left = mask.unsqueeze(4).repeat(
+ 1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
+ relative_pos_x <= W)[None, None, None]
+ nearest_indice_left = (pos_left).max(dim=5)[1]
+
+ pos_right = mask.unsqueeze(4).repeat(
+ 1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
+ relative_pos_x <= W)[None, None, None]
+ nearest_indice_right = (pos_right).max(dim=5)[1]
+
+ # NOTE: IMPORTANT !!! depending on how to use this offset
+ initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
+ initial_offset_down = nearest_indice_down - grid_y[None, None, None]
+
+ initial_offset_left = -(nearest_indice_left -
+ grid_x[None, None, None]).flip(4)
+ initial_offset_right = nearest_indice_right - grid_x[None, None, None]
+
+ # nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
+ # initial_offset_x = nearest_indice_x - grid_x
+
+ # handle the boundary cases
+ final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
+ initial_offset_down > 0) * initial_offset_down
+ final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
+ initial_offset_up < 0) * initial_offset_up
+ final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
+ initial_offset_right > 0) * initial_offset_right
+ final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
+ initial_offset_left < 0) * initial_offset_left
+ zero_offset = torch.zeros_like(final_offset_down)
+ # out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
+ out = torch.cat([
+ zero_offset, final_offset_left, zero_offset, final_offset_right,
+ final_offset_up, zero_offset, final_offset_down, zero_offset
+ ],
+ dim=2)
+
+ return out
diff --git a/inpainter/model/modules/spectral_norm.py b/inpainter/model/modules/spectral_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f38c34e98c03caa28ce0b15a4083215fb7d8e9af
--- /dev/null
+++ b/inpainter/model/modules/spectral_norm.py
@@ -0,0 +1,288 @@
+"""
+Spectral Normalization from https://arxiv.org/abs/1802.05957
+"""
+import torch
+from torch.nn.functional import normalize
+
+
+class SpectralNorm(object):
+ # Invariant before and after each forward call:
+ # u = normalize(W @ v)
+ # NB: At initialization, this invariant is not enforced
+
+ _version = 1
+
+ # At version 1:
+ # made `W` not a buffer,
+ # added `v` as a buffer, and
+ # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
+
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
+ self.name = name
+ self.dim = dim
+ if n_power_iterations <= 0:
+ raise ValueError(
+ 'Expected n_power_iterations to be positive, but '
+ 'got n_power_iterations={}'.format(n_power_iterations))
+ self.n_power_iterations = n_power_iterations
+ self.eps = eps
+
+ def reshape_weight_to_matrix(self, weight):
+ weight_mat = weight
+ if self.dim != 0:
+ # permute dim to front
+ weight_mat = weight_mat.permute(
+ self.dim,
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
+ height = weight_mat.size(0)
+ return weight_mat.reshape(height, -1)
+
+ def compute_weight(self, module, do_power_iteration):
+ # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
+ # updated in power iteration **in-place**. This is very important
+ # because in `DataParallel` forward, the vectors (being buffers) are
+ # broadcast from the parallelized module to each module replica,
+ # which is a new module object created on the fly. And each replica
+ # runs its own spectral norm power iteration. So simply assigning
+ # the updated vectors to the module this function runs on will cause
+ # the update to be lost forever. And the next time the parallelized
+ # module is replicated, the same randomly initialized vectors are
+ # broadcast and used!
+ #
+ # Therefore, to make the change propagate back, we rely on two
+ # important behaviors (also enforced via tests):
+ # 1. `DataParallel` doesn't clone storage if the broadcast tensor
+ # is already on correct device; and it makes sure that the
+ # parallelized module is already on `device[0]`.
+ # 2. If the out tensor in `out=` kwarg has correct shape, it will
+ # just fill in the values.
+ # Therefore, since the same power iteration is performed on all
+ # devices, simply updating the tensors in-place will make sure that
+ # the module replica on `device[0]` will update the _u vector on the
+ # parallized module (by shared storage).
+ #
+ # However, after we update `u` and `v` in-place, we need to **clone**
+ # them before using them to normalize the weight. This is to support
+ # backproping through two forward passes, e.g., the common pattern in
+ # GAN training: loss = D(real) - D(fake). Otherwise, engine will
+ # complain that variables needed to do backward for the first forward
+ # (i.e., the `u` and `v` vectors) are changed in the second forward.
+ weight = getattr(module, self.name + '_orig')
+ u = getattr(module, self.name + '_u')
+ v = getattr(module, self.name + '_v')
+ weight_mat = self.reshape_weight_to_matrix(weight)
+
+ if do_power_iteration:
+ with torch.no_grad():
+ for _ in range(self.n_power_iterations):
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
+ # are the first left and right singular vectors.
+ # This power iteration produces approximations of `u` and `v`.
+ v = normalize(torch.mv(weight_mat.t(), u),
+ dim=0,
+ eps=self.eps,
+ out=v)
+ u = normalize(torch.mv(weight_mat, v),
+ dim=0,
+ eps=self.eps,
+ out=u)
+ if self.n_power_iterations > 0:
+ # See above on why we need to clone
+ u = u.clone()
+ v = v.clone()
+
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
+ weight = weight / sigma
+ return weight
+
+ def remove(self, module):
+ with torch.no_grad():
+ weight = self.compute_weight(module, do_power_iteration=False)
+ delattr(module, self.name)
+ delattr(module, self.name + '_u')
+ delattr(module, self.name + '_v')
+ delattr(module, self.name + '_orig')
+ module.register_parameter(self.name,
+ torch.nn.Parameter(weight.detach()))
+
+ def __call__(self, module, inputs):
+ setattr(
+ module, self.name,
+ self.compute_weight(module, do_power_iteration=module.training))
+
+ def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
+ # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
+ # (the invariant at top of this class) and `u @ W @ v = sigma`.
+ # This uses pinverse in case W^T W is not invertible.
+ v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
+ weight_mat.t(), u.unsqueeze(1)).squeeze(1)
+ return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
+
+ @staticmethod
+ def apply(module, name, n_power_iterations, dim, eps):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ raise RuntimeError(
+ "Cannot register two spectral_norm hooks on "
+ "the same parameter {}".format(name))
+
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
+ weight = module._parameters[name]
+
+ with torch.no_grad():
+ weight_mat = fn.reshape_weight_to_matrix(weight)
+
+ h, w = weight_mat.size()
+ # randomly initialize `u` and `v`
+ u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
+ v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
+
+ delattr(module, fn.name)
+ module.register_parameter(fn.name + "_orig", weight)
+ # We still need to assign weight back as fn.name because all sorts of
+ # things may assume that it exists, e.g., when initializing weights.
+ # However, we can't directly assign as it could be an nn.Parameter and
+ # gets added as a parameter. Instead, we register weight.data as a plain
+ # attribute.
+ setattr(module, fn.name, weight.data)
+ module.register_buffer(fn.name + "_u", u)
+ module.register_buffer(fn.name + "_v", v)
+
+ module.register_forward_pre_hook(fn)
+
+ module._register_state_dict_hook(SpectralNormStateDictHook(fn))
+ module._register_load_state_dict_pre_hook(
+ SpectralNormLoadStateDictPreHook(fn))
+ return fn
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormLoadStateDictPreHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ # For state_dict with version None, (assuming that it has gone through at
+ # least one training forward), we have
+ #
+ # u = normalize(W_orig @ v)
+ # W = W_orig / sigma, where sigma = u @ W_orig @ v
+ #
+ # To compute `v`, we solve `W_orig @ x = u`, and let
+ # v = x / (u @ W_orig @ x) * (W / W_orig).
+ def __call__(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ fn = self.fn
+ version = local_metadata.get('spectral_norm',
+ {}).get(fn.name + '.version', None)
+ if version is None or version < 1:
+ with torch.no_grad():
+ weight_orig = state_dict[prefix + fn.name + '_orig']
+ # weight = state_dict.pop(prefix + fn.name)
+ # sigma = (weight_orig / weight).mean()
+ weight_mat = fn.reshape_weight_to_matrix(weight_orig)
+ u = state_dict[prefix + fn.name + '_u']
+ # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
+ # state_dict[prefix + fn.name + '_v'] = v
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormStateDictHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ def __call__(self, module, state_dict, prefix, local_metadata):
+ if 'spectral_norm' not in local_metadata:
+ local_metadata['spectral_norm'] = {}
+ key = self.fn.name + '.version'
+ if key in local_metadata['spectral_norm']:
+ raise RuntimeError(
+ "Unexpected key in metadata['spectral_norm']: {}".format(key))
+ local_metadata['spectral_norm'][key] = self.fn._version
+
+
+def spectral_norm(module,
+ name='weight',
+ n_power_iterations=1,
+ eps=1e-12,
+ dim=None):
+ r"""Applies spectral normalization to a parameter in the given module.
+
+ .. math::
+ \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
+
+ Spectral normalization stabilizes the training of discriminators (critics)
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
+ power iteration method. If the dimension of the weight tensor is greater
+ than 2, it is reshaped to 2D in power iteration method to get spectral
+ norm. This is implemented via a hook that calculates spectral norm and
+ rescales weight before every :meth:`~Module.forward` call.
+
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
+
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
+
+ Args:
+ module (nn.Module): containing module
+ name (str, optional): name of weight parameter
+ n_power_iterations (int, optional): number of power iterations to
+ calculate spectral norm
+ eps (float, optional): epsilon for numerical stability in
+ calculating norms
+ dim (int, optional): dimension corresponding to number of outputs,
+ the default is ``0``, except for modules that are instances of
+ ConvTranspose{1,2,3}d, when it is ``1``
+
+ Returns:
+ The original module with the spectral norm hook
+
+ Example::
+
+ >>> m = spectral_norm(nn.Linear(20, 40))
+ >>> m
+ Linear(in_features=20, out_features=40, bias=True)
+ >>> m.weight_u.size()
+ torch.Size([40])
+
+ """
+ if dim is None:
+ if isinstance(module,
+ (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d)):
+ dim = 1
+ else:
+ dim = 0
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
+ return module
+
+
+def remove_spectral_norm(module, name='weight'):
+ r"""Removes the spectral normalization reparameterization from a module.
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+
+ Example:
+ >>> m = spectral_norm(nn.Linear(40, 10))
+ >>> remove_spectral_norm(m)
+ """
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ return module
+
+ raise ValueError("spectral_norm of '{}' not found in {}".format(
+ name, module))
+
+
+def use_spectral_norm(module, use_sn=False):
+ if use_sn:
+ return spectral_norm(module)
+ return module
\ No newline at end of file
diff --git a/inpainter/model/modules/tfocal_transformer.py b/inpainter/model/modules/tfocal_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..179508f490f2662331a8817b37513005e98fe4de
--- /dev/null
+++ b/inpainter/model/modules/tfocal_transformer.py
@@ -0,0 +1,536 @@
+"""
+ This code is based on:
+ [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
+ https://github.com/ruiliu-ai/FuseFormer
+ [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
+ https://github.com/yitu-opensource/T2T-ViT
+ [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
+ https://github.com/microsoft/Focal-Transformer
+"""
+
+import math
+from functools import reduce
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SoftSplit(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding,
+ t2t_param):
+ super(SoftSplit, self).__init__()
+ self.kernel_size = kernel_size
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(c_in, hidden)
+
+ self.f_h = int(
+ (t2t_param['output_size'][0] + 2 * t2t_param['padding'][0] -
+ (t2t_param['kernel_size'][0] - 1) - 1) / t2t_param['stride'][0] +
+ 1)
+ self.f_w = int(
+ (t2t_param['output_size'][1] + 2 * t2t_param['padding'][1] -
+ (t2t_param['kernel_size'][1] - 1) - 1) / t2t_param['stride'][1] +
+ 1)
+
+ def forward(self, x, b):
+ feat = self.t2t(x)
+ feat = feat.permute(0, 2, 1)
+ # feat shape [b*t, num_vec, ks*ks*c]
+ feat = self.embedding(feat)
+ # feat shape after embedding [b, t*num_vec, hidden]
+ feat = feat.view(b, -1, self.f_h, self.f_w, feat.size(2))
+ return feat
+
+
+class SoftComp(nn.Module):
+ def __init__(self, channel, hidden, output_size, kernel_size, stride,
+ padding):
+ super(SoftComp, self).__init__()
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(hidden, c_out)
+ self.t2t = torch.nn.Fold(output_size=output_size,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ h, w = output_size
+ self.bias = nn.Parameter(torch.zeros((channel, h, w),
+ dtype=torch.float32),
+ requires_grad=True)
+
+ def forward(self, x, t):
+ b_, _, _, _, c_ = x.shape
+ x = x.view(b_, -1, c_)
+ feat = self.embedding(x)
+ b, _, c = feat.size()
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
+ feat = self.t2t(feat) + self.bias[None]
+ return feat
+
+
+class FusionFeedForward(nn.Module):
+ def __init__(self, d_model, n_vecs=None, t2t_params=None):
+ super(FusionFeedForward, self).__init__()
+ # We set d_ff as a default to 1960
+ hd = 1960
+ self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
+ self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
+ assert t2t_params is not None and n_vecs is not None
+ tp = t2t_params.copy()
+ self.fold = nn.Fold(**tp)
+ del tp['output_size']
+ self.unfold = nn.Unfold(**tp)
+ self.n_vecs = n_vecs
+
+ def forward(self, x):
+ x = self.conv1(x)
+ b, n, c = x.size()
+ normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs,
+ 49).permute(0, 2, 1)
+ x = self.unfold(
+ self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) /
+ self.fold(normalizer)).permute(0, 2, 1).contiguous().view(b, n, c)
+ x = self.conv2(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B*num_windows, T*window_size*window_size, C)
+ """
+ B, T, H, W, C = x.shape
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
+ window_size[1], C)
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
+ -1, T * window_size[0] * window_size[1], C)
+ return windows
+
+
+def window_partition_noreshape(x, window_size):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
+ """
+ B, T, H, W, C = x.shape
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
+ window_size[1], C)
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
+ return windows
+
+
+def window_reverse(windows, window_size, T, H, W):
+ """
+ Args:
+ windows: shape is (num_windows*B, T, window_size, window_size, C)
+ window_size (tuple[int]): Window size
+ T (int): Temporal length of video
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, T, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
+ x = windows.view(B, H // window_size[0], W // window_size[1], T,
+ window_size[0], window_size[1], -1)
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """Temporal focal window attention
+ """
+ def __init__(self, dim, expand_size, window_size, focal_window,
+ focal_level, num_heads, qkv_bias, pool_method):
+
+ super().__init__()
+ self.dim = dim
+ self.expand_size = expand_size
+ self.window_size = window_size # Wh, Ww
+ self.pool_method = pool_method
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ if any(i > 0 for i in self.expand_size) and focal_level > 0:
+ # get mask for rolled k and rolled v
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
+ 0).flatten(0)
+ self.register_buffer("valid_ind_rolled",
+ mask_rolled.nonzero(as_tuple=False).view(-1))
+
+ if pool_method != "none" and focal_level > 1:
+ self.unfolds = nn.ModuleList()
+
+ # build relative position bias between local patch and pooled windows
+ for k in range(focal_level - 1):
+ stride = 2**k
+ kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
+ for i in self.focal_window)
+ # define unfolding operations
+ self.unfolds += [
+ nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=tuple(i // 2 for i in kernel_size))
+ ]
+
+ # define unfolding index for focal_level > 0
+ if k > 0:
+ mask = torch.zeros(kernel_size)
+ mask[(2**k) - 1:, (2**k) - 1:] = 1
+ self.register_buffer(
+ "valid_ind_unfold_{}".format(k),
+ mask.flatten(0).nonzero(as_tuple=False).view(-1))
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x_all, mask_all=None):
+ """
+ Args:
+ x: input features with shape of (B, T, Wh, Ww, C)
+ mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
+
+ output: (nW*B, Wh*Ww, C)
+ """
+ x = x_all[0]
+
+ B, T, nH, nW, C = x.shape
+ qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
+ C).permute(4, 0, 1, 2, 3, 5).contiguous()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
+
+ # partition q map
+ (q_windows, k_windows, v_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
+ contiguous().view(-1, self.num_heads, T * self.window_size[
+ 0] * self.window_size[1], C // self.num_heads), (q, k, v))
+ # q(k/v)_windows shape : [16, 4, 225, 128]
+
+ if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
+ (k_tl, v_tl) = map(
+ lambda t: torch.roll(t,
+ shifts=(-self.expand_size[0], -self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_tr, v_tr) = map(
+ lambda t: torch.roll(t,
+ shifts=(-self.expand_size[0], self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_bl, v_bl) = map(
+ lambda t: torch.roll(t,
+ shifts=(self.expand_size[0], -self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_br, v_br) = map(
+ lambda t: torch.roll(t,
+ shifts=(self.expand_size[0], self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
+ k_rolled = torch.cat(
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
+ 2).permute(0, 3, 1, 2, 4).contiguous()
+ v_rolled = torch.cat(
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
+ 2).permute(0, 3, 1, 2, 4).contiguous()
+
+ # mask out tokens in current window
+ k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
+ v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
+ temp_N = k_rolled.shape[3]
+ k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
+ C // self.num_heads)
+ v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
+ C // self.num_heads)
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
+ else:
+ k_rolled = k_windows
+ v_rolled = v_windows
+
+ # q(k/v)_windows shape : [16, 4, 225, 128]
+ # k_rolled.shape : [16, 4, 5, 165, 128]
+ # ideal expanded window size 153 ((5+2*2)*(9+2*4))
+ # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ k_pooled = []
+ v_pooled = []
+ for k in range(self.focal_level - 1):
+ stride = 2**k
+ x_window_pooled = x_all[k + 1].permute(
+ 0, 3, 1, 2, 4).contiguous() # B, T, nWh, nWw, C
+
+ nWh, nWw = x_window_pooled.shape[2:4]
+
+ # generate mask for pooled windows
+ mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
+ # unfold mask: [nWh*nWw//s//s, k*k, 1]
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
+ 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
+ view(nWh*nWw // stride // stride, -1, 1)
+
+ if k > 0:
+ valid_ind_unfold_k = getattr(
+ self, "valid_ind_unfold_{}".format(k))
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
+
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
+ x_window_masks = x_window_masks.masked_fill(
+ x_window_masks == 0,
+ float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
+ mask_all[k + 1] = x_window_masks
+
+ # generate k and v for pooled windows
+ qkv_pooled = self.qkv(x_window_pooled).reshape(
+ B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
+ 3).view(3, -1, C, nWh,
+ nWw).contiguous()
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[
+ 2] # B*T, C, nWh, nWw
+ # k_pooled_k shape: [5, 512, 4, 4]
+ # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
+
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: self.unfolds[k](t).view(
+ B, T, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 5, 1, 3, 4, 2).contiguous().\
+ view(-1, T, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).contiguous(),
+ (k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
+ )
+ # k_pooled_k shape : [16, 4, 5, 45, 128]
+
+ # select valid unfolding index
+ if k > 0:
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: t[:, :, :, valid_ind_unfold_k],
+ (k_pooled_k, v_pooled_k))
+
+ k_pooled_k = k_pooled_k.view(
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
+ v_pooled_k = v_pooled_k.view(
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
+
+ k_pooled += [k_pooled_k]
+ v_pooled += [v_pooled_k]
+
+ # k_all (v_all) shape : [16, 4, 5 * 210, 128]
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
+ else:
+ k_all = k_rolled
+ v_all = v_rolled
+
+ N = k_all.shape[-2]
+ q_windows = q_windows * self.scale
+ attn = (
+ q_windows @ k_all.transpose(-2, -1)
+ ) # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
+ # T * 45
+ window_area = T * self.window_size[0] * self.window_size[1]
+ # T * 165
+ window_area_rolled = k_rolled.shape[2]
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ offset = window_area_rolled
+ for k in range(self.focal_level - 1):
+ # add attentional mask
+ # mask_all[1] shape [1, 16, T * 45]
+
+ bias = tuple((i + 2**k - 1) for i in self.focal_window)
+
+ if mask_all[k + 1] is not None:
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
+ mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
+
+ offset += T * bias[0] * bias[1]
+
+ if mask_all[0] is not None:
+ nW = mask_all[0].shape[0]
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
+ window_area, N)
+ attn[:, :, :, :, :
+ window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
+ None, :, None, :, :]
+ attn = attn.view(-1, self.num_heads, window_area, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
+ C)
+ x = self.proj(x)
+ return x
+
+
+class TemporalFocalTransformerBlock(nn.Module):
+ r""" Temporal Focal Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ focal_level (int): The number level of focal window.
+ focal_window (int): Window size of each focal window.
+ n_vecs (int): Required for F3N.
+ t2t_params (int): T2T parameters for F3N.
+ """
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(5, 9),
+ mlp_ratio=4.,
+ qkv_bias=True,
+ pool_method="fc",
+ focal_level=2,
+ focal_window=(5, 9),
+ norm_layer=nn.LayerNorm,
+ n_vecs=None,
+ t2t_params=None):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.expand_size = tuple(i // 2 for i in window_size) # TODO
+ self.mlp_ratio = mlp_ratio
+ self.pool_method = pool_method
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ self.window_size_glo = self.window_size
+
+ self.pool_layers = nn.ModuleList()
+ if self.pool_method != "none":
+ for k in range(self.focal_level - 1):
+ window_size_glo = tuple(
+ math.floor(i / (2**k)) for i in self.window_size_glo)
+ self.pool_layers.append(
+ nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
+ self.pool_layers[-1].weight.data.fill_(
+ 1. / (window_size_glo[0] * window_size_glo[1]))
+ self.pool_layers[-1].bias.data.fill_(0)
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = WindowAttention(dim,
+ expand_size=self.expand_size,
+ window_size=self.window_size,
+ focal_window=focal_window,
+ focal_level=focal_level,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ pool_method=pool_method)
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
+
+ def forward(self, x):
+ B, T, H, W, C = x.shape
+
+ shortcut = x
+ x = self.norm1(x)
+
+ shifted_x = x
+
+ x_windows_all = [shifted_x]
+ x_window_masks_all = [None]
+
+ # partition windows tuple(i // 2 for i in window_size)
+ if self.focal_level > 1 and self.pool_method != "none":
+ # if we add coarser granularity and the pool method is not none
+ for k in range(self.focal_level - 1):
+ window_size_glo = tuple(
+ math.floor(i / (2**k)) for i in self.window_size_glo)
+ pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
+ pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
+ H_pool = pooled_h * window_size_glo[0]
+ W_pool = pooled_w * window_size_glo[1]
+
+ x_level_k = shifted_x
+ # trim or pad shifted_x depending on the required size
+ if H > H_pool:
+ trim_t = (H - H_pool) // 2
+ trim_b = H - H_pool - trim_t
+ x_level_k = x_level_k[:, :, trim_t:-trim_b]
+ elif H < H_pool:
+ pad_t = (H_pool - H) // 2
+ pad_b = H_pool - H - pad_t
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
+
+ if W > W_pool:
+ trim_l = (W - W_pool) // 2
+ trim_r = W - W_pool - trim_l
+ x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
+ elif W < W_pool:
+ pad_l = (W_pool - W) // 2
+ pad_r = W_pool - W - pad_l
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
+
+ x_windows_noreshape = window_partition_noreshape(
+ x_level_k.contiguous(), window_size_glo
+ ) # B, nw, nw, T, window_size, window_size, C
+ nWh, nWw = x_windows_noreshape.shape[1:3]
+ x_windows_noreshape = x_windows_noreshape.view(
+ B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
+ C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
+ x_windows_pooled = self.pool_layers[k](
+ x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
+
+ x_windows_all += [x_windows_pooled]
+ x_window_masks_all += [None]
+
+ attn_windows = self.attn(
+ x_windows_all,
+ mask_all=x_window_masks_all) # nW*B, T*window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, T, self.window_size[0],
+ self.window_size[1], C)
+ shifted_x = window_reverse(attn_windows, self.window_size, T, H,
+ W) # B T H' W' C
+
+ # FFN
+ x = shortcut + shifted_x
+ y = self.norm2(x)
+ x = x + self.mlp(y.view(B, T * H * W, C)).view(B, T, H, W, C)
+
+ return x
diff --git a/inpainter/model/modules/tfocal_transformer_hq.py b/inpainter/model/modules/tfocal_transformer_hq.py
new file mode 100644
index 0000000000000000000000000000000000000000..efabefb81bb5e7db60154c5eac770afa39798db3
--- /dev/null
+++ b/inpainter/model/modules/tfocal_transformer_hq.py
@@ -0,0 +1,567 @@
+"""
+ This code is based on:
+ [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
+ https://github.com/ruiliu-ai/FuseFormer
+ [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
+ https://github.com/yitu-opensource/T2T-ViT
+ [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
+ https://github.com/microsoft/Focal-Transformer
+"""
+
+import math
+from functools import reduce
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SoftSplit(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding,
+ t2t_param):
+ super(SoftSplit, self).__init__()
+ self.kernel_size = kernel_size
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(c_in, hidden)
+
+ self.t2t_param = t2t_param
+
+ def forward(self, x, b, output_size):
+ f_h = int((output_size[0] + 2 * self.t2t_param['padding'][0] -
+ (self.t2t_param['kernel_size'][0] - 1) - 1) /
+ self.t2t_param['stride'][0] + 1)
+ f_w = int((output_size[1] + 2 * self.t2t_param['padding'][1] -
+ (self.t2t_param['kernel_size'][1] - 1) - 1) /
+ self.t2t_param['stride'][1] + 1)
+
+ feat = self.t2t(x)
+ feat = feat.permute(0, 2, 1)
+ # feat shape [b*t, num_vec, ks*ks*c]
+ feat = self.embedding(feat)
+ # feat shape after embedding [b, t*num_vec, hidden]
+ feat = feat.view(b, -1, f_h, f_w, feat.size(2))
+ return feat
+
+
+class SoftComp(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
+ super(SoftComp, self).__init__()
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(hidden, c_out)
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.bias_conv = nn.Conv2d(channel,
+ channel,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ # TODO upsample conv
+ # self.bias_conv = nn.Conv2d()
+ # self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True)
+
+ def forward(self, x, t, output_size):
+ b_, _, _, _, c_ = x.shape
+ x = x.view(b_, -1, c_)
+ feat = self.embedding(x)
+ b, _, c = feat.size()
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
+ feat = F.fold(feat,
+ output_size=output_size,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding)
+ feat = self.bias_conv(feat)
+ return feat
+
+
+class FusionFeedForward(nn.Module):
+ def __init__(self, d_model, n_vecs=None, t2t_params=None):
+ super(FusionFeedForward, self).__init__()
+ # We set d_ff as a default to 1960
+ hd = 1960
+ self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
+ self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
+ assert t2t_params is not None and n_vecs is not None
+ self.t2t_params = t2t_params
+
+ def forward(self, x, output_size):
+ n_vecs = 1
+ for i, d in enumerate(self.t2t_params['kernel_size']):
+ n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
+ (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
+
+ x = self.conv1(x)
+ b, n, c = x.size()
+ normalizer = x.new_ones(b, n, 49).view(-1, n_vecs, 49).permute(0, 2, 1)
+ normalizer = F.fold(normalizer,
+ output_size=output_size,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride'])
+
+ x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
+ output_size=output_size,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride'])
+
+ x = F.unfold(x / normalizer,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride']).permute(
+ 0, 2, 1).contiguous().view(b, n, c)
+ x = self.conv2(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B*num_windows, T*window_size*window_size, C)
+ """
+ B, T, H, W, C = x.shape
+
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
+ window_size[1], C)
+
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
+ -1, T * window_size[0] * window_size[1], C)
+ return windows
+
+
+def window_partition_noreshape(x, window_size):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
+ """
+ B, T, H, W, C = x.shape
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
+ window_size[1], C)
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
+ return windows
+
+
+def window_reverse(windows, window_size, T, H, W):
+ """
+ Args:
+ windows: shape is (num_windows*B, T, window_size, window_size, C)
+ window_size (tuple[int]): Window size
+ T (int): Temporal length of video
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, T, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
+ x = windows.view(B, H // window_size[0], W // window_size[1], T,
+ window_size[0], window_size[1], -1)
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """Temporal focal window attention
+ """
+ def __init__(self, dim, expand_size, window_size, focal_window,
+ focal_level, num_heads, qkv_bias, pool_method):
+
+ super().__init__()
+ self.dim = dim
+ self.expand_size = expand_size
+ self.window_size = window_size # Wh, Ww
+ self.pool_method = pool_method
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ if any(i > 0 for i in self.expand_size) and focal_level > 0:
+ # get mask for rolled k and rolled v
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
+ 0).flatten(0)
+ self.register_buffer("valid_ind_rolled",
+ mask_rolled.nonzero(as_tuple=False).view(-1))
+
+ if pool_method != "none" and focal_level > 1:
+ self.unfolds = nn.ModuleList()
+
+ # build relative position bias between local patch and pooled windows
+ for k in range(focal_level - 1):
+ stride = 2**k
+ kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
+ for i in self.focal_window)
+ # define unfolding operations
+ self.unfolds += [
+ nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=tuple(i // 2 for i in kernel_size))
+ ]
+
+ # define unfolding index for focal_level > 0
+ if k > 0:
+ mask = torch.zeros(kernel_size)
+ mask[(2**k) - 1:, (2**k) - 1:] = 1
+ self.register_buffer(
+ "valid_ind_unfold_{}".format(k),
+ mask.flatten(0).nonzero(as_tuple=False).view(-1))
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x_all, mask_all=None):
+ """
+ Args:
+ x: input features with shape of (B, T, Wh, Ww, C)
+ mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
+
+ output: (nW*B, Wh*Ww, C)
+ """
+ x = x_all[0]
+
+ B, T, nH, nW, C = x.shape
+ qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
+ C).permute(4, 0, 1, 2, 3, 5).contiguous()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
+
+ # partition q map
+ (q_windows, k_windows, v_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
+ contiguous().view(-1, self.num_heads, T * self.window_size[
+ 0] * self.window_size[1], C // self.num_heads), (q, k, v))
+ # q(k/v)_windows shape : [16, 4, 225, 128]
+
+ if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
+ (k_tl, v_tl) = map(
+ lambda t: torch.roll(t,
+ shifts=(-self.expand_size[0], -self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_tr, v_tr) = map(
+ lambda t: torch.roll(t,
+ shifts=(-self.expand_size[0], self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_bl, v_bl) = map(
+ lambda t: torch.roll(t,
+ shifts=(self.expand_size[0], -self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_br, v_br) = map(
+ lambda t: torch.roll(t,
+ shifts=(self.expand_size[0], self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
+ k_rolled = torch.cat(
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
+ 2).permute(0, 3, 1, 2, 4).contiguous()
+ v_rolled = torch.cat(
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
+ 2).permute(0, 3, 1, 2, 4).contiguous()
+
+ # mask out tokens in current window
+ k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
+ v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
+ temp_N = k_rolled.shape[3]
+ k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
+ C // self.num_heads)
+ v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
+ C // self.num_heads)
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
+ else:
+ k_rolled = k_windows
+ v_rolled = v_windows
+
+ # q(k/v)_windows shape : [16, 4, 225, 128]
+ # k_rolled.shape : [16, 4, 5, 165, 128]
+ # ideal expanded window size 153 ((5+2*2)*(9+2*4))
+ # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ k_pooled = []
+ v_pooled = []
+ for k in range(self.focal_level - 1):
+ stride = 2**k
+ # B, T, nWh, nWw, C
+ x_window_pooled = x_all[k + 1].permute(0, 3, 1, 2,
+ 4).contiguous()
+
+ nWh, nWw = x_window_pooled.shape[2:4]
+
+ # generate mask for pooled windows
+ mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
+ # unfold mask: [nWh*nWw//s//s, k*k, 1]
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
+ 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
+ view(nWh*nWw // stride // stride, -1, 1)
+
+ if k > 0:
+ valid_ind_unfold_k = getattr(
+ self, "valid_ind_unfold_{}".format(k))
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
+
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
+ x_window_masks = x_window_masks.masked_fill(
+ x_window_masks == 0,
+ float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
+ mask_all[k + 1] = x_window_masks
+
+ # generate k and v for pooled windows
+ qkv_pooled = self.qkv(x_window_pooled).reshape(
+ B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
+ 3).view(3, -1, C, nWh,
+ nWw).contiguous()
+ # B*T, C, nWh, nWw
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2]
+ # k_pooled_k shape: [5, 512, 4, 4]
+ # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
+
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: self.unfolds[k]
+ (t).view(B, T, C, self.unfolds[k].kernel_size[0], self.
+ unfolds[k].kernel_size[1], -1)
+ .permute(0, 5, 1, 3, 4, 2).contiguous().view(
+ -1, T, self.unfolds[k].kernel_size[0] * self.unfolds[
+ k].kernel_size[1], self.num_heads, C // self.
+ num_heads).permute(0, 3, 1, 2, 4).contiguous(),
+ # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
+ (k_pooled_k, v_pooled_k))
+ # k_pooled_k shape : [16, 4, 5, 45, 128]
+
+ # select valid unfolding index
+ if k > 0:
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: t[:, :, :, valid_ind_unfold_k],
+ (k_pooled_k, v_pooled_k))
+
+ k_pooled_k = k_pooled_k.view(
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
+ v_pooled_k = v_pooled_k.view(
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
+
+ k_pooled += [k_pooled_k]
+ v_pooled += [v_pooled_k]
+
+ # k_all (v_all) shape : [16, 4, 5 * 210, 128]
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
+ else:
+ k_all = k_rolled
+ v_all = v_rolled
+
+ N = k_all.shape[-2]
+ q_windows = q_windows * self.scale
+ # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
+ attn = (q_windows @ k_all.transpose(-2, -1))
+ # T * 45
+ window_area = T * self.window_size[0] * self.window_size[1]
+ # T * 165
+ window_area_rolled = k_rolled.shape[2]
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ offset = window_area_rolled
+ for k in range(self.focal_level - 1):
+ # add attentional mask
+ # mask_all[1] shape [1, 16, T * 45]
+
+ bias = tuple((i + 2**k - 1) for i in self.focal_window)
+
+ if mask_all[k + 1] is not None:
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
+ mask_all[k+1][:, :, None, None, :].repeat(
+ attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
+
+ offset += T * bias[0] * bias[1]
+
+ if mask_all[0] is not None:
+ nW = mask_all[0].shape[0]
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
+ window_area, N)
+ attn[:, :, :, :, :
+ window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
+ None, :, None, :, :]
+ attn = attn.view(-1, self.num_heads, window_area, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
+ C)
+ x = self.proj(x)
+ return x
+
+
+class TemporalFocalTransformerBlock(nn.Module):
+ r""" Temporal Focal Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ focal_level (int): The number level of focal window.
+ focal_window (int): Window size of each focal window.
+ n_vecs (int): Required for F3N.
+ t2t_params (int): T2T parameters for F3N.
+ """
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(5, 9),
+ mlp_ratio=4.,
+ qkv_bias=True,
+ pool_method="fc",
+ focal_level=2,
+ focal_window=(5, 9),
+ norm_layer=nn.LayerNorm,
+ n_vecs=None,
+ t2t_params=None):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.expand_size = tuple(i // 2 for i in window_size) # TODO
+ self.mlp_ratio = mlp_ratio
+ self.pool_method = pool_method
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ self.window_size_glo = self.window_size
+
+ self.pool_layers = nn.ModuleList()
+ if self.pool_method != "none":
+ for k in range(self.focal_level - 1):
+ window_size_glo = tuple(
+ math.floor(i / (2**k)) for i in self.window_size_glo)
+ self.pool_layers.append(
+ nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
+ self.pool_layers[-1].weight.data.fill_(
+ 1. / (window_size_glo[0] * window_size_glo[1]))
+ self.pool_layers[-1].bias.data.fill_(0)
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = WindowAttention(dim,
+ expand_size=self.expand_size,
+ window_size=self.window_size,
+ focal_window=focal_window,
+ focal_level=focal_level,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ pool_method=pool_method)
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
+
+ def forward(self, x):
+ output_size = x[1]
+ x = x[0]
+
+ B, T, H, W, C = x.shape
+
+ shortcut = x
+ x = self.norm1(x)
+
+ shifted_x = x
+
+ x_windows_all = [shifted_x]
+ x_window_masks_all = [None]
+
+ # partition windows tuple(i // 2 for i in window_size)
+ if self.focal_level > 1 and self.pool_method != "none":
+ # if we add coarser granularity and the pool method is not none
+ for k in range(self.focal_level - 1):
+ window_size_glo = tuple(
+ math.floor(i / (2**k)) for i in self.window_size_glo)
+ pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
+ pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
+ H_pool = pooled_h * window_size_glo[0]
+ W_pool = pooled_w * window_size_glo[1]
+
+ x_level_k = shifted_x
+ # trim or pad shifted_x depending on the required size
+ if H > H_pool:
+ trim_t = (H - H_pool) // 2
+ trim_b = H - H_pool - trim_t
+ x_level_k = x_level_k[:, :, trim_t:-trim_b]
+ elif H < H_pool:
+ pad_t = (H_pool - H) // 2
+ pad_b = H_pool - H - pad_t
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
+
+ if W > W_pool:
+ trim_l = (W - W_pool) // 2
+ trim_r = W - W_pool - trim_l
+ x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
+ elif W < W_pool:
+ pad_l = (W_pool - W) // 2
+ pad_r = W_pool - W - pad_l
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
+
+ x_windows_noreshape = window_partition_noreshape(
+ x_level_k.contiguous(), window_size_glo
+ ) # B, nw, nw, T, window_size, window_size, C
+ nWh, nWw = x_windows_noreshape.shape[1:3]
+ x_windows_noreshape = x_windows_noreshape.view(
+ B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
+ C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
+ x_windows_pooled = self.pool_layers[k](
+ x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
+
+ x_windows_all += [x_windows_pooled]
+ x_window_masks_all += [None]
+
+ # nW*B, T*window_size*window_size, C
+ attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all)
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, T, self.window_size[0],
+ self.window_size[1], C)
+ shifted_x = window_reverse(attn_windows, self.window_size, T, H,
+ W) # B T H' W' C
+
+ # FFN
+ x = shortcut + shifted_x
+ y = self.norm2(x)
+ x = x + self.mlp(y.view(B, T * H * W, C), output_size).view(
+ B, T, H, W, C)
+
+ return x, output_size
diff --git a/inpainter/util/__init__.py b/inpainter/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inpainter/util/tensor_util.py b/inpainter/util/tensor_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a4746a5ecde78dc582f6169d12db9ac58d209f
--- /dev/null
+++ b/inpainter/util/tensor_util.py
@@ -0,0 +1,24 @@
+import cv2
+import numpy as np
+
+# resize frames
+def resize_frames(frames, size=None):
+ """
+ size: (w, h)
+ """
+ if size is not None:
+ frames = [cv2.resize(f, size) for f in frames]
+ frames = np.stack(frames, 0)
+
+ return frames
+
+# resize frames
+def resize_masks(masks, size=None):
+ """
+ size: (w, h)
+ """
+ if size is not None:
+ masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks]
+ masks = np.stack(masks, 0)
+
+ return masks
diff --git a/overleaf/.DS_Store b/overleaf/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..61e4b14b963c4c6c8475c11becaeb63034095a1b
Binary files /dev/null and b/overleaf/.DS_Store differ
diff --git a/overleaf/Track Anything.zip b/overleaf/Track Anything.zip
new file mode 100644
index 0000000000000000000000000000000000000000..c6ca6ab814cd0bbd2e388bfcbdc1a14cf44b91d2
--- /dev/null
+++ b/overleaf/Track Anything.zip
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d271378ac9538e322b362b43a41e2c22a21cffac6f539a0c3e5b140c3b24b47e
+size 5370701
diff --git a/overleaf/Track Anything/figs/avengers_1.pdf b/overleaf/Track Anything/figs/avengers_1.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..f0336aa313befe8b5a784920ba6657b166af8206
--- /dev/null
+++ b/overleaf/Track Anything/figs/avengers_1.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a519eb00a2d315ecdc36b5a53e174e9b3361a9526c7fcd8a96bfefde2eeb940f
+size 2570569
diff --git a/overleaf/Track Anything/figs/davisresults.pdf b/overleaf/Track Anything/figs/davisresults.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..465c78b08ce5d2e100d0a4742abddeb719ed2f19
--- /dev/null
+++ b/overleaf/Track Anything/figs/davisresults.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fefd74df3daafd48ffb72a725c43354712a244db70e6c5d7ae8773203e0be492
+size 1349133
diff --git a/overleaf/Track Anything/figs/failedcases.pdf b/overleaf/Track Anything/figs/failedcases.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..daed58affbc580d37cbfa9eebbed947d2d78965a
--- /dev/null
+++ b/overleaf/Track Anything/figs/failedcases.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ccb662ff62914d05fe8dc99640b9f89b32847675dd2069900a27771569378aa4
+size 1200242
diff --git a/overleaf/Track Anything/figs/overview_4.pdf b/overleaf/Track Anything/figs/overview_4.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..d8882674696d0c04b50c71933a80f0ccd5967536
Binary files /dev/null and b/overleaf/Track Anything/figs/overview_4.pdf differ
diff --git a/overleaf/Track Anything/neurips_2022.bbl b/overleaf/Track Anything/neurips_2022.bbl
new file mode 100644
index 0000000000000000000000000000000000000000..88a2605a2bb47ae690cadd8474416b2e02cc9296
--- /dev/null
+++ b/overleaf/Track Anything/neurips_2022.bbl
@@ -0,0 +1,105 @@
+\begin{thebibliography}{10}
+
+\bibitem{xmem}
+Ho~Kei Cheng and Alexander~G. Schwing.
+\newblock Xmem: Long-term video object segmentation with an atkinson-shiffrin
+ memory model.
+\newblock In {\em {ECCV} {(28)}}, volume 13688 of {\em Lecture Notes in
+ Computer Science}, pages 640--658. Springer, 2022.
+
+\bibitem{mivos}
+Ho~Kei Cheng, Yu{-}Wing Tai, and Chi{-}Keung Tang.
+\newblock Modular interactive video object segmentation: Interaction-to-mask,
+ propagation and difference-aware fusion.
+\newblock In {\em {CVPR}}, pages 5559--5568. Computer Vision Foundation /
+ {IEEE}, 2021.
+
+\bibitem{vit}
+Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn,
+ Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg
+ Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby.
+\newblock An image is worth 16x16 words: Transformers for image recognition at
+ scale.
+\newblock In {\em {ICLR}}. OpenReview.net, 2021.
+
+\bibitem{vos}
+Mingqi Gao, Feng Zheng, James J.~Q. Yu, Caifeng Shan, Guiguang Ding, and
+ Jungong Han.
+\newblock Deep learning for video object segmentation: a review.
+\newblock {\em Artif. Intell. Rev.}, 56(1):457--531, 2023.
+
+\bibitem{sam}
+Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura
+ Gustafson, Tete Xiao, Spencer Whitehead, Alexander~C Berg, Wan-Yen Lo, et~al.
+\newblock Segment anything.
+\newblock {\em arXiv preprint arXiv:2304.02643}, 2023.
+
+\bibitem{vot10}
+Matej Kristan, Ale{\v{s}} Leonardis, Ji{\v{r}}{\'\i} Matas, Michael Felsberg,
+ Roman Pflugfelder, Joni-Kristian K{\"a}m{\"a}r{\"a}inen, Hyung~Jin Chang,
+ Martin Danelljan, Luka~{\v{C}}ehovin Zajc, Alan Luke{\v{z}}i{\v{c}}, et~al.
+\newblock The tenth visual object tracking vot2022 challenge results.
+\newblock In {\em Computer Vision--ECCV 2022 Workshops: Tel Aviv, Israel,
+ October 23--27, 2022, Proceedings, Part VIII}, pages 431--460. Springer,
+ 2023.
+
+\bibitem{vot8}
+Matej Kristan, Ale{\v{s}} Leonardis, Ji{\v{r}}{\'\i} Matas, Michael Felsberg,
+ Roman Pflugfelder, Joni-Kristian K{\"a}m{\"a}r{\"a}inen, Martin Danelljan,
+ Luka~{\v{C}}ehovin Zajc, Alan Luke{\v{z}}i{\v{c}}, Ondrej Drbohlav, et~al.
+\newblock The eighth visual object tracking vot2020 challenge results.
+\newblock In {\em European Conference on Computer Vision}, pages 547--601.
+ Springer, 2020.
+
+\bibitem{vot6}
+Matej Kristan, Ales Leonardis, Jiri Matas, Michael Felsberg, Roman Pflugfelder,
+ Luka ˇCehovin~Zajc, Tomas Vojir, Goutam Bhat, Alan Lukezic, Abdelrahman
+ Eldesokey, et~al.
+\newblock The sixth visual object tracking vot2018 challenge results.
+\newblock In {\em Proceedings of the European Conference on Computer Vision
+ (ECCV) Workshops}, pages 0--0, 2018.
+
+\bibitem{vot9}
+Matej Kristan, Ji{\v{r}}{\'\i} Matas, Ale{\v{s}} Leonardis, Michael Felsberg,
+ Roman Pflugfelder, Joni-Kristian K{\"a}m{\"a}r{\"a}inen, Hyung~Jin Chang,
+ Martin Danelljan, Luka Cehovin, Alan Luke{\v{z}}i{\v{c}}, et~al.
+\newblock The ninth visual object tracking vot2021 challenge results.
+\newblock In {\em Proceedings of the IEEE/CVF International Conference on
+ Computer Vision}, pages 2711--2738, 2021.
+
+\bibitem{vot7}
+Matej Kristan, Jiri Matas, Ales Leonardis, Michael Felsberg, Roman Pflugfelder,
+ Joni-Kristian Kamarainen, Luka ˇCehovin~Zajc, Ondrej Drbohlav, Alan Lukezic,
+ Amanda Berg, et~al.
+\newblock The seventh visual object tracking vot2019 challenge results.
+\newblock In {\em Proceedings of the IEEE/CVF International Conference on
+ Computer Vision Workshops}, pages 0--0, 2019.
+
+\bibitem{e2fgvi}
+Zhen Li, Chengze Lu, Jianhua Qin, Chun{-}Le Guo, and Ming{-}Ming Cheng.
+\newblock Towards an end-to-end framework for flow-guided video inpainting.
+\newblock In {\em {CVPR}}, pages 17541--17550. {IEEE}, 2022.
+
+\bibitem{stm}
+Seoung~Wug Oh, Joon{-}Young Lee, Ning Xu, and Seon~Joo Kim.
+\newblock Video object segmentation using space-time memory networks.
+\newblock In {\em {ICCV}}, pages 9225--9234. {IEEE}, 2019.
+
+\bibitem{davis}
+Jordi Pont{-}Tuset, Federico Perazzi, Sergi Caelles, Pablo Arbelaez, Alexander
+ Sorkine{-}Hornung, and Luc~Van Gool.
+\newblock The 2017 {DAVIS} challenge on video object segmentation.
+\newblock {\em CoRR}, abs/1704.00675, 2017.
+
+\bibitem{siammask}
+Qiang Wang, Li~Zhang, Luca Bertinetto, Weiming Hu, and Philip H.~S. Torr.
+\newblock Fast online object tracking and segmentation: {A} unifying approach.
+\newblock In {\em {CVPR}}, pages 1328--1338. Computer Vision Foundation /
+ {IEEE}, 2019.
+
+\bibitem{aot}
+Zongxin Yang, Yunchao Wei, and Yi~Yang.
+\newblock Associating objects with transformers for video object segmentation.
+\newblock In {\em NeurIPS}, pages 2491--2502, 2021.
+
+\end{thebibliography}
diff --git a/overleaf/Track Anything/neurips_2022.bib b/overleaf/Track Anything/neurips_2022.bib
new file mode 100644
index 0000000000000000000000000000000000000000..64dc0f1fcc48a2b8b56857dd3459fcb7b45b8775
--- /dev/null
+++ b/overleaf/Track Anything/neurips_2022.bib
@@ -0,0 +1,187 @@
+@article{sam,
+ title={Segment anything},
+ author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C and Lo, Wan-Yen and others},
+ journal={arXiv preprint arXiv:2304.02643},
+ year={2023}
+}
+
+@inproceedings{xmem,
+ author = {Ho Kei Cheng and
+ Alexander G. Schwing},
+ title = {XMem: Long-Term Video Object Segmentation with an Atkinson-Shiffrin
+ Memory Model},
+ booktitle = {{ECCV} {(28)}},
+ series = {Lecture Notes in Computer Science},
+ volume = {13688},
+ pages = {640--658},
+ publisher = {Springer},
+ year = {2022}
+}
+
+
+%related
+
+@article{vos,
+ author = {Mingqi Gao and
+ Feng Zheng and
+ James J. Q. Yu and
+ Caifeng Shan and
+ Guiguang Ding and
+ Jungong Han},
+ title = {Deep learning for video object segmentation: a review},
+ journal = {Artif. Intell. Rev.},
+ volume = {56},
+ number = {1},
+ pages = {457--531},
+ year = {2023}
+}
+
+@inproceedings{vot9,
+ title={The ninth visual object tracking vot2021 challenge results},
+ author={Kristan, Matej and Matas, Ji{\v{r}}{\'\i} and Leonardis, Ale{\v{s}} and Felsberg, Michael and Pflugfelder, Roman and K{\"a}m{\"a}r{\"a}inen, Joni-Kristian and Chang, Hyung Jin and Danelljan, Martin and Cehovin, Luka and Luke{\v{z}}i{\v{c}}, Alan and others},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+ pages={2711--2738},
+ year={2021}
+}
+
+@inproceedings{vot10,
+ title={The Tenth Visual Object Tracking VOT2022 Challenge Results},
+ author={Kristan, Matej and Leonardis, Ale{\v{s}} and Matas, Ji{\v{r}}{\'\i} and Felsberg, Michael and Pflugfelder, Roman and K{\"a}m{\"a}r{\"a}inen, Joni-Kristian and Chang, Hyung Jin and Danelljan, Martin and Zajc, Luka {\v{C}}ehovin and Luke{\v{z}}i{\v{c}}, Alan and others},
+ booktitle={Computer Vision--ECCV 2022 Workshops: Tel Aviv, Israel, October 23--27, 2022, Proceedings, Part VIII},
+ pages={431--460},
+ year={2023},
+ organization={Springer}
+}
+
+@inproceedings{vot8,
+ title={The eighth visual object tracking VOT2020 challenge results},
+ author={Kristan, Matej and Leonardis, Ale{\v{s}} and Matas, Ji{\v{r}}{\'\i} and Felsberg, Michael and Pflugfelder, Roman and K{\"a}m{\"a}r{\"a}inen, Joni-Kristian and Danelljan, Martin and Zajc, Luka {\v{C}}ehovin and Luke{\v{z}}i{\v{c}}, Alan and Drbohlav, Ondrej and others},
+ booktitle={European Conference on Computer Vision},
+ pages={547--601},
+ year={2020},
+ organization={Springer}
+}
+@inproceedings{vot7,
+ title={The seventh visual object tracking vot2019 challenge results},
+ author={Kristan, Matej and Matas, Jiri and Leonardis, Ales and Felsberg, Michael and Pflugfelder, Roman and Kamarainen, Joni-Kristian and ˇCehovin Zajc, Luka and Drbohlav, Ondrej and Lukezic, Alan and Berg, Amanda and others},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops},
+ pages={0--0},
+ year={2019}
+}
+@inproceedings{vot6,
+ title={The sixth visual object tracking vot2018 challenge results},
+ author={Kristan, Matej and Leonardis, Ales and Matas, Jiri and Felsberg, Michael and Pflugfelder, Roman and ˇCehovin Zajc, Luka and Vojir, Tomas and Bhat, Goutam and Lukezic, Alan and Eldesokey, Abdelrahman and others},
+ booktitle={Proceedings of the European Conference on Computer Vision (ECCV) Workshops},
+ pages={0--0},
+ year={2018}
+}
+
+@inproceedings{vit,
+ author = {Alexey Dosovitskiy and
+ Lucas Beyer and
+ Alexander Kolesnikov and
+ Dirk Weissenborn and
+ Xiaohua Zhai and
+ Thomas Unterthiner and
+ Mostafa Dehghani and
+ Matthias Minderer and
+ Georg Heigold and
+ Sylvain Gelly and
+ Jakob Uszkoreit and
+ Neil Houlsby},
+ title = {An Image is Worth 16x16 Words: Transformers for Image Recognition
+ at Scale},
+ booktitle = {{ICLR}},
+ publisher = {OpenReview.net},
+ year = {2021}
+}
+
+@inproceedings{stm,
+ author = {Seoung Wug Oh and
+ Joon{-}Young Lee and
+ Ning Xu and
+ Seon Joo Kim},
+ title = {Video Object Segmentation Using Space-Time Memory Networks},
+ booktitle = {{ICCV}},
+ pages = {9225--9234},
+ publisher = {{IEEE}},
+ year = {2019}
+}
+
+@inproceedings{siammask,
+ author = {Qiang Wang and
+ Li Zhang and
+ Luca Bertinetto and
+ Weiming Hu and
+ Philip H. S. Torr},
+ title = {Fast Online Object Tracking and Segmentation: {A} Unifying Approach},
+ booktitle = {{CVPR}},
+ pages = {1328--1338},
+ publisher = {Computer Vision Foundation / {IEEE}},
+ year = {2019}
+}
+
+@inproceedings{mivos,
+ author = {Ho Kei Cheng and
+ Yu{-}Wing Tai and
+ Chi{-}Keung Tang},
+ title = {Modular Interactive Video Object Segmentation: Interaction-to-Mask,
+ Propagation and Difference-Aware Fusion},
+ booktitle = {{CVPR}},
+ pages = {5559--5568},
+ publisher = {Computer Vision Foundation / {IEEE}},
+ year = {2021}
+}
+
+@article{davis,
+ author = {Jordi Pont{-}Tuset and
+ Federico Perazzi and
+ Sergi Caelles and
+ Pablo Arbelaez and
+ Alexander Sorkine{-}Hornung and
+ Luc Van Gool},
+ title = {The 2017 {DAVIS} Challenge on Video Object Segmentation},
+ journal = {CoRR},
+ volume = {abs/1704.00675},
+ year = {2017}
+}
+
+@inproceedings{aot,
+ author = {Zongxin Yang and
+ Yunchao Wei and
+ Yi Yang},
+ title = {Associating Objects with Transformers for Video Object Segmentation},
+ booktitle = {NeurIPS},
+ pages = {2491--2502},
+ year = {2021}
+}
+
+@inproceedings{icip,
+ author = {St{\'{e}}phane Vujasinovic and
+ Sebastian Bullinger and
+ Stefan Becker and
+ Norbert Scherer{-}Negenborn and
+ Michael Arens and
+ Rainer Stiefelhagen},
+ title = {Revisiting Click-Based Interactive Video Object Segmentation},
+ booktitle = {{ICIP}},
+ pages = {2756--2760},
+ publisher = {{IEEE}},
+ year = {2022}
+}
+
+
+
+
+@inproceedings{e2fgvi,
+ author = {Zhen Li and
+ Chengze Lu and
+ Jianhua Qin and
+ Chun{-}Le Guo and
+ Ming{-}Ming Cheng},
+ title = {Towards An End-to-End Framework for Flow-Guided Video Inpainting},
+ booktitle = {{CVPR}},
+ pages = {17541--17550},
+ publisher = {{IEEE}},
+ year = {2022}
+}
\ No newline at end of file
diff --git a/overleaf/Track Anything/neurips_2022.sty b/overleaf/Track Anything/neurips_2022.sty
new file mode 100644
index 0000000000000000000000000000000000000000..6dc2e3a0903d4eead078cb3c882d032d01131176
--- /dev/null
+++ b/overleaf/Track Anything/neurips_2022.sty
@@ -0,0 +1,381 @@
+% partial rewrite of the LaTeX2e package for submissions to the
+% Conference on Neural Information Processing Systems (NeurIPS):
+%
+% - uses more LaTeX conventions
+% - line numbers at submission time replaced with aligned numbers from
+% lineno package
+% - \nipsfinalcopy replaced with [final] package option
+% - automatically loads times package for authors
+% - loads natbib automatically; this can be suppressed with the
+% [nonatbib] package option
+% - adds foot line to first page identifying the conference
+% - adds preprint option for submission to e.g. arXiv
+% - conference acronym modified
+%
+% Roman Garnett (garnett@wustl.edu) and the many authors of
+% nips15submit_e.sty, including MK and drstrip@sandia
+%
+% last revision: March 2022
+
+\NeedsTeXFormat{LaTeX2e}
+\ProvidesPackage{neurips_2022}[2022/03/31 NeurIPS 2022 submission/camera-ready style file]
+
+% declare final option, which creates camera-ready copy
+\newif\if@neuripsfinal\@neuripsfinalfalse
+\DeclareOption{final}{
+ \@neuripsfinaltrue
+}
+
+% declare nonatbib option, which does not load natbib in case of
+% package clash (users can pass options to natbib via
+% \PassOptionsToPackage)
+\newif\if@natbib\@natbibtrue
+\DeclareOption{nonatbib}{
+ \@natbibfalse
+}
+
+% declare preprint option, which creates a preprint version ready for
+% upload to, e.g., arXiv
+\newif\if@preprint\@preprintfalse
+\DeclareOption{preprint}{
+ \@preprinttrue
+}
+
+\ProcessOptions\relax
+
+% determine whether this is an anonymized submission
+\newif\if@submission\@submissiontrue
+\if@neuripsfinal\@submissionfalse\fi
+\if@preprint\@submissionfalse\fi
+
+% fonts
+\renewcommand{\rmdefault}{ptm}
+\renewcommand{\sfdefault}{phv}
+
+% change this every year for notice string at bottom
+\newcommand{\@neuripsordinal}{36th}
+\newcommand{\@neuripsyear}{2022}
+\newcommand{\@neuripslocation}{New Orleans}
+
+% acknowledgments
+\usepackage{environ}
+\newcommand{\acksection}{\section*{Acknowledgments and Disclosure of Funding}}
+\NewEnviron{ack}{%
+ \acksection
+ \BODY
+}
+
+
+% load natbib unless told otherwise
+%\if@natbib
+% \RequirePackage{natbib}
+%\fi
+
+% set page geometry
+\usepackage[verbose=true,letterpaper]{geometry}
+\AtBeginDocument{
+ \newgeometry{
+ textheight=9in,
+ textwidth=5.5in,
+ top=1in,
+ headheight=12pt,
+ headsep=25pt,
+ footskip=30pt
+ }
+ \@ifpackageloaded{fullpage}
+ {\PackageWarning{neurips_2022}{fullpage package not allowed! Overwriting formatting.}}
+ {}
+}
+
+\widowpenalty=10000
+\clubpenalty=10000
+\flushbottom
+\sloppy
+
+
+% font sizes with reduced leading
+\renewcommand{\normalsize}{%
+ \@setfontsize\normalsize\@xpt\@xipt
+ \abovedisplayskip 7\p@ \@plus 2\p@ \@minus 5\p@
+ \abovedisplayshortskip \z@ \@plus 3\p@
+ \belowdisplayskip \abovedisplayskip
+ \belowdisplayshortskip 4\p@ \@plus 3\p@ \@minus 3\p@
+}
+\normalsize
+\renewcommand{\small}{%
+ \@setfontsize\small\@ixpt\@xpt
+ \abovedisplayskip 6\p@ \@plus 1.5\p@ \@minus 4\p@
+ \abovedisplayshortskip \z@ \@plus 2\p@
+ \belowdisplayskip \abovedisplayskip
+ \belowdisplayshortskip 3\p@ \@plus 2\p@ \@minus 2\p@
+}
+\renewcommand{\footnotesize}{\@setfontsize\footnotesize\@ixpt\@xpt}
+\renewcommand{\scriptsize}{\@setfontsize\scriptsize\@viipt\@viiipt}
+\renewcommand{\tiny}{\@setfontsize\tiny\@vipt\@viipt}
+\renewcommand{\large}{\@setfontsize\large\@xiipt{14}}
+\renewcommand{\Large}{\@setfontsize\Large\@xivpt{16}}
+\renewcommand{\LARGE}{\@setfontsize\LARGE\@xviipt{20}}
+\renewcommand{\huge}{\@setfontsize\huge\@xxpt{23}}
+\renewcommand{\Huge}{\@setfontsize\Huge\@xxvpt{28}}
+
+% sections with less space
+\providecommand{\section}{}
+\renewcommand{\section}{%
+ \@startsection{section}{1}{\z@}%
+ {-2.0ex \@plus -0.5ex \@minus -0.2ex}%
+ { 1.5ex \@plus 0.3ex \@minus 0.2ex}%
+ {\large\bf\raggedright}%
+}
+\providecommand{\subsection}{}
+\renewcommand{\subsection}{%
+ \@startsection{subsection}{2}{\z@}%
+ {-1.8ex \@plus -0.5ex \@minus -0.2ex}%
+ { 0.8ex \@plus 0.2ex}%
+ {\normalsize\bf\raggedright}%
+}
+\providecommand{\subsubsection}{}
+\renewcommand{\subsubsection}{%
+ \@startsection{subsubsection}{3}{\z@}%
+ {-1.5ex \@plus -0.5ex \@minus -0.2ex}%
+ { 0.5ex \@plus 0.2ex}%
+ {\normalsize\bf\raggedright}%
+}
+\providecommand{\paragraph}{}
+\renewcommand{\paragraph}{%
+ \@startsection{paragraph}{4}{\z@}%
+ {1.5ex \@plus 0.5ex \@minus 0.2ex}%
+ {-1em}%
+ {\normalsize\bf}%
+}
+\providecommand{\subparagraph}{}
+\renewcommand{\subparagraph}{%
+ \@startsection{subparagraph}{5}{\z@}%
+ {1.5ex \@plus 0.5ex \@minus 0.2ex}%
+ {-1em}%
+ {\normalsize\bf}%
+}
+\providecommand{\subsubsubsection}{}
+\renewcommand{\subsubsubsection}{%
+ \vskip5pt{\noindent\normalsize\rm\raggedright}%
+}
+
+% float placement
+\renewcommand{\topfraction }{0.85}
+\renewcommand{\bottomfraction }{0.4}
+\renewcommand{\textfraction }{0.1}
+\renewcommand{\floatpagefraction}{0.7}
+
+\newlength{\@neuripsabovecaptionskip}\setlength{\@neuripsabovecaptionskip}{7\p@}
+\newlength{\@neuripsbelowcaptionskip}\setlength{\@neuripsbelowcaptionskip}{\z@}
+
+\setlength{\abovecaptionskip}{\@neuripsabovecaptionskip}
+\setlength{\belowcaptionskip}{\@neuripsbelowcaptionskip}
+
+% swap above/belowcaptionskip lengths for tables
+\renewenvironment{table}
+ {\setlength{\abovecaptionskip}{\@neuripsbelowcaptionskip}%
+ \setlength{\belowcaptionskip}{\@neuripsabovecaptionskip}%
+ \@float{table}}
+ {\end@float}
+
+% footnote formatting
+\setlength{\footnotesep }{6.65\p@}
+\setlength{\skip\footins}{9\p@ \@plus 4\p@ \@minus 2\p@}
+\renewcommand{\footnoterule}{\kern-3\p@ \hrule width 12pc \kern 2.6\p@}
+\setcounter{footnote}{0}
+
+% paragraph formatting
+\setlength{\parindent}{\z@}
+\setlength{\parskip }{5.5\p@}
+
+% list formatting
+\setlength{\topsep }{4\p@ \@plus 1\p@ \@minus 2\p@}
+\setlength{\partopsep }{1\p@ \@plus 0.5\p@ \@minus 0.5\p@}
+\setlength{\itemsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@}
+\setlength{\parsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@}
+\setlength{\leftmargin }{3pc}
+\setlength{\leftmargini }{\leftmargin}
+\setlength{\leftmarginii }{2em}
+\setlength{\leftmarginiii}{1.5em}
+\setlength{\leftmarginiv }{1.0em}
+\setlength{\leftmarginv }{0.5em}
+\def\@listi {\leftmargin\leftmargini}
+\def\@listii {\leftmargin\leftmarginii
+ \labelwidth\leftmarginii
+ \advance\labelwidth-\labelsep
+ \topsep 2\p@ \@plus 1\p@ \@minus 0.5\p@
+ \parsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@
+ \itemsep \parsep}
+\def\@listiii{\leftmargin\leftmarginiii
+ \labelwidth\leftmarginiii
+ \advance\labelwidth-\labelsep
+ \topsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@
+ \parsep \z@
+ \partopsep 0.5\p@ \@plus 0\p@ \@minus 0.5\p@
+ \itemsep \topsep}
+\def\@listiv {\leftmargin\leftmarginiv
+ \labelwidth\leftmarginiv
+ \advance\labelwidth-\labelsep}
+\def\@listv {\leftmargin\leftmarginv
+ \labelwidth\leftmarginv
+ \advance\labelwidth-\labelsep}
+\def\@listvi {\leftmargin\leftmarginvi
+ \labelwidth\leftmarginvi
+ \advance\labelwidth-\labelsep}
+
+% create title
+\providecommand{\maketitle}{}
+\renewcommand{\maketitle}{%
+ \par
+ \begingroup
+ \renewcommand{\thefootnote}{\fnsymbol{footnote}}
+ % for perfect author name centering
+ \renewcommand{\@makefnmark}{\hbox to \z@{$^{\@thefnmark}$\hss}}
+ % The footnote-mark was overlapping the footnote-text,
+ % added the following to fix this problem (MK)
+ \long\def\@makefntext##1{%
+ \parindent 1em\noindent
+ \hbox to 1.8em{\hss $\m@th ^{\@thefnmark}$}##1
+ }
+ \thispagestyle{empty}
+ \@maketitle
+ \@thanks
+ \@notice
+ \endgroup
+ \let\maketitle\relax
+ \let\thanks\relax
+}
+
+% rules for title box at top of first page
+\newcommand{\@toptitlebar}{
+ \hrule height 4\p@
+ \vskip 0.25in
+ \vskip -\parskip%
+}
+\newcommand{\@bottomtitlebar}{
+ \vskip 0.29in
+ \vskip -\parskip
+ \hrule height 1\p@
+ \vskip 0.09in%
+}
+
+% create title (includes both anonymized and non-anonymized versions)
+\providecommand{\@maketitle}{}
+\renewcommand{\@maketitle}{%
+ \vbox{%
+ \hsize\textwidth
+ \linewidth\hsize
+ \vskip 0.1in
+ \@toptitlebar
+ \centering
+ {\LARGE\bf \@title\par}
+ \@bottomtitlebar
+ \if@submission
+ \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}
+ Anonymous Author(s) \\
+ Affiliation \\
+ Address \\
+ \texttt{email} \\
+ \end{tabular}%
+ \else
+ \def\And{%
+ \end{tabular}\hfil\linebreak[0]\hfil%
+ \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces%
+ }
+ \def\AND{%
+ \end{tabular}\hfil\linebreak[4]\hfil%
+ \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces%
+ }
+ \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\@author\end{tabular}%
+ \fi
+ \vskip 0.3in \@minus 0.1in
+ }
+}
+
+% add conference notice to bottom of first page
+\newcommand{\ftype@noticebox}{8}
+\newcommand{\@notice}{%
+ % give a bit of extra room back to authors on first page
+ \enlargethispage{2\baselineskip}%
+ \@float{noticebox}[b]%
+ \footnotesize\@noticestring%
+ \end@float%
+}
+
+% abstract styling
+\renewenvironment{abstract}%
+{%
+ \vskip 0.075in%
+ \centerline%
+ {\large\bf Abstract}%
+ \vspace{0.5ex}%
+ \begin{quote}%
+}
+{
+ \par%
+ \end{quote}%
+ \vskip 1ex%
+}
+
+% For the paper checklist
+\newcommand{\answerYes}[1][]{\textcolor{blue}{[Yes] #1}}
+\newcommand{\answerNo}[1][]{\textcolor{orange}{[No] #1}}
+\newcommand{\answerNA}[1][]{\textcolor{gray}{[N/A] #1}}
+\newcommand{\answerTODO}[1][]{\textcolor{red}{\bf [TODO]}}
+
+% handle tweaks for camera-ready copy vs. submission copy
+\if@preprint
+ \newcommand{\@noticestring}{%
+ Preprint. Under review.%
+ }
+\else
+ \if@neuripsfinal
+ \newcommand{\@noticestring}{%
+ \@neuripsordinal\/ Conference on Neural Information Processing Systems
+ (NeurIPS \@neuripsyear).%, \@neuripslocation.%
+ }
+ \else
+ \newcommand{\@noticestring}{%
+ Submitted to \@neuripsordinal\/ Conference on Neural Information
+ Processing Systems (NeurIPS \@neuripsyear). Do not distribute.%
+ }
+
+ % hide the acknowledgements
+ \NewEnviron{hide}{}
+ \let\ack\hide
+ \let\endack\endhide
+
+ % line numbers for submission
+ \RequirePackage{lineno}
+ \linenumbers
+
+ % fix incompatibilities between lineno and amsmath, if required, by
+ % transparently wrapping linenomath environments around amsmath
+ % environments
+ \AtBeginDocument{%
+ \@ifpackageloaded{amsmath}{%
+ \newcommand*\patchAmsMathEnvironmentForLineno[1]{%
+ \expandafter\let\csname old#1\expandafter\endcsname\csname #1\endcsname
+ \expandafter\let\csname oldend#1\expandafter\endcsname\csname end#1\endcsname
+ \renewenvironment{#1}%
+ {\linenomath\csname old#1\endcsname}%
+ {\csname oldend#1\endcsname\endlinenomath}%
+ }%
+ \newcommand*\patchBothAmsMathEnvironmentsForLineno[1]{%
+ \patchAmsMathEnvironmentForLineno{#1}%
+ \patchAmsMathEnvironmentForLineno{#1*}%
+ }%
+ \patchBothAmsMathEnvironmentsForLineno{equation}%
+ \patchBothAmsMathEnvironmentsForLineno{align}%
+ \patchBothAmsMathEnvironmentsForLineno{flalign}%
+ \patchBothAmsMathEnvironmentsForLineno{alignat}%
+ \patchBothAmsMathEnvironmentsForLineno{gather}%
+ \patchBothAmsMathEnvironmentsForLineno{multline}%
+ }
+ {}
+ }
+ \fi
+\fi
+
+
+\endinput
diff --git a/overleaf/Track Anything/neurips_2022.tex b/overleaf/Track Anything/neurips_2022.tex
new file mode 100644
index 0000000000000000000000000000000000000000..f14a483fc519ccd4312233d185747c6d0bbaf1c7
--- /dev/null
+++ b/overleaf/Track Anything/neurips_2022.tex
@@ -0,0 +1,378 @@
+\documentclass{article}
+
+
+% if you need to pass options to natbib, use, e.g.:
+% \PassOptionsToPackage{numbers, compress}{natbib}
+% before loading neurips_2022
+
+
+% ready for submission
+% \usepackage{neurips_2022}
+
+
+% to compile a preprint version, e.g., for submission to arXiv, add add the
+% [preprint] option:
+ \usepackage[preprint]{neurips_2022}
+
+% to compile a camera-ready version, add the [final] option, e.g.:
+% \usepackage[final]{neurips_2022}
+
+
+% to avoid loading the natbib package, add option nonatbib:
+% \usepackage[nonatbib]{neurips_2022}
+\usepackage{graphicx}
+\usepackage[utf8]{inputenc} % allow utf-8 input
+\usepackage[T1]{fontenc} % use 8-bit T1 fonts
+\usepackage{hyperref} % hyperlinks
+\usepackage{url} % simple URL typesetting
+\usepackage{booktabs} % professional-quality tables
+\usepackage{amsfonts} % blackboard math symbols
+\usepackage{nicefrac} % compact symbols for 1/2, etc.
+\usepackage{microtype} % microtypography
+\usepackage{xcolor} % colors
+% \usepackage{acmart}
+
+\title{Track Anything: High-performance Interactive Tracking and Segmentation}
+\title{Track Anything: High-performance Object Tracking in Videos by Interactive Masks}
+% \title{Track Anything: Interaction to Mask in Videos}
+\title{Track Anything: Segment Anything Meets Videos}
+
+% \author{%
+% David S.~Hippocampus\thanks{Use footnote for providing further information
+% about author (webpage, alternative address)---\emph{not} for acknowledging
+% funding agencies.} \\
+% SUSTech VIPG\\
+
+% \author{Jinyu Yang}
+% \authornote{equal}
+
+% \author{Mingqi Gao}
+% \authornotemark[1]
+
+\author{%
+ Jinyu Yang\thanks{Equal contribution. Alphabetical order.},\enskip Mingqi Gao\footnotemark[1],\enskip Zhe Li\footnotemark[1],\enskip Shang Gao, Fangjing Wang, Feng Zheng \\
+ SUSTech VIP Lab\\
+ % Cranberry-Lemon University\\
+ % Pittsburgh, PA 15213 \\
+ % \texttt{hippo@cs.cranberry-lemon.edu} \\
+ % \url{https://github.com/gaomingqi/Track-Anything}\\
+ % examples of more authors
+ % \And
+ % Coauthor \\
+ % Affiliation \\
+ % Address \\
+ % \texttt{email} \\
+ % \AND
+ % Coauthor \\
+ % Affiliation \\
+ % Address \\
+ % \texttt{email} \\
+ % \And
+ % Coauthor \\
+ % Affiliation \\
+ % Address \\
+ % \texttt{email} \\
+ % \And
+ % Coauthor \\
+ % Affiliation \\
+ % Address \\
+ % \texttt{email} \\
+ % \thanks{these authors contributed equally}
+}
+% \affiliation{\institution{SUSTech VIP Lab}}
+% \footnote{Equal contribution. Alphabetical order.}
+
+\begin{document}
+
+
+\maketitle
+
+
+\begin{abstract}
+
+Recently, the Segment Anything Model (SAM) gains lots of attention rapidly due to its impressive segmentation performance on images.
+Regarding its strong ability on image segmentation and high interactivity with different prompts, we found that it performs poorly on consistent segmentation in videos.
+Therefore, in this report, we propose Track Anything Model (TAM), which achieves high-performance interactive tracking and segmentation in videos.
+To be detailed, given a video sequence, only with very little human participation, \textit{i.e.}, several clicks, people can track anything they are interested in, and get satisfactory results in one-pass inference.
+Without additional training, such an interactive design performs impressively on video object tracking and segmentation.
+% superior to prior works on video object tracking and segmentation.
+All resources are available on \url{https://github.com/gaomingqi/Track-Anything}.
+We hope this work can facilitate related research.
+
+\end{abstract}
+
+\section{Introduction}
+
+Tracking an arbitrary object in generic scenes is important, and Video Object Tracking (VOT) is a fundamental task in computer vision.
+Similar to VOT, Video Object Segmentation (VOS) aims to separate the target (region of interest) from the background in a video sequence, which can be seen as a kind of more fine-grained object tracking.
+We notice that current state-of-the-art video trackers/segmenters are trained on large-scale manually-annotated datasets and initialized by a bounding box or a segmentation mask.
+On the one hand, the massive human labor force is hidden behind huge amounts of labeled data.
+% Recently, interactive algorithms help to liberate users from labor-expensive initialization and annotation.
+Moreover, current initialization settings, especially the semi-supervised VOS, need specific object mask groundtruth for model initialization.
+How to liberate researchers from labor-expensive annotation and initialization is much of important.
+
+
+Recently, Segment-Anything Model (SAM)~\cite{sam} has been proposed, which is a large foundation model for image segmentation.
+It supports flexible prompts and computes masks in real-time, thus allowing interactive use.
+We conclude that SAM has the following advantages that can assist interactive tracking:
+\textbf{1) Strong image segmentation ability.}
+Trained on 11 million images and 1.1 billion masks, SAM can produce high-quality masks and do zero-shot segmentation in generic scenarios.
+\textbf{2) High interactivity with different kinds of prompts. }
+With input user-friendly prompts of points, boxes, or language, SAM can give satisfactory segmentation masks on specific image areas.
+However, using SAM in videos directly did not give us an impressive performance due to its deficiency in temporal correspondence.
+
+On the other hand, tracking or segmenting in videos faces challenges from scale variation, target deformation, motion blur, camera motion, similar objects, and so on~\cite{vos,vot6,vot7,vot8,vot9,vot10}.
+Even the state-of-the-art models suffer from complex scenarios in the public datasets~\cite{xmem}, not to mention the real-world applications.
+Therefore, a question is considered by us:
+\textit{can we achieve high-performance tracking/segmentation in videos through the way of interaction?}
+
+In this technical report, we introduce our Track-Anything project, which develops an efficient toolkit for high-performance object tracking and segmentation in videos.
+With a user-friendly interface, the Track Anything Model (TAM) can track and segment any objects in a given video with only one-pass inference.
+Figure~\ref{fig:overview} shows the one-pass interactive process in the proposed TAM.
+In detail, TAM combines SAM~\cite{sam}, a large segmentation model, and XMem~\cite{xmem}, an advanced VOS model.
+As shown, we integrate them in an interactive way.
+Firstly, users can interactively initialize the SAM, \textit{i.e.}, clicking on the object, to define a target object;
+then, XMem is used to give a mask prediction of the object in the next frame according to both temporal and spatial correspondence;
+next, SAM is utilized to give a more precise mask description;
+during the tracking process, users can pause and correct as soon as they notice tracking failures.
+
+Our contributions can be concluded as follows:
+
+1) We promote the SAM applications to the video level to achieve interactive video object tracking and segmentation.
+% We combine the SAM with VOS models to achieve interactive video object tracking and segmentation.
+Rather than separately using SAM per frame, we integrate SAM into the process of temporal correspondence construction.
+
+2) We propose one-pass interactive tracking and segmentation for efficient annotation and a user-friendly tracking interface, which uses very small amounts of human participation to solve extreme difficulties in video object perception.
+
+3) Our proposed method shows superior performance and high usability in complex scenes and has many potential applications.
+
+% \section{Related Works}
+
+% \textbf{Video Object Tracking.}
+
+
+
+% \textbf{Video Object Segmentation.}
+\section{Track Anything Task}
+
+Inspired by the Segment Anything task~\cite{sam}, we propose the Track Anything task, which aims to flexible object tracking in arbitrary videos.
+Here we define that the target objects can be flexibly selected, added, or removed in any way according to the users' interests.
+Also, the video length and types can be arbitrary rather than limited to trimmed or natural videos.
+With such settings, diverse downstream tasks can be achieved, including single/multiple object tracking, short-/long-term object tracking, unsupervised VOS, semi-supervised VOS, referring VOS, interactive VOS, long-term VOS, and so on.
+
+\section{Methodology}
+
+\subsection{Preliminaries}
+
+\textbf{Segment Anything Model~\cite{sam}.}
+Very recently, the Segment Anything Model (SAM) has been proposed by Meta AI Research and gets numerous attention.
+As a foundation model for image segmentation, SAM is based on ViT~\cite{vit} and trained on the large-scale dataset SA-1B~\cite{sam}.
+Obviously, SAM shows promising segmentation ability on images, especially on zero-shot segmentation tasks.
+Unfortunately, SAM only shows superior performance on image segmentation, while it cannot deal with complex video segmentation.
+
+
+\textbf{XMem~\cite{xmem}.}
+Given the mask description of the target object at the first frame, XMem can track the object and generate corresponding masks in the subsequent frames.
+Inspired by the Atkinson-Shiffrin memory model, it aims to solve the difficulties in long-term videos with unified feature memory stores.
+The drawbacks of XMem are also obvious: 1) as a semi-supervised VOS model, it requires a precise mask to initialize; 2) for long videos, it is difficult for XMem to recover from tracking or segmentation failure.
+In this paper, we solve both difficulties by importing interactive tracking with SAM.
+
+
+\textbf{Interactive Video Object Segmentation.}
+Interactive VOS~\cite{mivos} takes user interactions as inputs, \textit{e.g.}, scribbles.
+Then, users can iteratively refine the segmentation results until they are satisfied with them.
+Interactive VOS gains lots of attention as it is much easier to provide scribbles than to specify every pixel for an object mask.
+However, we found that current interactive VOS methods require multiple rounds to refine the results, which impedes their efficiency in real-world applications.
+
+\begin{figure}[t]
+\centering
+\includegraphics[width=\linewidth]{figs/overview_4.pdf}
+\caption{Pipeline of our proposed Track Anything Model (TAM). Only within one round of inference can the TAM obtain impressive tracking and segmentation performance on the human-selected target.}
+\label{fig:overview}
+\end{figure}
+
+\begin{table}
+ \caption{Results on DAVIS-2016-val and DAVIS-2017-test-dev datasets~\cite{davis}.}
+ \label{davis1617}
+ \centering
+ \small
+ \setlength\tabcolsep{4pt}
+ \begin{tabular}{l|c|c|c|ccc|ccc}
+ \toprule
+ & & & &\multicolumn{3}{c|}{DAVIS-2016-val} &\multicolumn{3}{c}{DAVIS-2017-test-dev} \\
+ Method & Venue & Initialization & Evaluation& $J\&F$ & $J$ &$F$ &$J\&F$ & $J$ &$F$\\
+ \midrule
+ STM~\cite{stm} & ICCV2019 &Mask & One Pass &89.3 &88.7 &89.9 & 72.2 & 69.3 & 75.2 \\
+ AOT~\cite{aot} &NeurIPS2021 &Mask & One Pass & 91.1 & 90.1 & 92.1 & 79.6 & 75.9 & 83.3 \\
+ XMem~\cite{xmem} & NeurIPS2022 &Mask & One Pass & 92.0 &90.7 &93.2 & 81.2 & 77.6 & 84.7\\
+ \midrule
+ % SiamMask~\cite{siammask}& CVPR2019 &Box & One Pass & 69.8 &71.7 &67.8 &56.4 &54.3 &58.5 \\
+ SiamMask~\cite{siammask}& CVPR2019 &Box & One Pass & 69.8 &71.7 &67.8 &- &- &- \\
+ \midrule
+ % MiVOS~\cite{mivos} & CVPR2021 &Scribble &8 Rounds &91.0 &89.6 &92.4 & 84.5 &81.7 &87.4\\
+ MiVOS~\cite{mivos} & CVPR2021 &Scribble &8 Rounds &91.0 &89.6 &92.4 &78.6 &74.9 &82.2\\
+ % \midrule
+ % & ICIP2022 &Click & \\
+ \midrule
+ TAM (Proposed) &- & Click & One Pass & 88.4 & 87.5 &89.4 & 73.1 & 69.8 & 76.4\\
+ % Ours & & 5 Clicks & \\
+ \bottomrule
+ \end{tabular}
+\end{table}
+
+
+
+\subsection{Implementation}\label{implementation}
+
+Inspired by SAM, we consider tracking anything in videos.
+We aim to define this task with high interactivity and ease of use.
+It leads to ease of use and is able to obtain high performance with very little human interaction effort.
+Figure~\ref{fig:overview} shows the pipeline of our Track Anything Model (TAM).
+As shown, we divide our Track-Anything process into the following four steps:
+
+\textbf{Step 1: Initialization with SAM~\cite{sam}.}
+As SAM provides us an opportunity to segment a region of interest with weak prompts, \textit{e.g.}, points, and bounding boxes, we use it to give an initial mask of the target object.
+Following SAM, users can get a mask description of the interested object by a click or modify the object mask with several clicks to get a satisfactory initialization.
+
+\textbf{Step 2: Tracking with XMem~\cite{xmem}.}
+Given the initialized mask, XMem performs semi-supervised VOS on the following frames.
+Since XMem is an advanced VOS method that can output satisfactory results on simple scenarios, we output the predicted masks of XMem on most occasions.
+When the mask quality is not such good, we save the XMem predictions and corresponding intermediate parameters, \textit{i.e.}, probes and affinities, and skip to step 3.
+% Given the initialized mask and the whole sequence, XMem performs semi-supervised VOS, which aims to solve the performance decay in long-term prediction with memory potentiation.
+
+
+\textbf{Step 3: Refinement with SAM~\cite{sam}.}
+We notice that during the inference of VOS models, keep predicting consistent and precise masks are challenging.
+In fact, most state-of-the-art VOS models tend to segment more and more coarsely over time during inference.
+Therefore, we utilize SAM to refine the masks predicted by XMem when its quality assessment is not satisfactory.
+Specifically, we project the probes and affinities to be point prompts for SAM, and the predicted mask from Step 2 is used as a mask prompt for SAM.
+Then, with these prompts, SAM is able to produce a refined segmentation mask.
+Such refined masks will also be added to the temporal correspondence of XMem to refine all subsequent object discrimination.
+
+\textbf{Step 4: Correction with human participation.}
+% Long video annotation.
+After the above three steps, the TAM can now successfully solve some common challenges and predict segmentation masks.
+However, we notice that it is still difficult to accurately distinguish the objects in some extremely challenging scenarios, especially when processing long videos.
+Therefore, we propose to add human correction during inference, which can bring a qualitative leap in performance with only very small human efforts.
+In detail, users can compulsively stop the TAM process and correct the mask of the current frame with positive and negative clicks.
+
+\section{Experiments}
+
+\subsection{Quantitative Results}
+
+
+To evaluate TAM, we utilize the validation set of DAVIS-2016 and test-development set of DAVIS-2017~\cite{davis}.
+% The evaluation process follows the one we proposed in Section~\ref{implementation}.
+Then, we execute the proposed TAM as demonstrated in Section~\ref{implementation}.
+The results are given in Table~\ref{davis1617}.
+As shown, our TAM obtains $J\&F$ scores of 88.4 and 73.1 on DAVIS-2016-val and DAVIS-2017-test-dev datasets, respectively.
+Note that TAM is initialized by clicks and evaluated in one pass.
+Notably, we found that TAM performs well when against difficult and complex scenarios.
+% During the evaluation,
+
+% click-based interactive video object segmentation
+
+% CLICK-BASED INTERACTIVE VIDEO OBJECT
+% SEGMENTATION
+
+
+\begin{figure}[t]
+\centering
+\includegraphics[width=\linewidth]{figs/davisresults.pdf}
+\caption{Qualitative results on video sequences from DAVIS-16 and DAVIS-17 datasets~\cite{davis}.}
+\label{fig:davisresult}
+\end{figure}
+
+
+\begin{figure}[t]
+\centering
+\includegraphics[width=\linewidth]{figs/failedcases.pdf}
+\caption{Failed cases.}
+\label{fig:failedcases}
+\end{figure}
+
+\subsection{Qualitative Results}
+
+% As we use a new one-pass interactive method to evaluation our TAM, here we only present some qualitative results.
+We also give some qualitative results in Figure~\ref{fig:davisresult}.
+As shown, TAM can handle multi-object separation, target deformation, scale change, and camera motion well, which demonstrates its superior tracking and segmentation abilities within only click initialization and one-round inference.
+
+\subsection{Failed Cases}
+We here also analyze the failed cases, as shown in Figure~\ref{fig:failedcases}.
+Overall, we notice that the failed cases typically appear on the following two occasions.
+1)
+% Separated masks of one object in a long video.
+Current VOS models are mostly designed for short videos, which focus more on maintaining short-term memory rather than long-term memory.
+This leads to mask shrinkage or lacking refinement in long-term videos, as shown in seq (a).
+Essentially, we aim to solve them in step 3 by the refinement ability of SAM, while its effectiveness is lower than expected in realistic applications.
+It indicates that the ability of SAM refinement based on multiple prompts can be further improved in the future.
+On the other hand, human participation/interaction in TAM can be an approach to solving such difficulties, while too much interaction will also result in low efficiency.
+Thus, the mechanism of long-term memory preserving and transient memory updating is still important.
+% Limited refinement by SAM. Although SAM supports to refine previous predictions, via point and mask prompts, . How to .
+2) When the object structure is complex, \textit{e.g.}, the bicycle wheels in seq (b) contain many cavities in groundtruth masks. We found it very difficult to get a fine-grained initialized mask by propagating the clicks.
+Thus, the coarse initialized masks may have side effects on the subsequent frames and lead to poor predictions.
+This also inspires us that SAM is still struggling with complex and precision structures.
+
+
+\begin{figure}[t]
+\centering
+\includegraphics[width=\linewidth]{figs/avengers_1.pdf}
+\caption{Raw frames, object masks, and inpainted results from the movie \textit{Captain America: Civil War (2016)}.}
+\label{fig:captain}
+\end{figure}
+
+
+
+\section{Applications}
+The proposed Track Anything Model (TAM) provides many possibilities for flexible tracking and segmentation in videos.
+Here, we demonstrate several applications enabled by our proposed method.
+% Our method may be able to a variety of applications.
+In such an interactive way, diverse downstream tasks can be easily achieved.
+% \textbf{Demo.}
+% It is able to solve diverse downstream tasks in such a interactive way.
+
+\textbf{Efficient video annotation.}
+TAM has the ability to segment the regions of interest in videos and flexibly choose the objects users want to track. Thus, it can be used for video annotation for tasks like video object tracking and video object segmentation.
+On the other hand, click-based interaction makes it easy to use, and the annotation process is of high efficiency.
+
+
+\textbf{Long-term object tracking.}
+The study of long-term tracking is gaining more and more attention because it is much closer to practical applications.
+Current long-term object tracking task requires the tracker to have the ability to handle target disappearance and reappearance while it is still limited in the scope of trimmed videos.
+Our TAM is more advanced in real-world applications which can handle the shot changes in long videos.
+
+
+\textbf{User-friendly video editing.}
+Track Anything Model provides us the opportunities to segment objects
+With the object segmentation masks provided by TAM, we are then able to remove or alter any of the existing objects in a given video.
+Here we combine E$^2$FGVI~\cite{e2fgvi} to evaluate its application value.
+
+\textbf{Visualized development toolkit for video tasks.}
+For ease of use, we also provide visualized interfaces for multiple video tasks, \textit{e.g.}, VOS, VOT, video inpainting, and so on.
+With the provided toolkit, users can apply their models on real-world videos and visualize the results instantaneously.
+Corresponding demos are available in Hugging Face\footnote{\url{https://huggingface.co/spaces/watchtowerss/Track-Anything}}.
+
+
+To show the effectiveness, we give a comprehensive test by applying TAM on the movie \textit{Captain America: Civil War (2016)}.
+Some representative results are given in Figure \ref{fig:captain}.
+As shown, TAM can present multiple object tracking precisely in videos with lots of shot changes and can further be helpful in video inpainting.
+
+% \section{Further work}
+
+
+% \section*{Acknowledgements}
+
+% \appendix
+
+% \section{Appendix}
+
+
+% Optionally include extra information (complete proofs, additional experiments and plots) in the appendix.
+% This section will often be part of the supplemental material.
+
+
+
+\bibliographystyle{plain}
+\bibliography{neurips_2022}
+
+\end{document}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..16ec760b93f5187361da625198ba9beb57676424
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+progressbar2
+gdown
+gitpython
+git+https://github.com/cheind/py-thin-plate-spline
+hickle
+tensorboard
+numpy
+git+https://github.com/facebookresearch/segment-anything.git
+gradio==3.25.0
+opencv-python
+pycocotools
+matplotlib
+pyyaml
+av
+openmim
+tqdm
+psutil
\ No newline at end of file
diff --git a/sam_vit_h_4b8939.pth b/sam_vit_h_4b8939.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72
--- /dev/null
+++ b/sam_vit_h_4b8939.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
+size 2564550879
diff --git a/template.html b/template.html
new file mode 100644
index 0000000000000000000000000000000000000000..c476716dc1e6d258db094bb0d79a6a232cb986b0
--- /dev/null
+++ b/template.html
@@ -0,0 +1,27 @@
+
+
+
+
+
+
+ Gradio Video Pause Time
+
+
+
+
+
+
diff --git a/templates/index.html b/templates/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..33485832a851f1cc38f0d1b0ee073f7c99dc6725
--- /dev/null
+++ b/templates/index.html
@@ -0,0 +1,50 @@
+
+
+
+
+
+
+ Video Object Segmentation
+
+
+
+
Video Object Segmentation
+
+
+
+
+
+
+
+
+
+
+ Download Video
+
+
+
+
+
diff --git a/test.txt b/test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/test_beta.txt b/test_beta.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/test_sample/huggingface_demo_operation.mp4 b/test_sample/huggingface_demo_operation.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8d1a65f8068bff7204d7ff8e8bef4854de4f7ce4
--- /dev/null
+++ b/test_sample/huggingface_demo_operation.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2594a9bb9862db33fc9f120d21ab1df76cc0653891a7e97f1ea4f396f084b43
+size 61159338
diff --git a/test_sample/test-sample1.mp4 b/test_sample/test-sample1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6f382584230c7e45d6ac884424b4c4d165e5bc40
--- /dev/null
+++ b/test_sample/test-sample1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:403b711376a79026beedb7d0d919d35268298150120438a22a5330d0c8cdd6b6
+size 6039223
diff --git a/test_sample/test-sample13.mp4 b/test_sample/test-sample13.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..62d86d06b6b6e68c4e817f6cd996724cca44f9a1
--- /dev/null
+++ b/test_sample/test-sample13.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54f24a0aaae482aff7ff3555256f60ad1931d478dc5694fb37624cac85479eee
+size 2528426
diff --git a/test_sample/test-sample2.mp4 b/test_sample/test-sample2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..cf6c1b1bd5ed3e2d75f69f8c09502ce6b68feffd
Binary files /dev/null and b/test_sample/test-sample2.mp4 differ
diff --git a/test_sample/test-sample4.mp4 b/test_sample/test-sample4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..f94ac09a75f973665408abfa312aec69d1e90930
Binary files /dev/null and b/test_sample/test-sample4.mp4 differ
diff --git a/test_sample/test-sample8.mp4 b/test_sample/test-sample8.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..59c4edca2167c29a3ae911c64229db4d03301b0a
--- /dev/null
+++ b/test_sample/test-sample8.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2414d24cc1ddfe1619c17e9876a7c3ed0f1f37da234c63c08af2cecbbb16c1ed
+size 8714250
diff --git a/text_server.py b/text_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0623a3d9632ae5eceb27dc002ed63952dbc22c1
--- /dev/null
+++ b/text_server.py
@@ -0,0 +1,72 @@
+import os
+import sys
+import cv2
+import time
+import json
+import queue
+import numpy as np
+import requests
+import concurrent.futures
+from PIL import Image
+from flask import Flask, render_template, request, jsonify, send_file
+import torchvision
+import torch
+
+from demo import automask_image_app, automask_video_app, sahi_autoseg_app
+sys.path.append(sys.path[0] + "/tracker")
+sys.path.append(sys.path[0] + "/tracker/model")
+from track_anything import TrackingAnything
+from track_anything import parse_augment
+
+# ... (all the functions defined in the original code except the Gradio part)
+
+app = Flask(__name__)
+app.config['UPLOAD_FOLDER'] = './uploaded_videos'
+app.config['ALLOWED_EXTENSIONS'] = {'mp4', 'avi', 'mov', 'mkv'}
+
+
+def allowed_file(filename):
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
+
+@app.route("/")
+def index():
+ return render_template("index.html")
+
+@app.route("/upload_video", methods=["POST"])
+def upload_video():
+ # ... (handle video upload and processing)
+ return jsonify(status="success", data=video_data)
+
+@app.route("/template_select", methods=["POST"])
+def template_select():
+ # ... (handle template selection and processing)
+ return jsonify(status="success", data=template_data)
+
+@app.route("/sam_refine", methods=["POST"])
+def sam_refine_request():
+ # ... (handle sam refine and processing)
+ return jsonify(status="success", data=sam_data)
+
+@app.route("/track_video", methods=["POST"])
+def track_video():
+ # ... (handle video tracking and processing)
+ return jsonify(status="success", data=tracking_data)
+
+@app.route("/track_image", methods=["POST"])
+def track_image():
+ # ... (handle image tracking and processing)
+ return jsonify(status="success", data=tracking_data)
+
+@app.route("/download_video", methods=["GET"])
+def download_video():
+ try:
+ return send_file("output.mp4", attachment_filename="output.mp4")
+ except Exception as e:
+ return str(e)
+
+if __name__ == "__main__":
+ app.run(debug=True, host="0.0.0.0", port=args.port)
+
+
+if __name__ == '__main__':
+ app.run(host="0.0.0.0",port=12212, debug=True)
diff --git a/tools/__init__.py b/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tools/base_segmenter.py b/tools/base_segmenter.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b975bb779b47485f9e6ba7435646b4db40a2c6a
--- /dev/null
+++ b/tools/base_segmenter.py
@@ -0,0 +1,129 @@
+import time
+import torch
+import cv2
+from PIL import Image, ImageDraw, ImageOps
+import numpy as np
+from typing import Union
+from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
+import matplotlib.pyplot as plt
+import PIL
+from .mask_painter import mask_painter
+
+
+class BaseSegmenter:
+ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
+ """
+ device: model device
+ SAM_checkpoint: path of SAM checkpoint
+ model_type: vit_b, vit_l, vit_h
+ """
+ print(f"Initializing BaseSegmenter to {device}")
+ assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
+
+ self.device = device
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
+ self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
+ self.model.to(device=self.device)
+ self.predictor = SamPredictor(self.model)
+ self.embedded = False
+
+ @torch.no_grad()
+ def set_image(self, image: np.ndarray):
+ # PIL.open(image_path) 3channel: RGB
+ # image embedding: avoid encode the same image multiple times
+ self.orignal_image = image
+ if self.embedded:
+ print('repeat embedding, please reset_image.')
+ return
+ self.predictor.set_image(image)
+ self.embedded = True
+ return
+
+ @torch.no_grad()
+ def reset_image(self):
+ # reset image embeding
+ self.predictor.reset_image()
+ self.embedded = False
+
+ def predict(self, prompts, mode, multimask=True):
+ """
+ image: numpy array, h, w, 3
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
+ prompts['point_coords']: numpy array [N,2]
+ prompts['point_labels']: numpy array [1,N]
+ prompts['mask_input']: numpy array [1,256,256]
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
+ """
+ assert self.embedded, 'prediction is called before set_image (feature embedding).'
+ assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
+
+ if mode == 'point':
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
+ point_labels=prompts['point_labels'],
+ multimask_output=multimask)
+ elif mode == 'mask':
+ masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
+ multimask_output=multimask)
+ elif mode == 'both': # both
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
+ point_labels=prompts['point_labels'],
+ mask_input=prompts['mask_input'],
+ multimask_output=multimask)
+ else:
+ raise("Not implement now!")
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ return masks, scores, logits
+
+
+if __name__ == "__main__":
+ # load and show an image
+ image = cv2.imread('/hhd3/gaoshang/truck.jpg')
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
+
+ # initialise BaseSegmenter
+ SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
+ model_type = 'vit_h'
+ device = "cuda:4"
+ base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
+
+ # image embedding (once embedded, multiple prompts can be applied)
+ base_segmenter.set_image(image)
+
+ # examples
+ # point only ------------------------
+ mode = 'point'
+ prompts = {
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
+ 'point_labels': np.array([1, 1]),
+ }
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+ cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
+
+ # both ------------------------
+ mode = 'both'
+ mask_input = logits[np.argmax(scores), :, :]
+ prompts = {'mask_input': mask_input [None, :, :]}
+ prompts = {
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
+ 'point_labels': np.array([1, 0]),
+ 'mask_input': mask_input[None, :, :]
+ }
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+ cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
+
+ # mask only ------------------------
+ mode = 'mask'
+ mask_input = logits[np.argmax(scores), :, :]
+
+ prompts = {'mask_input': mask_input[None, :, :]}
+
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+ cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
diff --git a/tools/interact_tools.py b/tools/interact_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..daecc73e5f54c95b53c04520110775281a6e0560
--- /dev/null
+++ b/tools/interact_tools.py
@@ -0,0 +1,265 @@
+import time
+import torch
+import cv2
+from PIL import Image, ImageDraw, ImageOps
+import numpy as np
+from typing import Union
+from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
+import matplotlib.pyplot as plt
+import PIL
+from .mask_painter import mask_painter as mask_painter2
+from .base_segmenter import BaseSegmenter
+from .painter import mask_painter, point_painter
+import os
+import requests
+import sys
+
+
+mask_color = 3
+mask_alpha = 0.7
+contour_color = 1
+contour_width = 5
+point_color_ne = 8
+point_color_ps = 50
+point_alpha = 0.9
+point_radius = 15
+contour_color = 2
+contour_width = 5
+
+
+class SamControler():
+ def __init__(self, SAM_checkpoint, model_type, device):
+ '''
+ initialize sam controler
+ '''
+
+
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
+
+
+ # def seg_again(self, image: np.ndarray):
+ # '''
+ # it is used when interact in video
+ # '''
+ # self.sam_controler.reset_image()
+ # self.sam_controler.set_image(image)
+ # return
+
+
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
+ '''
+ it is used in first frame in video
+ return: mask, logit, painted image(mask+point)
+ '''
+ # self.sam_controler.set_image(image)
+ origal_image = self.sam_controler.orignal_image
+ neg_flag = labels[-1]
+ if neg_flag==1:
+ #find neg
+ prompts = {
+ 'point_coords': points,
+ 'point_labels': labels,
+ }
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+ prompts = {
+ 'point_coords': points,
+ 'point_labels': labels,
+ 'mask_input': logit[None, :, :]
+ }
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+ else:
+ #find positive
+ prompts = {
+ 'point_coords': points,
+ 'point_labels': labels,
+ }
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+
+
+ assert len(points)==len(labels)
+
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
+ painted_image = Image.fromarray(painted_image)
+
+ return mask, logit, painted_image
+
+ # def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
+ # origal_image = self.sam_controler.orignal_image
+ # if same:
+ # '''
+ # true; loop in the same image
+ # '''
+ # prompts = {
+ # 'point_coords': points,
+ # 'point_labels': labels,
+ # 'mask_input': logits[None, :, :]
+ # }
+ # masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
+ # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+
+ # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
+ # painted_image = Image.fromarray(painted_image)
+
+ # return mask, logit, painted_image
+ # else:
+ # '''
+ # loop in the different image, interact in the video
+ # '''
+ # if image is None:
+ # raise('Image error')
+ # else:
+ # self.seg_again(image)
+ # prompts = {
+ # 'point_coords': points,
+ # 'point_labels': labels,
+ # }
+ # masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
+ # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+
+ # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
+ # painted_image = Image.fromarray(painted_image)
+
+ # return mask, logit, painted_image
+
+
+
+
+
+
+# def initialize():
+# '''
+# initialize sam controler
+# '''
+# checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
+# folder = "segmenter"
+# SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
+# download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
+
+
+# model_type = 'vit_h'
+# device = "cuda:0"
+# sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
+# return sam_controler
+
+
+# def seg_again(sam_controler, image: np.ndarray):
+# '''
+# it is used when interact in video
+# '''
+# sam_controler.reset_image()
+# sam_controler.set_image(image)
+# return
+
+
+# def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
+# '''
+# it is used in first frame in video
+# return: mask, logit, painted image(mask+point)
+# '''
+# sam_controler.set_image(image)
+# prompts = {
+# 'point_coords': points,
+# 'point_labels': labels,
+# }
+# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
+# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+
+# assert len(points)==len(labels)
+
+# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
+# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
+# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
+# painted_image = Image.fromarray(painted_image)
+
+# return mask, logit, painted_image
+
+# def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
+# if same:
+# '''
+# true; loop in the same image
+# '''
+# prompts = {
+# 'point_coords': points,
+# 'point_labels': labels,
+# 'mask_input': logits[None, :, :]
+# }
+# masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
+# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+
+# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
+# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
+# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
+# painted_image = Image.fromarray(painted_image)
+
+# return mask, logit, painted_image
+# else:
+# '''
+# loop in the different image, interact in the video
+# '''
+# if image is None:
+# raise('Image error')
+# else:
+# seg_again(sam_controler, image)
+# prompts = {
+# 'point_coords': points,
+# 'point_labels': labels,
+# }
+# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
+# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+
+# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
+# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
+# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
+# painted_image = Image.fromarray(painted_image)
+
+# return mask, logit, painted_image
+
+
+
+
+# if __name__ == "__main__":
+# points = np.array([[500, 375], [1125, 625]])
+# labels = np.array([1, 1])
+# image = cv2.imread('/hhd3/gaoshang/truck.jpg')
+# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+# sam_controler = initialize()
+# mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
+# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
+# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+# cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
+# cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
+# painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
+
+# mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
+# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
+# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+# cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
+# painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
+
+# mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
+# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
+# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+# cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
+# painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tools/mask_painter.py b/tools/mask_painter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f471ea0116d656e2cc236832893b07c6d7be1643
--- /dev/null
+++ b/tools/mask_painter.py
@@ -0,0 +1,288 @@
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+import copy
+import time
+
+
+def colormap(rgb=True):
+ color_list = np.array(
+ [
+ 0.000, 0.000, 0.000,
+ 1.000, 1.000, 1.000,
+ 1.000, 0.498, 0.313,
+ 0.392, 0.581, 0.929,
+ 0.000, 0.447, 0.741,
+ 0.850, 0.325, 0.098,
+ 0.929, 0.694, 0.125,
+ 0.494, 0.184, 0.556,
+ 0.466, 0.674, 0.188,
+ 0.301, 0.745, 0.933,
+ 0.635, 0.078, 0.184,
+ 0.300, 0.300, 0.300,
+ 0.600, 0.600, 0.600,
+ 1.000, 0.000, 0.000,
+ 1.000, 0.500, 0.000,
+ 0.749, 0.749, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.333, 0.333, 0.000,
+ 0.333, 0.667, 0.000,
+ 0.333, 1.000, 0.000,
+ 0.667, 0.333, 0.000,
+ 0.667, 0.667, 0.000,
+ 0.667, 1.000, 0.000,
+ 1.000, 0.333, 0.000,
+ 1.000, 0.667, 0.000,
+ 1.000, 1.000, 0.000,
+ 0.000, 0.333, 0.500,
+ 0.000, 0.667, 0.500,
+ 0.000, 1.000, 0.500,
+ 0.333, 0.000, 0.500,
+ 0.333, 0.333, 0.500,
+ 0.333, 0.667, 0.500,
+ 0.333, 1.000, 0.500,
+ 0.667, 0.000, 0.500,
+ 0.667, 0.333, 0.500,
+ 0.667, 0.667, 0.500,
+ 0.667, 1.000, 0.500,
+ 1.000, 0.000, 0.500,
+ 1.000, 0.333, 0.500,
+ 1.000, 0.667, 0.500,
+ 1.000, 1.000, 0.500,
+ 0.000, 0.333, 1.000,
+ 0.000, 0.667, 1.000,
+ 0.000, 1.000, 1.000,
+ 0.333, 0.000, 1.000,
+ 0.333, 0.333, 1.000,
+ 0.333, 0.667, 1.000,
+ 0.333, 1.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.667, 0.333, 1.000,
+ 0.667, 0.667, 1.000,
+ 0.667, 1.000, 1.000,
+ 1.000, 0.000, 1.000,
+ 1.000, 0.333, 1.000,
+ 1.000, 0.667, 1.000,
+ 0.167, 0.000, 0.000,
+ 0.333, 0.000, 0.000,
+ 0.500, 0.000, 0.000,
+ 0.667, 0.000, 0.000,
+ 0.833, 0.000, 0.000,
+ 1.000, 0.000, 0.000,
+ 0.000, 0.167, 0.000,
+ 0.000, 0.333, 0.000,
+ 0.000, 0.500, 0.000,
+ 0.000, 0.667, 0.000,
+ 0.000, 0.833, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 0.167,
+ 0.000, 0.000, 0.333,
+ 0.000, 0.000, 0.500,
+ 0.000, 0.000, 0.667,
+ 0.000, 0.000, 0.833,
+ 0.000, 0.000, 1.000,
+ 0.143, 0.143, 0.143,
+ 0.286, 0.286, 0.286,
+ 0.429, 0.429, 0.429,
+ 0.571, 0.571, 0.571,
+ 0.714, 0.714, 0.714,
+ 0.857, 0.857, 0.857
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3)) * 255
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list
+
+
+color_list = colormap()
+color_list = color_list.astype('uint8').tolist()
+
+
+def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
+ background_color = np.array(background_color)
+ contour_color = np.array(contour_color)
+
+ # background_mask = 1 - background_mask
+ # contour_mask = 1 - contour_mask
+
+ for i in range(3):
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
+
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
+
+ return image.astype('uint8')
+
+
+def mask_generator_00(mask, background_radius, contour_radius):
+ # no background width when '00'
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+
+ return mask, contour_mask
+
+
+def mask_generator_01(mask, background_radius, contour_radius):
+ # no background width when '00'
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ return mask, contour_mask
+
+
+def mask_generator_10(mask, background_radius, contour_radius):
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # .....:::::!!!!!
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
+ background_mask = (background_mask - np.min(background_mask))
+ background_mask = background_mask / np.max(background_mask)
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+ return background_mask, contour_mask
+
+
+def mask_generator_11(mask, background_radius, contour_radius):
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # .....:::::!!!!!
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
+ background_mask = (background_mask - np.min(background_mask))
+ background_mask = background_mask / np.max(background_mask)
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ return background_mask, contour_mask
+
+
+def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
+ """
+ Input:
+ input_image: numpy array
+ input_mask: numpy array
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
+ background_blur_radius: radius of background blur, must be odd number
+ contour_width: width of mask contour, must be odd number
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
+
+ Output:
+ painted_image: numpy array
+ """
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
+
+ # downsample input image and mask
+ width, height = input_image.shape[0], input_image.shape[1]
+ res = 1024
+ ratio = min(1.0 * res / max(width, height), 1.0)
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
+
+ # 0: background, 1: foreground
+ msk = np.clip(input_mask, 0, 1)
+
+ # generate masks for background and contour pixels
+ background_radius = (background_blur_radius - 1) // 2
+ contour_radius = (contour_width - 1) // 2
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
+
+ # paint
+ painted_image = vis_add_mask\
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
+
+ return painted_image
+
+
+if __name__ == '__main__':
+
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
+ background_blur_radius = 31 # radius of background blur, must be odd number
+ contour_width = 11 # contour width, must be odd number
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
+
+ # load input image and mask
+ input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
+ input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
+
+ # paint
+ overall_time_1 = 0
+ overall_time_2 = 0
+ overall_time_3 = 0
+ overall_time_4 = 0
+ overall_time_5 = 0
+
+ for i in range(50):
+ t2 = time.time()
+ painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
+ e2 = time.time()
+
+ t3 = time.time()
+ painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
+ e3 = time.time()
+
+ t1 = time.time()
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
+ e1 = time.time()
+
+ t4 = time.time()
+ painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
+ e4 = time.time()
+
+ t5 = time.time()
+ painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
+ e5 = time.time()
+
+ overall_time_1 += (e1 - t1)
+ overall_time_2 += (e2 - t2)
+ overall_time_3 += (e3 - t3)
+ overall_time_4 += (e4 - t4)
+ overall_time_5 += (e5 - t5)
+
+ print(f'average time w gaussian: {overall_time_1/50}')
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
+
+ # save
+ painted_image_00 = Image.fromarray(painted_image_00)
+ painted_image_00.save('./test_img/painter_output_image_00.png')
+
+ painted_image_10 = Image.fromarray(painted_image_10)
+ painted_image_10.save('./test_img/painter_output_image_10.png')
+
+ painted_image_01 = Image.fromarray(painted_image_01)
+ painted_image_01.save('./test_img/painter_output_image_01.png')
+
+ painted_image_11 = Image.fromarray(painted_image_11)
+ painted_image_11.save('./test_img/painter_output_image_11.png')
diff --git a/tools/painter.py b/tools/painter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e711d35aa8348d15cdad9d1cd413da41ea4f1ab
--- /dev/null
+++ b/tools/painter.py
@@ -0,0 +1,215 @@
+# paint masks, contours, or points on images, with specified colors
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+import copy
+import time
+
+
+def colormap(rgb=True):
+ color_list = np.array(
+ [
+ 0.000, 0.000, 0.000,
+ 1.000, 1.000, 1.000,
+ 1.000, 0.498, 0.313,
+ 0.392, 0.581, 0.929,
+ 0.000, 0.447, 0.741,
+ 0.850, 0.325, 0.098,
+ 0.929, 0.694, 0.125,
+ 0.494, 0.184, 0.556,
+ 0.466, 0.674, 0.188,
+ 0.301, 0.745, 0.933,
+ 0.635, 0.078, 0.184,
+ 0.300, 0.300, 0.300,
+ 0.600, 0.600, 0.600,
+ 1.000, 0.000, 0.000,
+ 1.000, 0.500, 0.000,
+ 0.749, 0.749, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.333, 0.333, 0.000,
+ 0.333, 0.667, 0.000,
+ 0.333, 1.000, 0.000,
+ 0.667, 0.333, 0.000,
+ 0.667, 0.667, 0.000,
+ 0.667, 1.000, 0.000,
+ 1.000, 0.333, 0.000,
+ 1.000, 0.667, 0.000,
+ 1.000, 1.000, 0.000,
+ 0.000, 0.333, 0.500,
+ 0.000, 0.667, 0.500,
+ 0.000, 1.000, 0.500,
+ 0.333, 0.000, 0.500,
+ 0.333, 0.333, 0.500,
+ 0.333, 0.667, 0.500,
+ 0.333, 1.000, 0.500,
+ 0.667, 0.000, 0.500,
+ 0.667, 0.333, 0.500,
+ 0.667, 0.667, 0.500,
+ 0.667, 1.000, 0.500,
+ 1.000, 0.000, 0.500,
+ 1.000, 0.333, 0.500,
+ 1.000, 0.667, 0.500,
+ 1.000, 1.000, 0.500,
+ 0.000, 0.333, 1.000,
+ 0.000, 0.667, 1.000,
+ 0.000, 1.000, 1.000,
+ 0.333, 0.000, 1.000,
+ 0.333, 0.333, 1.000,
+ 0.333, 0.667, 1.000,
+ 0.333, 1.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.667, 0.333, 1.000,
+ 0.667, 0.667, 1.000,
+ 0.667, 1.000, 1.000,
+ 1.000, 0.000, 1.000,
+ 1.000, 0.333, 1.000,
+ 1.000, 0.667, 1.000,
+ 0.167, 0.000, 0.000,
+ 0.333, 0.000, 0.000,
+ 0.500, 0.000, 0.000,
+ 0.667, 0.000, 0.000,
+ 0.833, 0.000, 0.000,
+ 1.000, 0.000, 0.000,
+ 0.000, 0.167, 0.000,
+ 0.000, 0.333, 0.000,
+ 0.000, 0.500, 0.000,
+ 0.000, 0.667, 0.000,
+ 0.000, 0.833, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 0.167,
+ 0.000, 0.000, 0.333,
+ 0.000, 0.000, 0.500,
+ 0.000, 0.000, 0.667,
+ 0.000, 0.000, 0.833,
+ 0.000, 0.000, 1.000,
+ 0.143, 0.143, 0.143,
+ 0.286, 0.286, 0.286,
+ 0.429, 0.429, 0.429,
+ 0.571, 0.571, 0.571,
+ 0.714, 0.714, 0.714,
+ 0.857, 0.857, 0.857
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3)) * 255
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list
+
+
+color_list = colormap()
+color_list = color_list.astype('uint8').tolist()
+
+
+def vis_add_mask(image, mask, color, alpha):
+ color = np.array(color_list[color])
+ mask = mask > 0.5
+ image[mask] = image[mask] * (1-alpha) + color * alpha
+ return image.astype('uint8')
+
+def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
+ h, w = input_image.shape[:2]
+ point_mask = np.zeros((h, w)).astype('uint8')
+ for point in input_points:
+ point_mask[point[1], point[0]] = 1
+
+ kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
+ point_mask = cv2.dilate(point_mask, kernel)
+
+ contour_radius = (contour_width - 1) // 2
+ dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+
+ # paint mask
+ painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
+ # paint contour
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
+ return painted_image
+
+def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
+ # 0: background, 1: foreground
+ mask = np.clip(input_mask, 0, 1)
+ contour_radius = (contour_width - 1) // 2
+
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+
+ # paint mask
+ painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
+ # paint contour
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
+
+ return painted_image
+
+def background_remover(input_image, input_mask):
+ """
+ input_image: H, W, 3, np.array
+ input_mask: H, W, np.array
+
+ image_wo_background: PIL.Image
+ """
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
+ # 0: background, 1: foreground
+ mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
+ image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
+ image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
+
+ return image_wo_background
+
+if __name__ == '__main__':
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
+ input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
+
+ # example of mask painter
+ mask_color = 3
+ mask_alpha = 0.7
+ contour_color = 1
+ contour_width = 5
+
+ # save
+ painted_image = Image.fromarray(input_image)
+ painted_image.save('images/original.png')
+
+ painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
+ # save
+ painted_image = Image.fromarray(input_image)
+ painted_image.save('images/original1.png')
+
+ # example of point painter
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
+ input_points = np.array([[500, 375], [70, 600]]) # x, y
+ point_color = 5
+ point_alpha = 0.9
+ point_radius = 15
+ contour_color = 2
+ contour_width = 5
+ painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
+ # save
+ painted_image = Image.fromarray(painted_image_1)
+ painted_image.save('images/point_painter_1.png')
+
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
+ painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
+ # save
+ painted_image = Image.fromarray(painted_image_2)
+ painted_image.save('images/point_painter_2.png')
+
+ # example of background remover
+ input_image = np.array(Image.open('images/original.png').convert('RGB'))
+ image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
+ image_wo_background.save('images/image_wo_background.png')
diff --git a/track_anything.py b/track_anything.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9f1b5023d5c5668b1888e6c4a04960098a32cb
--- /dev/null
+++ b/track_anything.py
@@ -0,0 +1,111 @@
+import PIL
+from tqdm import tqdm
+
+from tools.interact_tools import SamControler
+from tracker.base_tracker import BaseTracker
+from inpainter.base_inpainter import BaseInpainter
+import numpy as np
+import argparse
+import cv2
+
+def read_image_from_userfolder(image_path):
+ # if type:
+ image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
+ # else:
+ # image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
+ return image
+
+def save_image_to_userfolder(video_state, index, image, type:bool):
+ if type:
+ image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
+ else:
+ image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
+ cv2.imwrite(image_path, image)
+ return image_path
+class TrackingAnything():
+ def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
+ self.args = args
+ self.sam_checkpoint = sam_checkpoint
+ self.xmem_checkpoint = xmem_checkpoint
+ self.e2fgvi_checkpoint = e2fgvi_checkpoint
+ self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device)
+ self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
+ self.baseinpainter = BaseInpainter(self.e2fgvi_checkpoint, args.device)
+ # def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
+ # same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
+ # if first_flag:
+ # mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
+ # return mask, logit, painted_image
+
+ # if interact_flag:
+ # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
+ # return mask, logit, painted_image
+
+ # mask, logit, painted_image = self.xmem.track(image, logit)
+ # return mask, logit, painted_image
+
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
+ return mask, logit, painted_image
+
+ # def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
+ # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
+ # return mask, logit, painted_image
+
+ def generator(self, images: list, template_mask:np.ndarray, video_state:dict):
+
+ masks = []
+ logits = []
+ painted_images = []
+ for i in tqdm(range(len(images)), desc="Tracking image"):
+ if i ==0:
+ mask, logit, painted_image = self.xmem.track(read_image_from_userfolder(images[i]), template_mask)
+ masks.append(mask)
+ logits.append(logit)
+ # painted_images.append(painted_image)
+ painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
+
+ else:
+ mask, logit, painted_image = self.xmem.track(read_image_from_userfolder(images[i]))
+ masks.append(mask)
+ logits.append(logit)
+ # painted_images.append(painted_image)
+ painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
+ return masks, logits, painted_images
+
+
+def parse_augment():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--device', type=str, default="cuda:0")
+ parser.add_argument('--sam_model_type', type=str, default="vit_h")
+ parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
+ parser.add_argument('--debug', action="store_true")
+ parser.add_argument('--mask_save', default=False)
+ args = parser.parse_args()
+
+ if args.debug:
+ print(args)
+ return args
+
+
+if __name__ == "__main__":
+ masks = None
+ logits = None
+ painted_images = None
+ images = []
+ image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg'))
+ args = parse_augment()
+ # images.append(np.ones((20,20,3)).astype('uint8'))
+ # images.append(np.ones((20,20,3)).astype('uint8'))
+ images.append(image)
+ images.append(image)
+
+ mask = np.zeros_like(image)[:,:,0]
+ mask[0,0]= 1
+ trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args)
+ masks, logits ,painted_images= trackany.generator(images, mask)
+
+
+
+
+
\ No newline at end of file
diff --git a/tracker/.DS_Store b/tracker/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..36a69bcbb75da5f64a5a2520e748b7bb0efab525
Binary files /dev/null and b/tracker/.DS_Store differ
diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d47f6b493afd9c144bf486ae0151f743e3c6371
--- /dev/null
+++ b/tracker/base_tracker.py
@@ -0,0 +1,261 @@
+# import for debugging
+import os
+import glob
+import numpy as np
+from PIL import Image
+# import for base_tracker
+import torch
+import yaml
+import torch.nn.functional as F
+from model.network import XMem
+from inference.inference_core import InferenceCore
+from tracker.util.mask_mapper import MaskMapper
+from torchvision import transforms
+from tracker.util.range_transform import im_normalization
+
+from tools.painter import mask_painter
+from tools.base_segmenter import BaseSegmenter
+from torchvision.transforms import Resize
+import progressbar
+
+
+class BaseTracker:
+ def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) -> None:
+ """
+ device: model device
+ xmem_checkpoint: checkpoint of XMem model
+ """
+ # load configurations
+ with open("tracker/config/config.yaml", 'r') as stream:
+ config = yaml.safe_load(stream)
+ # initialise XMem
+ network = XMem(config, xmem_checkpoint).to(device).eval()
+ # initialise IncerenceCore
+ self.tracker = InferenceCore(network, config)
+ # data transformation
+ self.im_transform = transforms.Compose([
+ transforms.ToTensor(),
+ im_normalization,
+ ])
+ self.device = device
+
+ # changable properties
+ self.mapper = MaskMapper()
+ self.initialised = False
+
+ # # SAM-based refinement
+ # self.sam_model = sam_model
+ # self.resizer = Resize([256, 256])
+
+ @torch.no_grad()
+ def resize_mask(self, mask):
+ # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
+ h, w = mask.shape[-2:]
+ min_hw = min(h, w)
+ return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
+ mode='nearest')
+
+ @torch.no_grad()
+ def track(self, frame, first_frame_annotation=None):
+ """
+ Input:
+ frames: numpy arrays (H, W, 3)
+ logit: numpy array (H, W), logit
+
+ Output:
+ mask: numpy arrays (H, W)
+ logit: numpy arrays, probability map (H, W)
+ painted_image: numpy array (H, W, 3)
+ """
+
+ if first_frame_annotation is not None: # first frame mask
+ # initialisation
+ mask, labels = self.mapper.convert_mask(first_frame_annotation)
+ mask = torch.Tensor(mask).to(self.device)
+ self.tracker.set_all_labels(list(self.mapper.remappings.values()))
+ else:
+ mask = None
+ labels = None
+ # prepare inputs
+ frame_tensor = self.im_transform(frame).to(self.device)
+ # track one frame
+ probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
+ # # refine
+ # if first_frame_annotation is None:
+ # out_mask = self.sam_refinement(frame, logits[1], ti)
+
+ # convert to mask
+ out_mask = torch.argmax(probs, dim=0)
+ out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
+
+ final_mask = np.zeros_like(out_mask)
+
+ # map back
+ for k, v in self.mapper.remappings.items():
+ final_mask[out_mask == v] = k
+
+ num_objs = final_mask.max()
+ painted_image = frame
+ for obj in range(1, num_objs+1):
+ if np.max(final_mask==obj) == 0:
+ continue
+ painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
+
+ # print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
+
+ return final_mask, final_mask, painted_image
+
+ @torch.no_grad()
+ def sam_refinement(self, frame, logits, ti):
+ """
+ refine segmentation results with mask prompt
+ """
+ # convert to 1, 256, 256
+ self.sam_model.set_image(frame)
+ mode = 'mask'
+ logits = logits.unsqueeze(0)
+ logits = self.resizer(logits).cpu().numpy()
+ prompts = {'mask_input': logits} # 1 256 256
+ masks, scores, logits = self.sam_model.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ painted_image = mask_painter(frame, masks[np.argmax(scores)].astype('uint8'), mask_alpha=0.8)
+ painted_image = Image.fromarray(painted_image)
+ painted_image.save(f'/ssd1/gaomingqi/refine/{ti:05d}.png')
+ self.sam_model.reset_image()
+
+ @torch.no_grad()
+ def clear_memory(self):
+ self.tracker.clear_memory()
+ self.mapper.clear_labels()
+ torch.cuda.empty_cache()
+
+
+## how to use:
+## 1/3) prepare device and xmem_checkpoint
+# device = 'cuda:2'
+# XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
+## 2/3) initialise Base Tracker
+# tracker = BaseTracker(XMEM_checkpoint, device, None, device) # leave an interface for sam model (currently set None)
+## 3/3)
+
+
+if __name__ == '__main__':
+ # video frames (take videos from DAVIS-2017 as examples)
+ video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
+ video_path_list.sort()
+ # load frames
+ frames = []
+ for video_path in video_path_list:
+ frames.append(np.array(Image.open(video_path).convert('RGB')))
+ frames = np.stack(frames, 0) # T, H, W, C
+ # load first frame annotation
+ first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
+ first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
+
+ # ------------------------------------------------------------------------------------
+ # how to use
+ # ------------------------------------------------------------------------------------
+ # 1/4: set checkpoint and device
+ device = 'cuda:2'
+ XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
+ # SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
+ # model_type = 'vit_h'
+ # ------------------------------------------------------------------------------------
+ # 2/4: initialise inpainter
+ tracker = BaseTracker(XMEM_checkpoint, device, None, device)
+ # ------------------------------------------------------------------------------------
+ # 3/4: for each frame, get tracking results by tracker.track(frame, first_frame_annotation)
+ # frame: numpy array (H, W, C), first_frame_annotation: numpy array (H, W), leave it blank when tracking begins
+ painted_frames = []
+ for ti, frame in enumerate(frames):
+ if ti == 0:
+ mask, prob, painted_frame = tracker.track(frame, first_frame_annotation)
+ # mask:
+ else:
+ mask, prob, painted_frame = tracker.track(frame)
+ painted_frames.append(painted_frame)
+ # ----------------------------------------------
+ # 3/4: clear memory in XMEM for the next video
+ tracker.clear_memory()
+ # ----------------------------------------------
+ # end
+ # ----------------------------------------------
+ print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
+ # set saving path
+ save_path = '/ssd1/gaomingqi/results/TAM/blackswan'
+ if not os.path.exists(save_path):
+ os.mkdir(save_path)
+ # save
+ for painted_frame in progressbar.progressbar(painted_frames):
+ painted_frame = Image.fromarray(painted_frame)
+ painted_frame.save(f'{save_path}/{ti:05d}.png')
+
+ # tracker.clear_memory()
+ # for ti, frame in enumerate(frames):
+ # print(ti)
+ # # if ti > 200:
+ # # break
+ # if ti == 0:
+ # mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
+ # else:
+ # mask, prob, painted_image = tracker.track(frame)
+ # # save
+ # painted_image = Image.fromarray(painted_image)
+ # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
+
+ # # track anything given in the first frame annotation
+ # for ti, frame in enumerate(frames):
+ # if ti == 0:
+ # mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
+ # else:
+ # mask, prob, painted_image = tracker.track(frame)
+ # # save
+ # painted_image = Image.fromarray(painted_image)
+ # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
+
+ # # ----------------------------------------------------------
+ # # another video
+ # # ----------------------------------------------------------
+ # # video frames
+ # video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/camel', '*.jpg'))
+ # video_path_list.sort()
+ # # first frame
+ # first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/camel/00000.png'
+ # # load frames
+ # frames = []
+ # for video_path in video_path_list:
+ # frames.append(np.array(Image.open(video_path).convert('RGB')))
+ # frames = np.stack(frames, 0) # N, H, W, C
+ # # load first frame annotation
+ # first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
+
+ # print('first video done. clear.')
+
+ # tracker.clear_memory()
+ # # track anything given in the first frame annotation
+ # for ti, frame in enumerate(frames):
+ # if ti == 0:
+ # mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
+ # else:
+ # mask, prob, painted_image = tracker.track(frame)
+ # # save
+ # painted_image = Image.fromarray(painted_image)
+ # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
+
+ # # failure case test
+ # failure_path = '/ssd1/gaomingqi/failure'
+ # frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
+ # # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
+ # first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
+ # first_mask = np.clip(first_mask, 0, 1)
+
+ # for ti, frame in enumerate(frames):
+ # if ti == 0:
+ # mask, probs, painted_image = tracker.track(frame, first_mask)
+ # else:
+ # mask, probs, painted_image = tracker.track(frame)
+ # # save
+ # painted_image = Image.fromarray(painted_image)
+ # painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
+ # prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
+
+ # # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
diff --git a/tracker/config/config.yaml b/tracker/config/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3c99064e04262eb50827056bef225877bbc12822
--- /dev/null
+++ b/tracker/config/config.yaml
@@ -0,0 +1,15 @@
+# config info for XMem
+benchmark: False
+disable_long_term: False
+max_mid_term_frames: 10
+min_mid_term_frames: 5
+max_long_term_elements: 1000
+num_prototypes: 128
+top_k: 30
+mem_every: 5
+deep_update_every: -1
+save_scores: False
+flip: False
+size: 480
+enable_long_term: True
+enable_long_term_count_usage: True
diff --git a/tracker/inference/__init__.py b/tracker/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/inference/inference_core.py b/tracker/inference/inference_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..e77f0805e30d3967265ed458dd7357e65a20c24f
--- /dev/null
+++ b/tracker/inference/inference_core.py
@@ -0,0 +1,115 @@
+from inference.memory_manager import MemoryManager
+from model.network import XMem
+from model.aggregate import aggregate
+
+from tracker.util.tensor_util import pad_divide_by, unpad
+
+
+class InferenceCore:
+ def __init__(self, network:XMem, config):
+ self.config = config
+ self.network = network
+ self.mem_every = config['mem_every']
+ self.deep_update_every = config['deep_update_every']
+ self.enable_long_term = config['enable_long_term']
+
+ # if deep_update_every < 0, synchronize deep update with memory frame
+ self.deep_update_sync = (self.deep_update_every < 0)
+
+ self.clear_memory()
+ self.all_labels = None
+
+ def clear_memory(self):
+ self.curr_ti = -1
+ self.last_mem_ti = 0
+ if not self.deep_update_sync:
+ self.last_deep_update_ti = -self.deep_update_every
+ self.memory = MemoryManager(config=self.config)
+
+ def update_config(self, config):
+ self.mem_every = config['mem_every']
+ self.deep_update_every = config['deep_update_every']
+ self.enable_long_term = config['enable_long_term']
+
+ # if deep_update_every < 0, synchronize deep update with memory frame
+ self.deep_update_sync = (self.deep_update_every < 0)
+ self.memory.update_config(config)
+
+ def set_all_labels(self, all_labels):
+ # self.all_labels = [l.item() for l in all_labels]
+ self.all_labels = all_labels
+
+ def step(self, image, mask=None, valid_labels=None, end=False):
+ # image: 3*H*W
+ # mask: num_objects*H*W or None
+ self.curr_ti += 1
+ image, self.pad = pad_divide_by(image, 16)
+ image = image.unsqueeze(0) # add the batch dimension
+
+ is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end)
+ need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels)))
+ is_deep_update = (
+ (self.deep_update_sync and is_mem_frame) or # synchronized
+ (not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync
+ ) and (not end)
+ is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end)
+
+ key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image,
+ need_ek=(self.enable_long_term or need_segment),
+ need_sk=is_mem_frame)
+ multi_scale_features = (f16, f8, f4)
+
+ # segment the current frame is needed
+ if need_segment:
+ memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
+
+ hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
+ self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
+ # remove batch dim
+ pred_prob_with_bg = pred_prob_with_bg[0]
+ pred_prob_no_bg = pred_prob_with_bg[1:]
+
+ pred_logits_with_bg = pred_logits_with_bg[0]
+ pred_logits_no_bg = pred_logits_with_bg[1:]
+
+ if is_normal_update:
+ self.memory.set_hidden(hidden)
+ else:
+ pred_prob_no_bg = pred_prob_with_bg = pred_logits_with_bg = pred_logits_no_bg = None
+
+ # use the input mask if any
+ if mask is not None:
+ mask, _ = pad_divide_by(mask, 16)
+
+ if pred_prob_no_bg is not None:
+ # if we have a predicted mask, we work on it
+ # make pred_prob_no_bg consistent with the input mask
+ mask_regions = (mask.sum(0) > 0.5)
+ pred_prob_no_bg[:, mask_regions] = 0
+ # shift by 1 because mask/pred_prob_no_bg do not contain background
+ mask = mask.type_as(pred_prob_no_bg)
+ if valid_labels is not None:
+ shift_by_one_non_labels = [i for i in range(pred_prob_no_bg.shape[0]) if (i+1) not in valid_labels]
+ # non-labelled objects are copied from the predicted mask
+ mask[shift_by_one_non_labels] = pred_prob_no_bg[shift_by_one_non_labels]
+ pred_prob_with_bg = aggregate(mask, dim=0)
+
+ # also create new hidden states
+ self.memory.create_hidden_state(len(self.all_labels), key)
+
+ # save as memory if needed
+ if is_mem_frame:
+ value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(),
+ pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update)
+ self.memory.add_memory(key, shrinkage, value, self.all_labels,
+ selection=selection if self.enable_long_term else None)
+ self.last_mem_ti = self.curr_ti
+
+ if is_deep_update:
+ self.memory.set_hidden(hidden)
+ self.last_deep_update_ti = self.curr_ti
+
+ if pred_logits_with_bg is None:
+ return unpad(pred_prob_with_bg, self.pad), None
+ else:
+ return unpad(pred_prob_with_bg, self.pad), unpad(pred_logits_with_bg, self.pad)
diff --git a/tracker/inference/kv_memory_store.py b/tracker/inference/kv_memory_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e1113096c652ef8ce0504a4e8583007914e1957
--- /dev/null
+++ b/tracker/inference/kv_memory_store.py
@@ -0,0 +1,214 @@
+import torch
+from typing import List
+
+class KeyValueMemoryStore:
+ """
+ Works for key/value pairs type storage
+ e.g., working and long-term memory
+ """
+
+ """
+ An object group is created when new objects enter the video
+ Objects in the same group share the same temporal extent
+ i.e., objects initialized in the same frame are in the same group
+ For DAVIS/interactive, there is only one object group
+ For YouTubeVOS, there can be multiple object groups
+ """
+
+ def __init__(self, count_usage: bool):
+ self.count_usage = count_usage
+
+ # keys are stored in a single tensor and are shared between groups/objects
+ # values are stored as a list indexed by object groups
+ self.k = None
+ self.v = []
+ self.obj_groups = []
+ # for debugging only
+ self.all_objects = []
+
+ # shrinkage and selection are also single tensors
+ self.s = self.e = None
+
+ # usage
+ if self.count_usage:
+ self.use_count = self.life_count = None
+
+ def add(self, key, value, shrinkage, selection, objects: List[int]):
+ new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32)
+ new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7
+
+ # add the key
+ if self.k is None:
+ self.k = key
+ self.s = shrinkage
+ self.e = selection
+ if self.count_usage:
+ self.use_count = new_count
+ self.life_count = new_life
+ else:
+ self.k = torch.cat([self.k, key], -1)
+ if shrinkage is not None:
+ self.s = torch.cat([self.s, shrinkage], -1)
+ if selection is not None:
+ self.e = torch.cat([self.e, selection], -1)
+ if self.count_usage:
+ self.use_count = torch.cat([self.use_count, new_count], -1)
+ self.life_count = torch.cat([self.life_count, new_life], -1)
+
+ # add the value
+ if objects is not None:
+ # When objects is given, v is a tensor; used in working memory
+ assert isinstance(value, torch.Tensor)
+ # First consume objects that are already in the memory bank
+ # cannot use set here because we need to preserve order
+ # shift by one as background is not part of value
+ remaining_objects = [obj-1 for obj in objects]
+ for gi, group in enumerate(self.obj_groups):
+ for obj in group:
+ # should properly raise an error if there are overlaps in obj_groups
+ remaining_objects.remove(obj)
+ self.v[gi] = torch.cat([self.v[gi], value[group]], -1)
+
+ # If there are remaining objects, add them as a new group
+ if len(remaining_objects) > 0:
+ new_group = list(remaining_objects)
+ self.v.append(value[new_group])
+ self.obj_groups.append(new_group)
+ self.all_objects.extend(new_group)
+
+ assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order '
+ else:
+ # When objects is not given, v is a list that already has the object groups sorted
+ # used in long-term memory
+ assert isinstance(value, list)
+ for gi, gv in enumerate(value):
+ if gv is None:
+ continue
+ if gi < self.num_groups:
+ self.v[gi] = torch.cat([self.v[gi], gv], -1)
+ else:
+ self.v.append(gv)
+
+ def update_usage(self, usage):
+ # increase all life count by 1
+ # increase use of indexed elements
+ if not self.count_usage:
+ return
+
+ self.use_count += usage.view_as(self.use_count)
+ self.life_count += 1
+
+ def sieve_by_range(self, start: int, end: int, min_size: int):
+ # keep only the elements *outside* of this range (with some boundary conditions)
+ # i.e., concat (a[:start], a[end:])
+ # min_size is only used for values, we do not sieve values under this size
+ # (because they are not consolidated)
+
+ if end == 0:
+ # negative 0 would not work as the end index!
+ self.k = self.k[:,:,:start]
+ if self.count_usage:
+ self.use_count = self.use_count[:,:,:start]
+ self.life_count = self.life_count[:,:,:start]
+ if self.s is not None:
+ self.s = self.s[:,:,:start]
+ if self.e is not None:
+ self.e = self.e[:,:,:start]
+
+ for gi in range(self.num_groups):
+ if self.v[gi].shape[-1] >= min_size:
+ self.v[gi] = self.v[gi][:,:,:start]
+ else:
+ self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
+ if self.count_usage:
+ self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
+ self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
+ if self.s is not None:
+ self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
+ if self.e is not None:
+ self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
+
+ for gi in range(self.num_groups):
+ if self.v[gi].shape[-1] >= min_size:
+ self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
+
+ def remove_obsolete_features(self, max_size: int):
+ # normalize with life duration
+ usage = self.get_usage().flatten()
+
+ values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True)
+ survived = (usage > values[-1])
+
+ self.k = self.k[:, :, survived]
+ self.s = self.s[:, :, survived] if self.s is not None else None
+ # Long-term memory does not store ek so this should not be needed
+ self.e = self.e[:, :, survived] if self.e is not None else None
+ if self.num_groups > 1:
+ raise NotImplementedError("""The current data structure does not support feature removal with
+ multiple object groups (e.g., some objects start to appear later in the video)
+ The indices for "survived" is based on keys but not all values are present for every key
+ Basically we need to remap the indices for keys to values
+ """)
+ for gi in range(self.num_groups):
+ self.v[gi] = self.v[gi][:, :, survived]
+
+ self.use_count = self.use_count[:, :, survived]
+ self.life_count = self.life_count[:, :, survived]
+
+ def get_usage(self):
+ # return normalized usage
+ if not self.count_usage:
+ raise RuntimeError('I did not count usage!')
+ else:
+ usage = self.use_count / self.life_count
+ return usage
+
+ def get_all_sliced(self, start: int, end: int):
+ # return k, sk, ek, usage in order, sliced by start and end
+
+ if end == 0:
+ # negative 0 would not work as the end index!
+ k = self.k[:,:,start:]
+ sk = self.s[:,:,start:] if self.s is not None else None
+ ek = self.e[:,:,start:] if self.e is not None else None
+ usage = self.get_usage()[:,:,start:]
+ else:
+ k = self.k[:,:,start:end]
+ sk = self.s[:,:,start:end] if self.s is not None else None
+ ek = self.e[:,:,start:end] if self.e is not None else None
+ usage = self.get_usage()[:,:,start:end]
+
+ return k, sk, ek, usage
+
+ def get_v_size(self, ni: int):
+ return self.v[ni].shape[2]
+
+ def engaged(self):
+ return self.k is not None
+
+ @property
+ def size(self):
+ if self.k is None:
+ return 0
+ else:
+ return self.k.shape[-1]
+
+ @property
+ def num_groups(self):
+ return len(self.v)
+
+ @property
+ def key(self):
+ return self.k
+
+ @property
+ def value(self):
+ return self.v
+
+ @property
+ def shrinkage(self):
+ return self.s
+
+ @property
+ def selection(self):
+ return self.e
diff --git a/tracker/inference/memory_manager.py b/tracker/inference/memory_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..d47d96e400ba6050e6bb4325cdb21a1c3a25edc6
--- /dev/null
+++ b/tracker/inference/memory_manager.py
@@ -0,0 +1,286 @@
+import torch
+import warnings
+
+from inference.kv_memory_store import KeyValueMemoryStore
+from model.memory_util import *
+
+
+class MemoryManager:
+ """
+ Manages all three memory stores and the transition between working/long-term memory
+ """
+ def __init__(self, config):
+ self.hidden_dim = config['hidden_dim']
+ self.top_k = config['top_k']
+
+ self.enable_long_term = config['enable_long_term']
+ self.enable_long_term_usage = config['enable_long_term_count_usage']
+ if self.enable_long_term:
+ self.max_mt_frames = config['max_mid_term_frames']
+ self.min_mt_frames = config['min_mid_term_frames']
+ self.num_prototypes = config['num_prototypes']
+ self.max_long_elements = config['max_long_term_elements']
+
+ # dimensions will be inferred from input later
+ self.CK = self.CV = None
+ self.H = self.W = None
+
+ # The hidden state will be stored in a single tensor for all objects
+ # B x num_objects x CH x H x W
+ self.hidden = None
+
+ self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
+ if self.enable_long_term:
+ self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
+
+ self.reset_config = True
+
+ def update_config(self, config):
+ self.reset_config = True
+ self.hidden_dim = config['hidden_dim']
+ self.top_k = config['top_k']
+
+ assert self.enable_long_term == config['enable_long_term'], 'cannot update this'
+ assert self.enable_long_term_usage == config['enable_long_term_count_usage'], 'cannot update this'
+
+ self.enable_long_term_usage = config['enable_long_term_count_usage']
+ if self.enable_long_term:
+ self.max_mt_frames = config['max_mid_term_frames']
+ self.min_mt_frames = config['min_mid_term_frames']
+ self.num_prototypes = config['num_prototypes']
+ self.max_long_elements = config['max_long_term_elements']
+
+ def _readout(self, affinity, v):
+ # this function is for a single object group
+ return v @ affinity
+
+ def match_memory(self, query_key, selection):
+ # query_key: B x C^k x H x W
+ # selection: B x C^k x H x W
+ num_groups = self.work_mem.num_groups
+ h, w = query_key.shape[-2:]
+
+ query_key = query_key.flatten(start_dim=2)
+ selection = selection.flatten(start_dim=2) if selection is not None else None
+
+ """
+ Memory readout using keys
+ """
+
+ if self.enable_long_term and self.long_mem.engaged():
+ # Use long-term memory
+ long_mem_size = self.long_mem.size
+ memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1)
+ shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1)
+
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection)
+ work_mem_similarity = similarity[:, long_mem_size:]
+ long_mem_similarity = similarity[:, :long_mem_size]
+
+ # get the usage with the first group
+ # the first group always have all the keys valid
+ affinity, usage = do_softmax(
+ torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1),
+ top_k=self.top_k, inplace=True, return_usage=True)
+ affinity = [affinity]
+
+ # compute affinity group by group as later groups only have a subset of keys
+ for gi in range(1, num_groups):
+ if gi < self.long_mem.num_groups:
+ # merge working and lt similarities before softmax
+ affinity_one_group = do_softmax(
+ torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):],
+ work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1),
+ top_k=self.top_k, inplace=True)
+ else:
+ # no long-term memory for this group
+ affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):],
+ top_k=self.top_k, inplace=(gi==num_groups-1))
+ affinity.append(affinity_one_group)
+
+ all_memory_value = []
+ for gi, gv in enumerate(self.work_mem.value):
+ # merge the working and lt values before readout
+ if gi < self.long_mem.num_groups:
+ all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1))
+ else:
+ all_memory_value.append(gv)
+
+ """
+ Record memory usage for working and long-term memory
+ """
+ # ignore the index return for long-term memory
+ work_usage = usage[:, long_mem_size:]
+ self.work_mem.update_usage(work_usage.flatten())
+
+ if self.enable_long_term_usage:
+ # ignore the index return for working memory
+ long_usage = usage[:, :long_mem_size]
+ self.long_mem.update_usage(long_usage.flatten())
+ else:
+ # No long-term memory
+ similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection)
+
+ if self.enable_long_term:
+ affinity, usage = do_softmax(similarity, inplace=(num_groups==1),
+ top_k=self.top_k, return_usage=True)
+
+ # Record memory usage for working memory
+ self.work_mem.update_usage(usage.flatten())
+ else:
+ affinity = do_softmax(similarity, inplace=(num_groups==1),
+ top_k=self.top_k, return_usage=False)
+
+ affinity = [affinity]
+
+ # compute affinity group by group as later groups only have a subset of keys
+ for gi in range(1, num_groups):
+ affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):],
+ top_k=self.top_k, inplace=(gi==num_groups-1))
+ affinity.append(affinity_one_group)
+
+ all_memory_value = self.work_mem.value
+
+ # Shared affinity within each group
+ all_readout_mem = torch.cat([
+ self._readout(affinity[gi], gv)
+ for gi, gv in enumerate(all_memory_value)
+ ], 0)
+
+ return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
+
+ def add_memory(self, key, shrinkage, value, objects, selection=None):
+ # key: 1*C*H*W
+ # value: 1*num_objects*C*H*W
+ # objects contain a list of object indices
+ if self.H is None or self.reset_config:
+ self.reset_config = False
+ self.H, self.W = key.shape[-2:]
+ self.HW = self.H*self.W
+ if self.enable_long_term:
+ # convert from num. frames to num. nodes
+ self.min_work_elements = self.min_mt_frames*self.HW
+ self.max_work_elements = self.max_mt_frames*self.HW
+
+ # key: 1*C*N
+ # value: num_objects*C*N
+ key = key.flatten(start_dim=2)
+ shrinkage = shrinkage.flatten(start_dim=2)
+ value = value[0].flatten(start_dim=2)
+
+ self.CK = key.shape[1]
+ self.CV = value.shape[1]
+
+ if selection is not None:
+ if not self.enable_long_term:
+ warnings.warn('the selection factor is only needed in long-term mode', UserWarning)
+ selection = selection.flatten(start_dim=2)
+
+ self.work_mem.add(key, value, shrinkage, selection, objects)
+
+ # long-term memory cleanup
+ if self.enable_long_term:
+ # Do memory compressed if needed
+ if self.work_mem.size >= self.max_work_elements:
+ # print('remove memory')
+ # Remove obsolete features if needed
+ if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
+ self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
+
+ self.compress_features()
+
+ def create_hidden_state(self, n, sample_key):
+ # n is the TOTAL number of objects
+ h, w = sample_key.shape[-2:]
+ if self.hidden is None:
+ self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device)
+ elif self.hidden.shape[1] != n:
+ self.hidden = torch.cat([
+ self.hidden,
+ torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device)
+ ], 1)
+
+ assert(self.hidden.shape[1] == n)
+
+ def set_hidden(self, hidden):
+ self.hidden = hidden
+
+ def get_hidden(self):
+ return self.hidden
+
+ def compress_features(self):
+ HW = self.HW
+ candidate_value = []
+ total_work_mem_size = self.work_mem.size
+ for gv in self.work_mem.value:
+ # Some object groups might be added later in the video
+ # So not all keys have values associated with all objects
+ # We need to keep track of the key->value validity
+ mem_size_in_this_group = gv.shape[-1]
+ if mem_size_in_this_group == total_work_mem_size:
+ # full LT
+ candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
+ else:
+ # mem_size is smaller than total_work_mem_size, but at least HW
+ assert HW <= mem_size_in_this_group < total_work_mem_size
+ if mem_size_in_this_group > self.min_work_elements+HW:
+ # part of this object group still goes into LT
+ candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
+ else:
+ # this object group cannot go to the LT at all
+ candidate_value.append(None)
+
+ # perform memory consolidation
+ prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
+ *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value)
+
+ # remove consolidated working memory
+ self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW)
+
+ # add to long-term memory
+ self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
+ # print(f'long memory size: {self.long_mem.size}')
+ # print(f'work memory size: {self.work_mem.size}')
+
+ def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
+ # keys: 1*C*N
+ # values: num_objects*C*N
+ N = candidate_key.shape[-1]
+
+ # find the indices with max usage
+ _, max_usage_indices = torch.topk(usage, k=self.num_prototypes, dim=-1, sorted=True)
+ prototype_indices = max_usage_indices.flatten()
+
+ # Prototypes are invalid for out-of-bound groups
+ validity = [prototype_indices >= (N-gv.shape[2]) if gv is not None else None for gv in candidate_value]
+
+ prototype_key = candidate_key[:, :, prototype_indices]
+ prototype_selection = candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None
+
+ """
+ Potentiation step
+ """
+ similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, prototype_selection)
+
+ # convert similarity to affinity
+ # need to do it group by group since the softmax normalization would be different
+ affinity = [
+ do_softmax(similarity[:, -gv.shape[2]:, validity[gi]]) if gv is not None else None
+ for gi, gv in enumerate(candidate_value)
+ ]
+
+ # some values can be have all False validity. Weed them out.
+ affinity = [
+ aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
+ ]
+
+ # readout the values
+ prototype_value = [
+ self._readout(affinity[gi], gv) if affinity[gi] is not None else None
+ for gi, gv in enumerate(candidate_value)
+ ]
+
+ # readout the shrinkage term
+ prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None
+
+ return prototype_key, prototype_value, prototype_shrinkage
\ No newline at end of file
diff --git a/tracker/model/__init__.py b/tracker/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/model/aggregate.py b/tracker/model/aggregate.py
new file mode 100644
index 0000000000000000000000000000000000000000..7622391fb3ac9aa8b515df88cf3ea5297b367538
--- /dev/null
+++ b/tracker/model/aggregate.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn.functional as F
+
+
+# Soft aggregation from STM
+def aggregate(prob, dim, return_logits=False):
+ new_prob = torch.cat([
+ torch.prod(1-prob, dim=dim, keepdim=True),
+ prob
+ ], dim).clamp(1e-7, 1-1e-7)
+ logits = torch.log((new_prob /(1-new_prob)))
+ prob = F.softmax(logits, dim=dim)
+
+ if return_logits:
+ return logits, prob
+ else:
+ return prob
\ No newline at end of file
diff --git a/tracker/model/cbam.py b/tracker/model/cbam.py
new file mode 100644
index 0000000000000000000000000000000000000000..6423358429e2843b1f36ceb2bc1a485ea72b8eb4
--- /dev/null
+++ b/tracker/model/cbam.py
@@ -0,0 +1,77 @@
+# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class BasicConv(nn.Module):
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
+ super(BasicConv, self).__init__()
+ self.out_channels = out_planes
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+class Flatten(nn.Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+class ChannelGate(nn.Module):
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
+ super(ChannelGate, self).__init__()
+ self.gate_channels = gate_channels
+ self.mlp = nn.Sequential(
+ Flatten(),
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
+ nn.ReLU(),
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
+ )
+ self.pool_types = pool_types
+ def forward(self, x):
+ channel_att_sum = None
+ for pool_type in self.pool_types:
+ if pool_type=='avg':
+ avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+ channel_att_raw = self.mlp( avg_pool )
+ elif pool_type=='max':
+ max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+ channel_att_raw = self.mlp( max_pool )
+
+ if channel_att_sum is None:
+ channel_att_sum = channel_att_raw
+ else:
+ channel_att_sum = channel_att_sum + channel_att_raw
+
+ scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
+ return x * scale
+
+class ChannelPool(nn.Module):
+ def forward(self, x):
+ return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
+
+class SpatialGate(nn.Module):
+ def __init__(self):
+ super(SpatialGate, self).__init__()
+ kernel_size = 7
+ self.compress = ChannelPool()
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2)
+ def forward(self, x):
+ x_compress = self.compress(x)
+ x_out = self.spatial(x_compress)
+ scale = torch.sigmoid(x_out) # broadcasting
+ return x * scale
+
+class CBAM(nn.Module):
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
+ super(CBAM, self).__init__()
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
+ self.no_spatial=no_spatial
+ if not no_spatial:
+ self.SpatialGate = SpatialGate()
+ def forward(self, x):
+ x_out = self.ChannelGate(x)
+ if not self.no_spatial:
+ x_out = self.SpatialGate(x_out)
+ return x_out
diff --git a/tracker/model/group_modules.py b/tracker/model/group_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..749ef2386a992a468b7cf631293ebd22036b2777
--- /dev/null
+++ b/tracker/model/group_modules.py
@@ -0,0 +1,82 @@
+"""
+Group-specific modules
+They handle features that also depends on the mask.
+Features are typically of shape
+ batch_size * num_objects * num_channels * H * W
+
+All of them are permutation equivariant w.r.t. to the num_objects dimension
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def interpolate_groups(g, ratio, mode, align_corners):
+ batch_size, num_objects = g.shape[:2]
+ g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
+ scale_factor=ratio, mode=mode, align_corners=align_corners)
+ g = g.view(batch_size, num_objects, *g.shape[1:])
+ return g
+
+def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False):
+ return interpolate_groups(g, ratio, mode, align_corners)
+
+def downsample_groups(g, ratio=1/2, mode='area', align_corners=None):
+ return interpolate_groups(g, ratio, mode, align_corners)
+
+
+class GConv2D(nn.Conv2d):
+ def forward(self, g):
+ batch_size, num_objects = g.shape[:2]
+ g = super().forward(g.flatten(start_dim=0, end_dim=1))
+ return g.view(batch_size, num_objects, *g.shape[1:])
+
+
+class GroupResBlock(nn.Module):
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+
+ if in_dim == out_dim:
+ self.downsample = None
+ else:
+ self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
+
+ self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
+ self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1)
+
+ def forward(self, g):
+ out_g = self.conv1(F.relu(g))
+ out_g = self.conv2(F.relu(out_g))
+
+ if self.downsample is not None:
+ g = self.downsample(g)
+
+ return out_g + g
+
+
+class MainToGroupDistributor(nn.Module):
+ def __init__(self, x_transform=None, method='cat', reverse_order=False):
+ super().__init__()
+
+ self.x_transform = x_transform
+ self.method = method
+ self.reverse_order = reverse_order
+
+ def forward(self, x, g):
+ num_objects = g.shape[1]
+
+ if self.x_transform is not None:
+ x = self.x_transform(x)
+
+ if self.method == 'cat':
+ if self.reverse_order:
+ g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2)
+ else:
+ g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2)
+ elif self.method == 'add':
+ g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g
+ else:
+ raise NotImplementedError
+
+ return g
diff --git a/tracker/model/losses.py b/tracker/model/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..60a2894b6f5b330aa4baa56db226e8a59cb8c1ae
--- /dev/null
+++ b/tracker/model/losses.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from collections import defaultdict
+
+
+def dice_loss(input_mask, cls_gt):
+ num_objects = input_mask.shape[1]
+ losses = []
+ for i in range(num_objects):
+ mask = input_mask[:,i].flatten(start_dim=1)
+ # background not in mask, so we add one to cls_gt
+ gt = (cls_gt==(i+1)).float().flatten(start_dim=1)
+ numerator = 2 * (mask * gt).sum(-1)
+ denominator = mask.sum(-1) + gt.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ losses.append(loss)
+ return torch.cat(losses).mean()
+
+
+# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
+class BootstrappedCE(nn.Module):
+ def __init__(self, start_warm, end_warm, top_p=0.15):
+ super().__init__()
+
+ self.start_warm = start_warm
+ self.end_warm = end_warm
+ self.top_p = top_p
+
+ def forward(self, input, target, it):
+ if it < self.start_warm:
+ return F.cross_entropy(input, target), 1.0
+
+ raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
+ num_pixels = raw_loss.numel()
+
+ if it > self.end_warm:
+ this_p = self.top_p
+ else:
+ this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
+ loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
+ return loss.mean(), this_p
+
+
+class LossComputer:
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.bce = BootstrappedCE(config['start_warm'], config['end_warm'])
+
+ def compute(self, data, num_objects, it):
+ losses = defaultdict(int)
+
+ b, t = data['rgb'].shape[:2]
+
+ losses['total_loss'] = 0
+ for ti in range(1, t):
+ for bi in range(b):
+ loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it)
+ losses['p'] += p / b / (t-1)
+ losses[f'ce_loss_{ti}'] += loss / b
+
+ losses['total_loss'] += losses['ce_loss_%d'%ti]
+ losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0])
+ losses['total_loss'] += losses[f'dice_loss_{ti}']
+
+ return losses
diff --git a/tracker/model/memory_util.py b/tracker/model/memory_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..faf6197b8c4ea990317476e2e3aeb8952a78aedf
--- /dev/null
+++ b/tracker/model/memory_util.py
@@ -0,0 +1,80 @@
+import math
+import numpy as np
+import torch
+from typing import Optional
+
+
+def get_similarity(mk, ms, qk, qe):
+ # used for training/inference and memory reading/memory potentiation
+ # mk: B x CK x [N] - Memory keys
+ # ms: B x 1 x [N] - Memory shrinkage
+ # qk: B x CK x [HW/P] - Query keys
+ # qe: B x CK x [HW/P] - Query selection
+ # Dimensions in [] are flattened
+ CK = mk.shape[1]
+ mk = mk.flatten(start_dim=2)
+ ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
+ qk = qk.flatten(start_dim=2)
+ qe = qe.flatten(start_dim=2) if qe is not None else None
+
+ if qe is not None:
+ # See appendix for derivation
+ # or you can just trust me ヽ(ー_ー )ノ
+ mk = mk.transpose(1, 2)
+ a_sq = (mk.pow(2) @ qe)
+ two_ab = 2 * (mk @ (qk * qe))
+ b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
+ similarity = (-a_sq+two_ab-b_sq)
+ else:
+ # similar to STCN if we don't have the selection term
+ a_sq = mk.pow(2).sum(1).unsqueeze(2)
+ two_ab = 2 * (mk.transpose(1, 2) @ qk)
+ similarity = (-a_sq+two_ab)
+
+ if ms is not None:
+ similarity = similarity * ms / math.sqrt(CK) # B*N*HW
+ else:
+ similarity = similarity / math.sqrt(CK) # B*N*HW
+
+ return similarity
+
+def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False):
+ # normalize similarity with top-k softmax
+ # similarity: B x N x [HW/P]
+ # use inplace with care
+ if top_k is not None:
+ values, indices = torch.topk(similarity, k=top_k, dim=1)
+
+ x_exp = values.exp_()
+ x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
+ if inplace:
+ similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
+ affinity = similarity
+ else:
+ affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
+ else:
+ maxes = torch.max(similarity, dim=1, keepdim=True)[0]
+ x_exp = torch.exp(similarity - maxes)
+ x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
+ affinity = x_exp / x_exp_sum
+ indices = None
+
+ if return_usage:
+ return affinity, affinity.sum(dim=2)
+
+ return affinity
+
+def get_affinity(mk, ms, qk, qe):
+ # shorthand used in training with no top-k
+ similarity = get_similarity(mk, ms, qk, qe)
+ affinity = do_softmax(similarity)
+ return affinity
+
+def readout(affinity, mv):
+ B, CV, T, H, W = mv.shape
+
+ mo = mv.view(B, CV, T*H*W)
+ mem = torch.bmm(mo, affinity)
+ mem = mem.view(B, CV, H, W)
+
+ return mem
diff --git a/tracker/model/modules.py b/tracker/model/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..99207996e6d68dcf74da314dbd7cce21f65ac71e
--- /dev/null
+++ b/tracker/model/modules.py
@@ -0,0 +1,250 @@
+"""
+modules.py - This file stores the rather boring network blocks.
+
+x - usually means features that only depends on the image
+g - usually means features that also depends on the mask.
+ They might have an extra "group" or "num_objects" dimension, hence
+ batch_size * num_objects * num_channels * H * W
+
+The trailing number of a variable usually denote the stride
+
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from model.group_modules import *
+from model import resnet
+from model.cbam import CBAM
+
+
+class FeatureFusionBlock(nn.Module):
+ def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
+ super().__init__()
+
+ self.distributor = MainToGroupDistributor()
+ self.block1 = GroupResBlock(x_in_dim+g_in_dim, g_mid_dim)
+ self.attention = CBAM(g_mid_dim)
+ self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
+
+ def forward(self, x, g):
+ batch_size, num_objects = g.shape[:2]
+
+ g = self.distributor(x, g)
+ g = self.block1(g)
+ r = self.attention(g.flatten(start_dim=0, end_dim=1))
+ r = r.view(batch_size, num_objects, *r.shape[1:])
+
+ g = self.block2(g+r)
+
+ return g
+
+
+class HiddenUpdater(nn.Module):
+ # Used in the decoder, multi-scale feature + GRU
+ def __init__(self, g_dims, mid_dim, hidden_dim):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+
+ self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
+ self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
+ self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
+
+ self.transform = GConv2D(mid_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1)
+
+ nn.init.xavier_normal_(self.transform.weight)
+
+ def forward(self, g, h):
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
+ self.g4_conv(downsample_groups(g[2], ratio=1/4))
+
+ g = torch.cat([g, h], 2)
+
+ # defined slightly differently than standard GRU,
+ # namely the new value is generated before the forget gate.
+ # might provide better gradient but frankly it was initially just an
+ # implementation error that I never bothered fixing
+ values = self.transform(g)
+ forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim])
+ update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2])
+ new_value = torch.tanh(values[:,:,self.hidden_dim*2:])
+ new_h = forget_gate*h*(1-update_gate) + update_gate*new_value
+
+ return new_h
+
+
+class HiddenReinforcer(nn.Module):
+ # Used in the value encoder, a single GRU
+ def __init__(self, g_dim, hidden_dim):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.transform = GConv2D(g_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1)
+
+ nn.init.xavier_normal_(self.transform.weight)
+
+ def forward(self, g, h):
+ g = torch.cat([g, h], 2)
+
+ # defined slightly differently than standard GRU,
+ # namely the new value is generated before the forget gate.
+ # might provide better gradient but frankly it was initially just an
+ # implementation error that I never bothered fixing
+ values = self.transform(g)
+ forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim])
+ update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2])
+ new_value = torch.tanh(values[:,:,self.hidden_dim*2:])
+ new_h = forget_gate*h*(1-update_gate) + update_gate*new_value
+
+ return new_h
+
+
+class ValueEncoder(nn.Module):
+ def __init__(self, value_dim, hidden_dim, single_object=False):
+ super().__init__()
+
+ self.single_object = single_object
+ network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2)
+ self.conv1 = network.conv1
+ self.bn1 = network.bn1
+ self.relu = network.relu # 1/2, 64
+ self.maxpool = network.maxpool
+
+ self.layer1 = network.layer1 # 1/4, 64
+ self.layer2 = network.layer2 # 1/8, 128
+ self.layer3 = network.layer3 # 1/16, 256
+
+ self.distributor = MainToGroupDistributor()
+ self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim)
+ if hidden_dim > 0:
+ self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim)
+ else:
+ self.hidden_reinforce = None
+
+ def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True):
+ # image_feat_f16 is the feature from the key encoder
+ if not self.single_object:
+ g = torch.stack([masks, others], 2)
+ else:
+ g = masks.unsqueeze(2)
+ g = self.distributor(image, g)
+
+ batch_size, num_objects = g.shape[:2]
+ g = g.flatten(start_dim=0, end_dim=1)
+
+ g = self.conv1(g)
+ g = self.bn1(g) # 1/2, 64
+ g = self.maxpool(g) # 1/4, 64
+ g = self.relu(g)
+
+ g = self.layer1(g) # 1/4
+ g = self.layer2(g) # 1/8
+ g = self.layer3(g) # 1/16
+
+ g = g.view(batch_size, num_objects, *g.shape[1:])
+ g = self.fuser(image_feat_f16, g)
+
+ if is_deep_update and self.hidden_reinforce is not None:
+ h = self.hidden_reinforce(g, h)
+
+ return g, h
+
+
+class KeyEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ network = resnet.resnet50(pretrained=True)
+ self.conv1 = network.conv1
+ self.bn1 = network.bn1
+ self.relu = network.relu # 1/2, 64
+ self.maxpool = network.maxpool
+
+ self.res2 = network.layer1 # 1/4, 256
+ self.layer2 = network.layer2 # 1/8, 512
+ self.layer3 = network.layer3 # 1/16, 1024
+
+ def forward(self, f):
+ x = self.conv1(f)
+ x = self.bn1(x)
+ x = self.relu(x) # 1/2, 64
+ x = self.maxpool(x) # 1/4, 64
+ f4 = self.res2(x) # 1/4, 256
+ f8 = self.layer2(f4) # 1/8, 512
+ f16 = self.layer3(f8) # 1/16, 1024
+
+ return f16, f8, f4
+
+
+class UpsampleBlock(nn.Module):
+ def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
+ super().__init__()
+ self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1)
+ self.distributor = MainToGroupDistributor(method='add')
+ self.out_conv = GroupResBlock(g_up_dim, g_out_dim)
+ self.scale_factor = scale_factor
+
+ def forward(self, skip_f, up_g):
+ skip_f = self.skip_conv(skip_f)
+ g = upsample_groups(up_g, ratio=self.scale_factor)
+ g = self.distributor(skip_f, g)
+ g = self.out_conv(g)
+ return g
+
+
+class KeyProjection(nn.Module):
+ def __init__(self, in_dim, keydim):
+ super().__init__()
+
+ self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
+ # shrinkage
+ self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
+ # selection
+ self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
+
+ nn.init.orthogonal_(self.key_proj.weight.data)
+ nn.init.zeros_(self.key_proj.bias.data)
+
+ def forward(self, x, need_s, need_e):
+ shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
+ selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
+
+ return self.key_proj(x), shrinkage, selection
+
+
+class Decoder(nn.Module):
+ def __init__(self, val_dim, hidden_dim):
+ super().__init__()
+
+ self.fuser = FeatureFusionBlock(1024, val_dim+hidden_dim, 512, 512)
+ if hidden_dim > 0:
+ self.hidden_update = HiddenUpdater([512, 256, 256+1], 256, hidden_dim)
+ else:
+ self.hidden_update = None
+
+ self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
+ self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
+
+ self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1)
+
+ def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
+ batch_size, num_objects = memory_readout.shape[:2]
+
+ if self.hidden_update is not None:
+ g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2))
+ else:
+ g16 = self.fuser(f16, memory_readout)
+
+ g8 = self.up_16_8(f8, g16)
+ g4 = self.up_8_4(f4, g8)
+ logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1)))
+
+ if h_out and self.hidden_update is not None:
+ g4 = torch.cat([g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2)
+ hidden_state = self.hidden_update([g16, g8, g4], hidden_state)
+ else:
+ hidden_state = None
+
+ logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False)
+ logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
+
+ return hidden_state, logits
diff --git a/tracker/model/network.py b/tracker/model/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5f179db17ac424ffee2951ade3934e08cd6276a
--- /dev/null
+++ b/tracker/model/network.py
@@ -0,0 +1,198 @@
+"""
+This file defines XMem, the highest level nn.Module interface
+During training, it is used by trainer.py
+During evaluation, it is used by inference_core.py
+
+It further depends on modules.py which gives more detailed implementations of sub-modules
+"""
+
+import torch
+import torch.nn as nn
+
+from model.aggregate import aggregate
+from model.modules import *
+from model.memory_util import *
+
+
+class XMem(nn.Module):
+ def __init__(self, config, model_path=None, map_location=None):
+ """
+ model_path/map_location are used in evaluation only
+ map_location is for converting models saved in cuda to cpu
+ """
+ super().__init__()
+ model_weights = self.init_hyperparameters(config, model_path, map_location)
+
+ self.single_object = config.get('single_object', False)
+ print(f'Single object mode: {self.single_object}')
+
+ self.key_encoder = KeyEncoder()
+ self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object)
+
+ # Projection from f16 feature space to key/value space
+ self.key_proj = KeyProjection(1024, self.key_dim)
+
+ self.decoder = Decoder(self.value_dim, self.hidden_dim)
+
+ if model_weights is not None:
+ self.load_weights(model_weights, init_as_zero_if_needed=True)
+
+ def encode_key(self, frame, need_sk=True, need_ek=True):
+ # Determine input shape
+ if len(frame.shape) == 5:
+ # shape is b*t*c*h*w
+ need_reshape = True
+ b, t = frame.shape[:2]
+ # flatten so that we can feed them into a 2D CNN
+ frame = frame.flatten(start_dim=0, end_dim=1)
+ elif len(frame.shape) == 4:
+ # shape is b*c*h*w
+ need_reshape = False
+ else:
+ raise NotImplementedError
+
+ f16, f8, f4 = self.key_encoder(frame)
+ key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
+
+ if need_reshape:
+ # B*C*T*H*W
+ key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
+ if shrinkage is not None:
+ shrinkage = shrinkage.view(b, t, *shrinkage.shape[-3:]).transpose(1, 2).contiguous()
+ if selection is not None:
+ selection = selection.view(b, t, *selection.shape[-3:]).transpose(1, 2).contiguous()
+
+ # B*T*C*H*W
+ f16 = f16.view(b, t, *f16.shape[-3:])
+ f8 = f8.view(b, t, *f8.shape[-3:])
+ f4 = f4.view(b, t, *f4.shape[-3:])
+
+ return key, shrinkage, selection, f16, f8, f4
+
+ def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
+ num_objects = masks.shape[1]
+ if num_objects != 1:
+ others = torch.cat([
+ torch.sum(
+ masks[:, [j for j in range(num_objects) if i!=j]]
+ , dim=1, keepdim=True)
+ for i in range(num_objects)], 1)
+ else:
+ others = torch.zeros_like(masks)
+
+ g16, h16 = self.value_encoder(frame, image_feat_f16, h16, masks, others, is_deep_update)
+
+ return g16, h16
+
+ # Used in training only.
+ # This step is replaced by MemoryManager in test time
+ def read_memory(self, query_key, query_selection, memory_key,
+ memory_shrinkage, memory_value):
+ """
+ query_key : B * CK * H * W
+ query_selection : B * CK * H * W
+ memory_key : B * CK * T * H * W
+ memory_shrinkage: B * 1 * T * H * W
+ memory_value : B * num_objects * CV * T * H * W
+ """
+ batch_size, num_objects = memory_value.shape[:2]
+ memory_value = memory_value.flatten(start_dim=1, end_dim=2)
+
+ affinity = get_affinity(memory_key, memory_shrinkage, query_key, query_selection)
+ memory = readout(affinity, memory_value)
+ memory = memory.view(batch_size, num_objects, self.value_dim, *memory.shape[-2:])
+
+ return memory
+
+ def segment(self, multi_scale_features, memory_readout,
+ hidden_state, selector=None, h_out=True, strip_bg=True):
+
+ hidden_state, logits = self.decoder(*multi_scale_features, hidden_state, memory_readout, h_out=h_out)
+ prob = torch.sigmoid(logits)
+ if selector is not None:
+ prob = prob * selector
+
+ logits, prob = aggregate(prob, dim=1, return_logits=True)
+ if strip_bg:
+ # Strip away the background
+ prob = prob[:, 1:]
+
+ return hidden_state, logits, prob
+
+ def forward(self, mode, *args, **kwargs):
+ if mode == 'encode_key':
+ return self.encode_key(*args, **kwargs)
+ elif mode == 'encode_value':
+ return self.encode_value(*args, **kwargs)
+ elif mode == 'read_memory':
+ return self.read_memory(*args, **kwargs)
+ elif mode == 'segment':
+ return self.segment(*args, **kwargs)
+ else:
+ raise NotImplementedError
+
+ def init_hyperparameters(self, config, model_path=None, map_location=None):
+ """
+ Init three hyperparameters: key_dim, value_dim, and hidden_dim
+ If model_path is provided, we load these from the model weights
+ The actual parameters are then updated to the config in-place
+
+ Otherwise we load it either from the config or default
+ """
+ if model_path is not None:
+ # load the model and key/value/hidden dimensions with some hacks
+ # config is updated with the loaded parameters
+ model_weights = torch.load(model_path, map_location=map_location)
+ self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
+ self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
+ self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights
+ if self.disable_hidden:
+ self.hidden_dim = 0
+ else:
+ self.hidden_dim = model_weights['decoder.hidden_update.transform.weight'].shape[0]//3
+ print(f'Hyperparameters read from the model weights: '
+ f'C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}')
+ else:
+ model_weights = None
+ # load dimensions from config or default
+ if 'key_dim' not in config:
+ self.key_dim = 64
+ print(f'key_dim not found in config. Set to default {self.key_dim}')
+ else:
+ self.key_dim = config['key_dim']
+
+ if 'value_dim' not in config:
+ self.value_dim = 512
+ print(f'value_dim not found in config. Set to default {self.value_dim}')
+ else:
+ self.value_dim = config['value_dim']
+
+ if 'hidden_dim' not in config:
+ self.hidden_dim = 64
+ print(f'hidden_dim not found in config. Set to default {self.hidden_dim}')
+ else:
+ self.hidden_dim = config['hidden_dim']
+
+ self.disable_hidden = (self.hidden_dim <= 0)
+
+ config['key_dim'] = self.key_dim
+ config['value_dim'] = self.value_dim
+ config['hidden_dim'] = self.hidden_dim
+
+ return model_weights
+
+ def load_weights(self, src_dict, init_as_zero_if_needed=False):
+ # Maps SO weight (without other_mask) to MO weight (with other_mask)
+ for k in list(src_dict.keys()):
+ if k == 'value_encoder.conv1.weight':
+ if src_dict[k].shape[1] == 4:
+ print('Converting weights from single object to multiple objects.')
+ pads = torch.zeros((64,1,7,7), device=src_dict[k].device)
+ if not init_as_zero_if_needed:
+ print('Randomly initialized padding.')
+ nn.init.orthogonal_(pads)
+ else:
+ print('Zero-initialized padding.')
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
+
+ self.load_state_dict(src_dict)
diff --git a/tracker/model/resnet.py b/tracker/model/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..984ea3cbfac047537e7de6cfc47108e637e9dde7
--- /dev/null
+++ b/tracker/model/resnet.py
@@ -0,0 +1,165 @@
+"""
+resnet.py - A modified ResNet structure
+We append extra channels to the first conv by some network surgery
+"""
+
+from collections import OrderedDict
+import math
+
+import torch
+import torch.nn as nn
+from torch.utils import model_zoo
+
+
+def load_weights_add_extra_dim(target, source_state, extra_dim=1):
+ new_dict = OrderedDict()
+
+ for k1, v1 in target.state_dict().items():
+ if not 'num_batches_tracked' in k1:
+ if k1 in source_state:
+ tar_v = source_state[k1]
+
+ if v1.shape != tar_v.shape:
+ # Init the new segmentation channel with zeros
+ # print(v1.shape, tar_v.shape)
+ c, _, w, h = v1.shape
+ pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device)
+ nn.init.orthogonal_(pads)
+ tar_v = torch.cat([tar_v, pads], 1)
+
+ new_dict[k1] = tar_v
+
+ target.load_state_dict(new_dict)
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, dilation=dilation, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
+ padding=dilation, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = [block(self.inplanes, planes, stride, downsample)]
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, dilation=dilation))
+
+ return nn.Sequential(*layers)
+
+def resnet18(pretrained=True, extra_dim=0):
+ model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
+ if pretrained:
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
+ return model
+
+def resnet50(pretrained=True, extra_dim=0):
+ model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
+ if pretrained:
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
+ return model
+
diff --git a/tracker/model/trainer.py b/tracker/model/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..05b4e191a1a9f71db5ef904b275ef5077e8cc7c0
--- /dev/null
+++ b/tracker/model/trainer.py
@@ -0,0 +1,244 @@
+"""
+trainer.py - warpper and utility functions for network training
+Compute loss, back-prop, update parameters, logging, etc.
+"""
+import datetime
+import os
+import time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from model.network import XMem
+from model.losses import LossComputer
+from util.log_integrator import Integrator
+from util.image_saver import pool_pairs
+
+
+class XMemTrainer:
+ def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1):
+ self.config = config
+ self.num_frames = config['num_frames']
+ self.num_ref_frames = config['num_ref_frames']
+ self.deep_update_prob = config['deep_update_prob']
+ self.local_rank = local_rank
+
+ self.XMem = nn.parallel.DistributedDataParallel(
+ XMem(config).cuda(),
+ device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False)
+
+ # Set up logger when local_rank=0
+ self.logger = logger
+ self.save_path = save_path
+ if logger is not None:
+ self.last_time = time.time()
+ self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()])))
+ self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size)
+ self.loss_computer = LossComputer(config)
+
+ self.train()
+ self.optimizer = optim.AdamW(filter(
+ lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay'])
+ self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma'])
+ if config['amp']:
+ self.scaler = torch.cuda.amp.GradScaler()
+
+ # Logging info
+ self.log_text_interval = config['log_text_interval']
+ self.log_image_interval = config['log_image_interval']
+ self.save_network_interval = config['save_network_interval']
+ self.save_checkpoint_interval = config['save_checkpoint_interval']
+ if config['debug']:
+ self.log_text_interval = self.log_image_interval = 1
+
+ def do_pass(self, data, max_it, it=0):
+ # No need to store the gradient outside training
+ torch.set_grad_enabled(self._is_train)
+
+ for k, v in data.items():
+ if type(v) != list and type(v) != dict and type(v) != int:
+ data[k] = v.cuda(non_blocking=True)
+
+ out = {}
+ frames = data['rgb']
+ first_frame_gt = data['first_frame_gt'].float()
+ b = frames.shape[0]
+ num_filled_objects = [o.item() for o in data['info']['num_objects']]
+ num_objects = first_frame_gt.shape[2]
+ selector = data['selector'].unsqueeze(2).unsqueeze(2)
+
+ global_avg = 0
+
+ with torch.cuda.amp.autocast(enabled=self.config['amp']):
+ # image features never change, compute once
+ key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames)
+
+ filler_one = torch.zeros(1, dtype=torch.int64)
+ hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:]))
+ v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0])
+ values = v16.unsqueeze(3) # add the time dimension
+
+ for ti in range(1, self.num_frames):
+ if ti <= self.num_ref_frames:
+ ref_values = values
+ ref_keys = key[:,:,:ti]
+ ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None
+ else:
+ # pick num_ref_frames random frames
+ # this is not very efficient but I think we would
+ # need broadcasting in gather which we don't have
+ indices = [
+ torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1])
+ for _ in range(b)]
+ ref_values = torch.stack([
+ values[bi, :, :, indices[bi]] for bi in range(b)
+ ], 0)
+ ref_keys = torch.stack([
+ key[bi, :, indices[bi]] for bi in range(b)
+ ], 0)
+ ref_shrinkage = torch.stack([
+ shrinkage[bi, :, indices[bi]] for bi in range(b)
+ ], 0) if shrinkage is not None else None
+
+ # Segment frame ti
+ memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None,
+ ref_keys, ref_shrinkage, ref_values)
+ hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout,
+ hidden, selector, h_out=(ti < (self.num_frames-1)))
+
+ # No need to encode the last frame
+ if ti < (self.num_frames-1):
+ is_deep_update = np.random.rand() < self.deep_update_prob
+ v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update)
+ values = torch.cat([values, v16.unsqueeze(3)], 3)
+
+ out[f'masks_{ti}'] = masks
+ out[f'logits_{ti}'] = logits
+
+ if self._do_log or self._is_train:
+ losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it)
+
+ # Logging
+ if self._do_log:
+ self.integrator.add_dict(losses)
+ if self._is_train:
+ if it % self.log_image_interval == 0 and it != 0:
+ if self.logger is not None:
+ images = {**data, **out}
+ size = (384, 384)
+ self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it)
+
+ if self._is_train:
+
+ if (it) % self.log_text_interval == 0 and it != 0:
+ time_spent = time.time()-self.last_time
+
+ if self.logger is not None:
+ self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it)
+ self.logger.log_metrics('train', 'time', (time_spent)/self.log_text_interval, it)
+
+ global_avg = 0.5*(global_avg) + 0.5*(time_spent)
+ eta_seconds = global_avg * (max_it - it) / 100
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ print(f'ETA: {eta_string}')
+
+ self.last_time = time.time()
+ self.train_integrator.finalize('train', it)
+ self.train_integrator.reset_except_hooks()
+
+ if it % self.save_network_interval == 0 and it != 0:
+ if self.logger is not None:
+ self.save_network(it)
+
+ if it % self.save_checkpoint_interval == 0 and it != 0:
+ if self.logger is not None:
+ self.save_checkpoint(it)
+
+ # Backward pass
+ self.optimizer.zero_grad(set_to_none=True)
+ if self.config['amp']:
+ self.scaler.scale(losses['total_loss']).backward()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ else:
+ losses['total_loss'].backward()
+ self.optimizer.step()
+
+ self.scheduler.step()
+
+ def save_network(self, it):
+ if self.save_path is None:
+ print('Saving has been disabled.')
+ return
+
+ os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
+ model_path = f'{self.save_path}_{it}.pth'
+ torch.save(self.XMem.module.state_dict(), model_path)
+ print(f'Network saved to {model_path}.')
+
+ def save_checkpoint(self, it):
+ if self.save_path is None:
+ print('Saving has been disabled.')
+ return
+
+ os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
+ checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth'
+ checkpoint = {
+ 'it': it,
+ 'network': self.XMem.module.state_dict(),
+ 'optimizer': self.optimizer.state_dict(),
+ 'scheduler': self.scheduler.state_dict()}
+ torch.save(checkpoint, checkpoint_path)
+ print(f'Checkpoint saved to {checkpoint_path}.')
+
+ def load_checkpoint(self, path):
+ # This method loads everything and should be used to resume training
+ map_location = 'cuda:%d' % self.local_rank
+ checkpoint = torch.load(path, map_location={'cuda:0': map_location})
+
+ it = checkpoint['it']
+ network = checkpoint['network']
+ optimizer = checkpoint['optimizer']
+ scheduler = checkpoint['scheduler']
+
+ map_location = 'cuda:%d' % self.local_rank
+ self.XMem.module.load_state_dict(network)
+ self.optimizer.load_state_dict(optimizer)
+ self.scheduler.load_state_dict(scheduler)
+
+ print('Network weights, optimizer states, and scheduler states loaded.')
+
+ return it
+
+ def load_network_in_memory(self, src_dict):
+ self.XMem.module.load_weights(src_dict)
+ print('Network weight loaded from memory.')
+
+ def load_network(self, path):
+ # This method loads only the network weight and should be used to load a pretrained model
+ map_location = 'cuda:%d' % self.local_rank
+ src_dict = torch.load(path, map_location={'cuda:0': map_location})
+
+ self.load_network_in_memory(src_dict)
+ print(f'Network weight loaded from {path}')
+
+ def train(self):
+ self._is_train = True
+ self._do_log = True
+ self.integrator = self.train_integrator
+ self.XMem.eval()
+ return self
+
+ def val(self):
+ self._is_train = False
+ self._do_log = True
+ self.XMem.eval()
+ return self
+
+ def test(self):
+ self._is_train = False
+ self._do_log = False
+ self.XMem.eval()
+ return self
+
diff --git a/tracker/util/__init__.py b/tracker/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/util/mask_mapper.py b/tracker/util/mask_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..815807bf4b98c6674ab3ede55517f38a29bb59fb
--- /dev/null
+++ b/tracker/util/mask_mapper.py
@@ -0,0 +1,78 @@
+import numpy as np
+import torch
+
+def all_to_onehot(masks, labels):
+ if len(masks.shape) == 3:
+ Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
+ else:
+ Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
+
+ for ni, l in enumerate(labels):
+ Ms[ni] = (masks == l).astype(np.uint8)
+
+ return Ms
+
+class MaskMapper:
+ """
+ This class is used to convert a indexed-mask to a one-hot representation.
+ It also takes care of remapping non-continuous indices
+ It has two modes:
+ 1. Default. Only masks with new indices are supposed to go into the remapper.
+ This is also the case for YouTubeVOS.
+ i.e., regions with index 0 are not "background", but "don't care".
+
+ 2. Exhaustive. Regions with index 0 are considered "background".
+ Every single pixel is considered to be "labeled".
+ """
+ def __init__(self):
+ self.labels = []
+ self.remappings = {}
+
+ # if coherent, no mapping is required
+ self.coherent = True
+
+ def clear_labels(self):
+ self.labels = []
+ self.remappings = {}
+ # if coherent, no mapping is required
+ self.coherent = True
+
+ def convert_mask(self, mask, exhaustive=False):
+ # mask is in index representation, H*W numpy array
+ labels = np.unique(mask).astype(np.uint8)
+ labels = labels[labels!=0].tolist()
+
+ new_labels = list(set(labels) - set(self.labels))
+ if not exhaustive:
+ assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
+
+ # add new remappings
+ for i, l in enumerate(new_labels):
+ self.remappings[l] = i+len(self.labels)+1
+ if self.coherent and i+len(self.labels)+1 != l:
+ self.coherent = False
+
+ if exhaustive:
+ new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
+ else:
+ if self.coherent:
+ new_mapped_labels = new_labels
+ else:
+ new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)
+
+ self.labels.extend(new_labels)
+ mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
+
+ # mask num_objects*H*W
+ return mask, new_mapped_labels
+
+
+ def remap_index_mask(self, mask):
+ # mask is in index representation, H*W numpy array
+ if self.coherent:
+ return mask
+
+ new_mask = np.zeros_like(mask)
+ for l, i in self.remappings.items():
+ new_mask[mask==i] = l
+ return new_mask
\ No newline at end of file
diff --git a/tracker/util/range_transform.py b/tracker/util/range_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1b0b3b2a01a061b9b2220a93cdf7f7a6357bfb
--- /dev/null
+++ b/tracker/util/range_transform.py
@@ -0,0 +1,12 @@
+import torchvision.transforms as transforms
+
+im_mean = (124, 116, 104)
+
+im_normalization = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+
+inv_im_trans = transforms.Normalize(
+ mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
+ std=[1/0.229, 1/0.224, 1/0.225])
diff --git a/tracker/util/tensor_util.py b/tracker/util/tensor_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..05189d38e2b0b0d1d08bd7804b8e43418d6da637
--- /dev/null
+++ b/tracker/util/tensor_util.py
@@ -0,0 +1,47 @@
+import torch.nn.functional as F
+
+
+def compute_tensor_iu(seg, gt):
+ intersection = (seg & gt).float().sum()
+ union = (seg | gt).float().sum()
+
+ return intersection, union
+
+def compute_tensor_iou(seg, gt):
+ intersection, union = compute_tensor_iu(seg, gt)
+ iou = (intersection + 1e-6) / (union + 1e-6)
+
+ return iou
+
+# STM
+def pad_divide_by(in_img, d):
+ h, w = in_img.shape[-2:]
+
+ if h % d > 0:
+ new_h = h + d - h % d
+ else:
+ new_h = h
+ if w % d > 0:
+ new_w = w + d - w % d
+ else:
+ new_w = w
+ lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
+ lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
+ pad_array = (int(lw), int(uw), int(lh), int(uh))
+ out = F.pad(in_img, pad_array)
+ return out, pad_array
+
+def unpad(img, pad):
+ if len(img.shape) == 4:
+ if pad[2]+pad[3] > 0:
+ img = img[:,:,pad[2]:-pad[3],:]
+ if pad[0]+pad[1] > 0:
+ img = img[:,:,:,pad[0]:-pad[1]]
+ elif len(img.shape) == 3:
+ if pad[2]+pad[3] > 0:
+ img = img[:,pad[2]:-pad[3],:]
+ if pad[0]+pad[1] > 0:
+ img = img[:,:,pad[0]:-pad[1]]
+ else:
+ raise NotImplementedError
+ return img
\ No newline at end of file