EUNPYOHONG watchtowerss commited on
Commit
7d90e8e
·
0 Parent(s):

Duplicate from VIPLab/Track-Anything

Browse files

Co-authored-by: zhe li <watchtowerss@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +46 -0
  2. LICENSE +21 -0
  3. README.md +94 -0
  4. XMem-s012.pth +3 -0
  5. app.py +665 -0
  6. app_save.py +381 -0
  7. app_test.py +46 -0
  8. assets/avengers.gif +3 -0
  9. assets/demo_version_1.MP4 +3 -0
  10. assets/inpainting.gif +3 -0
  11. assets/poster_demo_version_1.png +0 -0
  12. assets/qingming.mp4 +3 -0
  13. assets/track-anything-logo.jpg +0 -0
  14. checkpoints/E2FGVI-HQ-CVPR22.pth +3 -0
  15. demo.py +87 -0
  16. images/groceries.jpg +0 -0
  17. images/mask_painter.png +0 -0
  18. images/painter_input_image.jpg +0 -0
  19. images/painter_input_mask.jpg +0 -0
  20. images/painter_output_image.png +0 -0
  21. images/painter_output_image__.png +0 -0
  22. images/point_painter.png +0 -0
  23. images/point_painter_1.png +0 -0
  24. images/point_painter_2.png +0 -0
  25. images/truck.jpg +0 -0
  26. images/truck_both.jpg +0 -0
  27. images/truck_mask.jpg +0 -0
  28. images/truck_point.jpg +0 -0
  29. inpainter/.DS_Store +0 -0
  30. inpainter/base_inpainter.py +287 -0
  31. inpainter/config/config.yaml +4 -0
  32. inpainter/model/e2fgvi.py +350 -0
  33. inpainter/model/e2fgvi_hq.py +350 -0
  34. inpainter/model/modules/feat_prop.py +149 -0
  35. inpainter/model/modules/flow_comp.py +450 -0
  36. inpainter/model/modules/spectral_norm.py +288 -0
  37. inpainter/model/modules/tfocal_transformer.py +536 -0
  38. inpainter/model/modules/tfocal_transformer_hq.py +567 -0
  39. inpainter/util/__init__.py +0 -0
  40. inpainter/util/tensor_util.py +24 -0
  41. overleaf/.DS_Store +0 -0
  42. overleaf/Track Anything.zip +3 -0
  43. overleaf/Track Anything/figs/avengers_1.pdf +3 -0
  44. overleaf/Track Anything/figs/davisresults.pdf +3 -0
  45. overleaf/Track Anything/figs/failedcases.pdf +3 -0
  46. overleaf/Track Anything/figs/overview_4.pdf +0 -0
  47. overleaf/Track Anything/neurips_2022.bbl +105 -0
  48. overleaf/Track Anything/neurips_2022.bib +187 -0
  49. overleaf/Track Anything/neurips_2022.sty +381 -0
  50. overleaf/Track Anything/neurips_2022.tex +378 -0
.gitattributes ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
36
+ assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ test_sample/test-sample1.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ assets/avengers.gif filter=lfs diff=lfs merge=lfs -text
40
+ overleaf/Track[[:space:]]Anything/figs/avengers_1.pdf filter=lfs diff=lfs merge=lfs -text
41
+ overleaf/Track[[:space:]]Anything/figs/davisresults.pdf filter=lfs diff=lfs merge=lfs -text
42
+ overleaf/Track[[:space:]]Anything/figs/failedcases.pdf filter=lfs diff=lfs merge=lfs -text
43
+ test_sample/test-sample13.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ test_sample/test-sample4.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ test_sample/test-sample8.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ test_sample/huggingface_demo_operation.mp4 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Mingqi Gao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Track Anything
3
+ emoji: 🐠
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: VIPLab/Track-Anything
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+ <!-- ![](./assets/track-anything-logo.jpg) -->
16
+
17
+ <div align=center>
18
+ <img src="./assets/track-anything-logo.jpg"/>
19
+ </div>
20
+ <br/>
21
+ <div align=center>
22
+ <a src="https://img.shields.io/badge/%F0%9F%93%96-Open_in_Spaces-informational.svg?style=flat-square" href="https://arxiv.org/abs/2304.11968">
23
+ <img src="https://img.shields.io/badge/%F0%9F%93%96-Arxiv_2304.11968-red.svg?style=flat-square">
24
+ </a>
25
+ <a src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square" href="https://huggingface.co/spaces/watchtowerss/Track-Anything">
26
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square">
27
+ </a>
28
+ <a src="https://img.shields.io/badge/%F0%9F%9A%80-SUSTech_VIP_Lab-important.svg?style=flat-square" href="https://zhengfenglab.com/">
29
+ <img src="https://img.shields.io/badge/%F0%9F%9A%80-SUSTech_VIP_Lab-important.svg?style=flat-square">
30
+ </a>
31
+ </div>
32
+
33
+ ***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for:
34
+ - Video object tracking and segmentation with shot changes.
35
+ - Visualized development and data annnotation for video object tracking and segmentation.
36
+ - Object-centric downstream video tasks, such as video inpainting and editing.
37
+
38
+ <div align=center>
39
+ <img src="./assets/avengers.gif"/>
40
+ </div>
41
+
42
+ <!-- ![avengers]() -->
43
+
44
+ ## :rocket: Updates
45
+ - 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, an inventive project from our lab that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.
46
+
47
+ - 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:!
48
+
49
+ ## Demo
50
+
51
+ https://user-images.githubusercontent.com/28050374/232842703-8395af24-b13e-4b8e-aafb-e94b61e6c449.MP4
52
+
53
+ ### Multiple Object Tracking and Segmentation (with [XMem](https://github.com/hkchengrex/XMem))
54
+
55
+ https://user-images.githubusercontent.com/39208339/233035206-0a151004-6461-4deb-b782-d1dbfe691493.mp4
56
+
57
+ ### Video Object Tracking and Segmentation with Shot Changes (with [XMem](https://github.com/hkchengrex/XMem))
58
+
59
+ https://user-images.githubusercontent.com/30309970/232848349-f5e29e71-2ea4-4529-ac9a-94b9ca1e7055.mp4
60
+
61
+ ### Video Inpainting (with [E2FGVI](https://github.com/MCG-NKU/E2FGVI))
62
+
63
+ https://user-images.githubusercontent.com/28050374/232959816-07f2826f-d267-4dda-8ae5-a5132173b8f4.mp4
64
+
65
+ ## Get Started
66
+ #### Linux
67
+ ```bash
68
+ # Clone the repository:
69
+ git clone https://github.com/gaomingqi/Track-Anything.git
70
+ cd Track-Anything
71
+
72
+ # Install dependencies:
73
+ pip install -r requirements.txt
74
+
75
+ # Run the Track-Anything gradio demo.
76
+ python app.py --device cuda:0 --sam_model_type vit_h --port 12212
77
+ ```
78
+
79
+ ## Citation
80
+ If you find this work useful for your research or applications, please cite using this BibTeX:
81
+ ```bibtex
82
+ @misc{yang2023track,
83
+ title={Track Anything: Segment Anything Meets Videos},
84
+ author={Jinyu Yang and Mingqi Gao and Zhe Li and Shang Gao and Fangjing Wang and Feng Zheng},
85
+ year={2023},
86
+ eprint={2304.11968},
87
+ archivePrefix={arXiv},
88
+ primaryClass={cs.CV}
89
+ }
90
+ ```
91
+
92
+ ## Acknowledgements
93
+
94
+ The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything), [XMem](https://github.com/hkchengrex/XMem), and [E2FGVI](https://github.com/MCG-NKU/E2FGVI). Thanks for the authors for their efforts.
XMem-s012.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16205ad04bfc55b442bd4d7af894382e09868b35e10721c5afc09a24ea8d72d9
3
+ size 249026057
app.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import gdown
4
+ import cv2
5
+ import numpy as np
6
+ import os
7
+ import sys
8
+ sys.path.append(sys.path[0]+"/tracker")
9
+ sys.path.append(sys.path[0]+"/tracker/model")
10
+ from track_anything import TrackingAnything
11
+ from track_anything import parse_augment, save_image_to_userfolder, read_image_from_userfolder
12
+ import requests
13
+ import json
14
+ import torchvision
15
+ import torch
16
+ from tools.painter import mask_painter
17
+ import psutil
18
+ import time
19
+ try:
20
+ from mmcv.cnn import ConvModule
21
+ except:
22
+ os.system("mim install mmcv")
23
+
24
+ # download checkpoints
25
+ def download_checkpoint(url, folder, filename):
26
+ os.makedirs(folder, exist_ok=True)
27
+ filepath = os.path.join(folder, filename)
28
+
29
+ if not os.path.exists(filepath):
30
+ print("download checkpoints ......")
31
+ response = requests.get(url, stream=True)
32
+ with open(filepath, "wb") as f:
33
+ for chunk in response.iter_content(chunk_size=8192):
34
+ if chunk:
35
+ f.write(chunk)
36
+
37
+ print("download successfully!")
38
+
39
+ return filepath
40
+
41
+ def download_checkpoint_from_google_drive(file_id, folder, filename):
42
+ os.makedirs(folder, exist_ok=True)
43
+ filepath = os.path.join(folder, filename)
44
+
45
+ if not os.path.exists(filepath):
46
+ print("Downloading checkpoints from Google Drive... tips: If you cannot see the progress bar, please try to download it manuall \
47
+ and put it in the checkpointes directory. E2FGVI-HQ-CVPR22.pth: https://github.com/MCG-NKU/E2FGVI(E2FGVI-HQ model)")
48
+ url = f"https://drive.google.com/uc?id={file_id}"
49
+ gdown.download(url, filepath, quiet=False)
50
+ print("Downloaded successfully!")
51
+
52
+ return filepath
53
+
54
+ # convert points input to prompt state
55
+ def get_prompt(click_state, click_input):
56
+ inputs = json.loads(click_input)
57
+ points = click_state[0]
58
+ labels = click_state[1]
59
+ for input in inputs:
60
+ points.append(input[:2])
61
+ labels.append(input[2])
62
+ click_state[0] = points
63
+ click_state[1] = labels
64
+ prompt = {
65
+ "prompt_type":["click"],
66
+ "input_point":click_state[0],
67
+ "input_label":click_state[1],
68
+ "multimask_output":"True",
69
+ }
70
+ return prompt
71
+
72
+
73
+
74
+ # extract frames from upload video
75
+ def get_frames_from_video(video_input, video_state):
76
+ """
77
+ Args:
78
+ video_path:str
79
+ timestamp:float64
80
+ Return
81
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
82
+ """
83
+ video_path = video_input
84
+ frames = [] # save image path
85
+ user_name = time.time()
86
+ video_state["video_name"] = os.path.split(video_path)[-1]
87
+ video_state["user_name"] = user_name
88
+
89
+ os.makedirs(os.path.join("/tmp/{}/originimages/{}".format(video_state["user_name"], video_state["video_name"])), exist_ok=True)
90
+ os.makedirs(os.path.join("/tmp/{}/paintedimages/{}".format(video_state["user_name"], video_state["video_name"])), exist_ok=True)
91
+ operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
92
+ try:
93
+ cap = cv2.VideoCapture(video_path)
94
+ fps = cap.get(cv2.CAP_PROP_FPS)
95
+ if not cap.isOpened():
96
+ operation_log = [("No frames extracted, please input video file with '.mp4.' '.mov'.", "Error")]
97
+ print("No frames extracted, please input video file with '.mp4.' '.mov'.")
98
+ return None, None, None, None, \
99
+ None, None, None, None, \
100
+ None, None, None, None, \
101
+ None, None, gr.update(visible=True, value=operation_log)
102
+ image_index = 0
103
+ while cap.isOpened():
104
+ ret, frame = cap.read()
105
+ if ret == True:
106
+ current_memory_usage = psutil.virtual_memory().percent
107
+
108
+ # try solve memory usage problem, save image to disk instead of memory
109
+ frames.append(save_image_to_userfolder(video_state, image_index, frame, True))
110
+ image_index +=1
111
+ # frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
112
+ if current_memory_usage > 90:
113
+ operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
114
+ print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
115
+ break
116
+ else:
117
+ break
118
+
119
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
120
+ # except:
121
+ operation_log = [("read_frame_source:{} error. {}\n".format(video_path, str(e)), "Error")]
122
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
123
+ return None, None, None, None, \
124
+ None, None, None, None, \
125
+ None, None, None, None, \
126
+ None, None, gr.update(visible=True, value=operation_log)
127
+ first_image = read_image_from_userfolder(frames[0])
128
+ image_size = (first_image.shape[0], first_image.shape[1])
129
+ # initialize video_state
130
+ video_state = {
131
+ "user_name": user_name,
132
+ "video_name": os.path.split(video_path)[-1],
133
+ "origin_images": frames,
134
+ "painted_images": frames.copy(),
135
+ "masks": [np.zeros((image_size[0], image_size[1]), np.uint8)]*len(frames),
136
+ "logits": [None]*len(frames),
137
+ "select_frame_number": 0,
138
+ "fps": fps
139
+ }
140
+ video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
141
+ model.samcontroler.sam_controler.reset_image()
142
+ model.samcontroler.sam_controler.set_image(first_image)
143
+ return video_state, video_info, first_image, gr.update(visible=True, maximum=len(frames), value=1), \
144
+ gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
145
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
146
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True, value=operation_log),
147
+
148
+ def run_example(example):
149
+ return example
150
+ # get the select frame from gradio slider
151
+ def select_template(image_selection_slider, video_state, interactive_state):
152
+
153
+ # images = video_state[1]
154
+ image_selection_slider -= 1
155
+ video_state["select_frame_number"] = image_selection_slider
156
+
157
+ # once select a new template frame, set the image in sam
158
+
159
+ model.samcontroler.sam_controler.reset_image()
160
+ model.samcontroler.sam_controler.set_image(read_image_from_userfolder(video_state["origin_images"][image_selection_slider]))
161
+
162
+ # update the masks when select a new template frame
163
+ # if video_state["masks"][image_selection_slider] is not None:
164
+ # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
165
+ operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
166
+
167
+ return read_image_from_userfolder(video_state["painted_images"][image_selection_slider]), video_state, interactive_state, operation_log
168
+
169
+ # set the tracking end frame
170
+ def get_end_number(track_pause_number_slider, video_state, interactive_state):
171
+ track_pause_number_slider -= 1
172
+ interactive_state["track_end_number"] = track_pause_number_slider
173
+ operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
174
+
175
+ return read_image_from_userfolder(video_state["painted_images"][track_pause_number_slider]),interactive_state, operation_log
176
+
177
+ def get_resize_ratio(resize_ratio_slider, interactive_state):
178
+ interactive_state["resize_ratio"] = resize_ratio_slider
179
+
180
+ return interactive_state
181
+
182
+ # use sam to get the mask
183
+ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
184
+ """
185
+ Args:
186
+ template_frame: PIL.Image
187
+ point_prompt: flag for positive or negative button click
188
+ click_state: [[points], [labels]]
189
+ """
190
+ if point_prompt == "Positive":
191
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
192
+ interactive_state["positive_click_times"] += 1
193
+ else:
194
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
195
+ interactive_state["negative_click_times"] += 1
196
+
197
+ # prompt for sam model
198
+ model.samcontroler.sam_controler.reset_image()
199
+ model.samcontroler.sam_controler.set_image(read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]]))
200
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
201
+
202
+ mask, logit, painted_image = model.first_frame_click(
203
+ image=read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]]),
204
+ points=np.array(prompt["input_point"]),
205
+ labels=np.array(prompt["input_label"]),
206
+ multimask=prompt["multimask_output"],
207
+ )
208
+ video_state["masks"][video_state["select_frame_number"]] = mask
209
+ video_state["logits"][video_state["select_frame_number"]] = logit
210
+ video_state["painted_images"][video_state["select_frame_number"]] = save_image_to_userfolder(video_state, index=video_state["select_frame_number"], image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB),type=False)
211
+
212
+ operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
213
+ return painted_image, video_state, interactive_state, operation_log
214
+
215
+ def add_multi_mask(video_state, interactive_state, mask_dropdown):
216
+ try:
217
+ mask = video_state["masks"][video_state["select_frame_number"]]
218
+ interactive_state["multi_mask"]["masks"].append(mask)
219
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
220
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
221
+ select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
222
+
223
+ operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
224
+ except:
225
+ operation_log = [("Please click the left image to generate mask.", "Error"), ("","")]
226
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
227
+
228
+ def clear_click(video_state, click_state):
229
+ click_state = [[],[]]
230
+ template_frame = read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]])
231
+ operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
232
+ return template_frame, click_state, operation_log
233
+
234
+ def remove_multi_mask(interactive_state, mask_dropdown):
235
+ interactive_state["multi_mask"]["mask_names"]= []
236
+ interactive_state["multi_mask"]["masks"] = []
237
+
238
+ operation_log = [("",""), ("Remove all mask, please add new masks","Normal")]
239
+ return interactive_state, gr.update(choices=[],value=[]), operation_log
240
+
241
+ def show_mask(video_state, interactive_state, mask_dropdown):
242
+ mask_dropdown.sort()
243
+ select_frame = read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]])
244
+
245
+ for i in range(len(mask_dropdown)):
246
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
247
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
248
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
249
+
250
+ operation_log = [("",""), ("Select {} for tracking or inpainting".format(mask_dropdown),"Normal")]
251
+ return select_frame, operation_log
252
+
253
+ # tracking vos
254
+ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
255
+ operation_log = [("",""), ("Track the selected masks, and then you can select the masks for inpainting.","Normal")]
256
+ model.xmem.clear_memory()
257
+ if interactive_state["track_end_number"]:
258
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
259
+ else:
260
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
261
+
262
+ if interactive_state["multi_mask"]["masks"]:
263
+ if len(mask_dropdown) == 0:
264
+ mask_dropdown = ["mask_001"]
265
+ mask_dropdown.sort()
266
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
267
+ for i in range(1,len(mask_dropdown)):
268
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
269
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
270
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
271
+ else:
272
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
273
+ fps = video_state["fps"]
274
+
275
+ # operation error
276
+ if len(np.unique(template_mask))==1:
277
+ template_mask[0][0]=1
278
+ operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
279
+ # return video_output, video_state, interactive_state, operation_error
280
+ masks, logits, painted_images_path = model.generator(images=following_frames, template_mask=template_mask, video_state=video_state)
281
+ # clear GPU memory
282
+ model.xmem.clear_memory()
283
+
284
+ if interactive_state["track_end_number"]:
285
+ video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
286
+ video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
287
+ video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images_path
288
+ else:
289
+ video_state["masks"][video_state["select_frame_number"]:] = masks
290
+ video_state["logits"][video_state["select_frame_number"]:] = logits
291
+ video_state["painted_images"][video_state["select_frame_number"]:] = painted_images_path
292
+
293
+ video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
294
+ interactive_state["inference_times"] += 1
295
+
296
+ print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"],
297
+ interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
298
+ interactive_state["positive_click_times"],
299
+ interactive_state["negative_click_times"]))
300
+
301
+ #### shanggao code for mask save
302
+ if interactive_state["mask_save"]:
303
+ if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
304
+ os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0]))
305
+ i = 0
306
+ print("save mask")
307
+ for mask in video_state["masks"]:
308
+ np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
309
+ i+=1
310
+ #### shanggao code for mask save
311
+ return video_output, video_state, interactive_state, operation_log
312
+
313
+
314
+
315
+ # inpaint
316
+ def inpaint_video(video_state, interactive_state, mask_dropdown):
317
+ operation_log = [("",""), ("Removed the selected masks.","Normal")]
318
+
319
+ # solve memory
320
+ frames = np.asarray(video_state["origin_images"])
321
+ fps = video_state["fps"]
322
+ inpaint_masks = np.asarray(video_state["masks"])
323
+ if len(mask_dropdown) == 0:
324
+ mask_dropdown = ["mask_001"]
325
+ mask_dropdown.sort()
326
+ # convert mask_dropdown to mask numbers
327
+ inpaint_mask_numbers = [int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))]
328
+ # interate through all masks and remove the masks that are not in mask_dropdown
329
+ unique_masks = np.unique(inpaint_masks)
330
+ num_masks = len(unique_masks) - 1
331
+ for i in range(1, num_masks + 1):
332
+ if i in inpaint_mask_numbers:
333
+ continue
334
+ inpaint_masks[inpaint_masks==i] = 0
335
+ # inpaint for videos
336
+
337
+ try:
338
+ inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
339
+ video_output = generate_video_from_paintedframes(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps)
340
+ except:
341
+ operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
342
+ inpainted_frames = video_state["origin_images"]
343
+ video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
344
+ return video_output, operation_log
345
+
346
+
347
+ # generate video after vos inference
348
+ def generate_video_from_frames(frames_path, output_path, fps=30):
349
+ """
350
+ Generates a video from a list of frames.
351
+
352
+ Args:
353
+ frames (list of numpy arrays): The frames to include in the video.
354
+ output_path (str): The path to save the generated video.
355
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
356
+ """
357
+ # height, width, layers = frames[0].shape
358
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
359
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
360
+ # print(output_path)
361
+ # for frame in frames:
362
+ # video.write(frame)
363
+
364
+ # video.release()
365
+ frames = []
366
+ for file in frames_path:
367
+ frames.append(read_image_from_userfolder(file))
368
+ frames = torch.from_numpy(np.asarray(frames))
369
+ if not os.path.exists(os.path.dirname(output_path)):
370
+ os.makedirs(os.path.dirname(output_path))
371
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
372
+ return output_path
373
+
374
+ def generate_video_from_paintedframes(frames, output_path, fps=30):
375
+ """
376
+ Generates a video from a list of frames.
377
+
378
+ Args:
379
+ frames (list of numpy arrays): The frames to include in the video.
380
+ output_path (str): The path to save the generated video.
381
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
382
+ """
383
+ # height, width, layers = frames[0].shape
384
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
385
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
386
+ # print(output_path)
387
+ # for frame in frames:
388
+ # video.write(frame)
389
+
390
+ # video.release()
391
+ frames = torch.from_numpy(np.asarray(frames))
392
+ if not os.path.exists(os.path.dirname(output_path)):
393
+ os.makedirs(os.path.dirname(output_path))
394
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
395
+ return output_path
396
+
397
+
398
+ # args, defined in track_anything.py
399
+ args = parse_augment()
400
+
401
+ # check and download checkpoints if needed
402
+ SAM_checkpoint_dict = {
403
+ 'vit_h': "sam_vit_h_4b8939.pth",
404
+ 'vit_l': "sam_vit_l_0b3195.pth",
405
+ "vit_b": "sam_vit_b_01ec64.pth"
406
+ }
407
+ SAM_checkpoint_url_dict = {
408
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
409
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
410
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
411
+ }
412
+ sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
413
+ sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
414
+ xmem_checkpoint = "XMem-s012.pth"
415
+ xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
416
+ e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
417
+ e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
418
+
419
+
420
+ folder ="./checkpoints"
421
+ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
422
+ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
423
+ e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
424
+ # args.port = 12213
425
+ # args.device = "cuda:8"
426
+ # args.mask_save = True
427
+
428
+ # initialize sam, xmem, e2fgvi models
429
+ model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
430
+
431
+
432
+ title = """<p><h1 align="center">Track-Anything</h1></p>
433
+ """
434
+ description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">Track-Anything</a> <a href="https://huggingface.co/spaces/VIPLab/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a> If you stuck in unknown errors, please feel free to watch the Tutorial video.</p>"""
435
+
436
+
437
+ with gr.Blocks() as iface:
438
+ """
439
+ state for
440
+ """
441
+ click_state = gr.State([[],[]])
442
+ interactive_state = gr.State({
443
+ "inference_times": 0,
444
+ "negative_click_times" : 0,
445
+ "positive_click_times": 0,
446
+ "mask_save": args.mask_save,
447
+ "multi_mask": {
448
+ "mask_names": [],
449
+ "masks": []
450
+ },
451
+ "track_end_number": None,
452
+ "resize_ratio": 0.6
453
+ }
454
+ )
455
+
456
+ video_state = gr.State(
457
+ {
458
+ "user_name": "",
459
+ "video_name": "",
460
+ "origin_images": None,
461
+ "painted_images": None,
462
+ "masks": None,
463
+ "inpaint_masks": None,
464
+ "logits": None,
465
+ "select_frame_number": 0,
466
+ "fps": 30
467
+ }
468
+ )
469
+ gr.Markdown(title)
470
+ gr.Markdown(description)
471
+ with gr.Row():
472
+ with gr.Column():
473
+ with gr.Tab("Test"):
474
+ # for user video input
475
+ with gr.Column():
476
+ with gr.Row(scale=0.4):
477
+ video_input = gr.Video(autosize=True)
478
+ with gr.Column():
479
+ video_info = gr.Textbox(label="Video Info")
480
+ resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
481
+ Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
482
+ resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=0.6, label="Resize ratio", visible=True)
483
+
484
+
485
+ with gr.Row():
486
+ # put the template frame under the radio button
487
+ with gr.Column():
488
+ # extract frames
489
+ with gr.Column():
490
+ extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
491
+
492
+ # click points settins, negative or positive, mode continuous or single
493
+ with gr.Row():
494
+ with gr.Row():
495
+ point_prompt = gr.Radio(
496
+ choices=["Positive", "Negative"],
497
+ value="Positive",
498
+ label="Point prompt",
499
+ interactive=True,
500
+ visible=False)
501
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
502
+ clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False).style(height=160)
503
+ Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
504
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
505
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
506
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
507
+
508
+ with gr.Column():
509
+ run_status = gr.HighlightedText(value=[("Run","Error"),("Status","Normal")], visible=True)
510
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
511
+ video_output = gr.Video(autosize=True, visible=False).style(height=360)
512
+ with gr.Row():
513
+ tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
514
+ inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
515
+ # set example
516
+ gr.Markdown("## Examples")
517
+ gr.Examples(
518
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \
519
+ "test-sample2.mp4","test-sample13.mp4"]],
520
+ fn=run_example,
521
+ inputs=[
522
+ video_input
523
+ ],
524
+ outputs=[video_input],
525
+ # cache_examples=True,
526
+ )
527
+
528
+ with gr.Tab("Tutorial"):
529
+ with gr.Column():
530
+ with gr.Row(scale=0.4):
531
+ video_demo_operation = gr.Video(autosize=True)
532
+
533
+ # set example
534
+ gr.Markdown("## Operation tutorial video")
535
+ gr.Examples(
536
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["huggingface_demo_operation.mp4"]],
537
+ fn=run_example,
538
+ inputs=[
539
+ video_demo_operation
540
+ ],
541
+ outputs=[video_demo_operation],
542
+ # cache_examples=True,
543
+ )
544
+
545
+ # first step: get the video information
546
+ extract_frames_button.click(
547
+ fn=get_frames_from_video,
548
+ inputs=[
549
+ video_input, video_state
550
+ ],
551
+ outputs=[video_state, video_info, template_frame, image_selection_slider,
552
+ track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button,
553
+ template_frame, tracking_video_predict_button, video_output, mask_dropdown,
554
+ remove_mask_button, inpaint_video_predict_button, run_status]
555
+ )
556
+
557
+ # second step: select images from slider
558
+ image_selection_slider.release(fn=select_template,
559
+ inputs=[image_selection_slider, video_state, interactive_state],
560
+ outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image")
561
+ track_pause_number_slider.release(fn=get_end_number,
562
+ inputs=[track_pause_number_slider, video_state, interactive_state],
563
+ outputs=[template_frame, interactive_state, run_status], api_name="end_image")
564
+ resize_ratio_slider.release(fn=get_resize_ratio,
565
+ inputs=[resize_ratio_slider, interactive_state],
566
+ outputs=[interactive_state], api_name="resize_ratio")
567
+
568
+ # click select image to get mask using sam
569
+ template_frame.select(
570
+ fn=sam_refine,
571
+ inputs=[video_state, point_prompt, click_state, interactive_state],
572
+ outputs=[template_frame, video_state, interactive_state, run_status]
573
+ )
574
+
575
+ # add different mask
576
+ Add_mask_button.click(
577
+ fn=add_multi_mask,
578
+ inputs=[video_state, interactive_state, mask_dropdown],
579
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status]
580
+ )
581
+
582
+ remove_mask_button.click(
583
+ fn=remove_multi_mask,
584
+ inputs=[interactive_state, mask_dropdown],
585
+ outputs=[interactive_state, mask_dropdown, run_status]
586
+ )
587
+
588
+ # tracking video from select image and mask
589
+ tracking_video_predict_button.click(
590
+ fn=vos_tracking_video,
591
+ inputs=[video_state, interactive_state, mask_dropdown],
592
+ outputs=[video_output, video_state, interactive_state, run_status]
593
+ )
594
+
595
+ # inpaint video from select image and mask
596
+ inpaint_video_predict_button.click(
597
+ fn=inpaint_video,
598
+ inputs=[video_state, interactive_state, mask_dropdown],
599
+ outputs=[video_output, run_status]
600
+ )
601
+
602
+ # click to get mask
603
+ mask_dropdown.change(
604
+ fn=show_mask,
605
+ inputs=[video_state, interactive_state, mask_dropdown],
606
+ outputs=[template_frame, run_status]
607
+ )
608
+
609
+ # clear input
610
+ video_input.clear(
611
+ lambda: (
612
+ {
613
+ "user_name": "",
614
+ "video_name": "",
615
+ "origin_images": None,
616
+ "painted_images": None,
617
+ "masks": None,
618
+ "inpaint_masks": None,
619
+ "logits": None,
620
+ "select_frame_number": 0,
621
+ "fps": 30
622
+ },
623
+ {
624
+ "inference_times": 0,
625
+ "negative_click_times" : 0,
626
+ "positive_click_times": 0,
627
+ "mask_save": args.mask_save,
628
+ "multi_mask": {
629
+ "mask_names": [],
630
+ "masks": []
631
+ },
632
+ "track_end_number": 0,
633
+ "resize_ratio": 0.6
634
+ },
635
+ [[],[]],
636
+ None,
637
+ None,
638
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
639
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
640
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
641
+ gr.update(visible=False), gr.update(visible=True)
642
+
643
+ ),
644
+ [],
645
+ [
646
+ video_state,
647
+ interactive_state,
648
+ click_state,
649
+ video_output,
650
+ template_frame,
651
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
652
+ Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
653
+ ],
654
+ queue=False,
655
+ show_progress=False)
656
+
657
+ # points clear
658
+ clear_button_click.click(
659
+ fn = clear_click,
660
+ inputs = [video_state, click_state,],
661
+ outputs = [template_frame,click_state, run_status],
662
+ )
663
+ iface.queue(concurrency_count=1)
664
+ # iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
665
+ iface.launch(debug=True, enable_queue=True)
app_save.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from demo import automask_image_app, automask_video_app, sahi_autoseg_app
3
+ import argparse
4
+ import cv2
5
+ import time
6
+ from PIL import Image
7
+ import numpy as np
8
+ import os
9
+ import sys
10
+ sys.path.append(sys.path[0]+"/tracker")
11
+ sys.path.append(sys.path[0]+"/tracker/model")
12
+ from track_anything import TrackingAnything
13
+ from track_anything import parse_augment
14
+ import requests
15
+ import json
16
+ import torchvision
17
+ import torch
18
+ import concurrent.futures
19
+ import queue
20
+
21
+ def download_checkpoint(url, folder, filename):
22
+ os.makedirs(folder, exist_ok=True)
23
+ filepath = os.path.join(folder, filename)
24
+
25
+ if not os.path.exists(filepath):
26
+ print("download checkpoints ......")
27
+ response = requests.get(url, stream=True)
28
+ with open(filepath, "wb") as f:
29
+ for chunk in response.iter_content(chunk_size=8192):
30
+ if chunk:
31
+ f.write(chunk)
32
+
33
+ print("download successfully!")
34
+
35
+ return filepath
36
+
37
+ def pause_video(play_state):
38
+ print("user pause_video")
39
+ play_state.append(time.time())
40
+ return play_state
41
+
42
+ def play_video(play_state):
43
+ print("user play_video")
44
+ play_state.append(time.time())
45
+ return play_state
46
+
47
+ # convert points input to prompt state
48
+ def get_prompt(click_state, click_input):
49
+ inputs = json.loads(click_input)
50
+ points = click_state[0]
51
+ labels = click_state[1]
52
+ for input in inputs:
53
+ points.append(input[:2])
54
+ labels.append(input[2])
55
+ click_state[0] = points
56
+ click_state[1] = labels
57
+ prompt = {
58
+ "prompt_type":["click"],
59
+ "input_point":click_state[0],
60
+ "input_label":click_state[1],
61
+ "multimask_output":"True",
62
+ }
63
+ return prompt
64
+
65
+ def get_frames_from_video(video_input, play_state):
66
+ """
67
+ Args:
68
+ video_path:str
69
+ timestamp:float64
70
+ Return
71
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
72
+ """
73
+ video_path = video_input
74
+ # video_name = video_path.split('/')[-1]
75
+
76
+ try:
77
+ timestamp = play_state[1] - play_state[0]
78
+ except:
79
+ timestamp = 0
80
+ frames = []
81
+ try:
82
+ cap = cv2.VideoCapture(video_path)
83
+ fps = cap.get(cv2.CAP_PROP_FPS)
84
+ while cap.isOpened():
85
+ ret, frame = cap.read()
86
+ if ret == True:
87
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
88
+ else:
89
+ break
90
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
91
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
92
+
93
+ # for index, frame in enumerate(frames):
94
+ # frames[index] = np.asarray(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
95
+
96
+ key_frame_index = int(timestamp * fps)
97
+ nearest_frame = frames[key_frame_index]
98
+ frames_split = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
99
+ # output_path='./seperate.mp4'
100
+ # torchvision.io.write_video(output_path, frames[1], fps=fps, video_codec="libx264")
101
+
102
+ # set image in sam when select the template frame
103
+ model.samcontroler.sam_controler.set_image(nearest_frame)
104
+ return frames_split, nearest_frame, nearest_frame, fps
105
+
106
+ def generate_video_from_frames(frames, output_path, fps=30):
107
+ """
108
+ Generates a video from a list of frames.
109
+
110
+ Args:
111
+ frames (list of numpy arrays): The frames to include in the video.
112
+ output_path (str): The path to save the generated video.
113
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
114
+ """
115
+ # height, width, layers = frames[0].shape
116
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
117
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
118
+
119
+ # for frame in frames:
120
+ # video.write(frame)
121
+
122
+ # video.release()
123
+ frames = torch.from_numpy(np.asarray(frames))
124
+ output_path='./output.mp4'
125
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
126
+ return output_path
127
+
128
+ def model_reset():
129
+ model.xmem.clear_memory()
130
+ return None
131
+
132
+ def sam_refine(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
133
+ """
134
+ Args:
135
+ template_frame: PIL.Image
136
+ point_prompt: flag for positive or negative button click
137
+ click_state: [[points], [labels]]
138
+ """
139
+ if point_prompt == "Positive":
140
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
141
+ else:
142
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
143
+
144
+ # prompt for sam model
145
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
146
+
147
+ # default value
148
+ # points = np.array([[evt.index[0],evt.index[1]]])
149
+ # labels= np.array([1])
150
+ if len(logit)==0:
151
+ logit = None
152
+
153
+ mask, logit, painted_image = model.first_frame_click(
154
+ image=origin_frame,
155
+ points=np.array(prompt["input_point"]),
156
+ labels=np.array(prompt["input_label"]),
157
+ multimask=prompt["multimask_output"],
158
+ )
159
+ return painted_image, click_state, logit, mask
160
+
161
+
162
+
163
+ def vos_tracking_video(video_state, template_mask,fps,video_input):
164
+
165
+ masks, logits, painted_images = model.generator(images=video_state[1], template_mask=template_mask)
166
+ video_output = generate_video_from_frames(painted_images, output_path="./output.mp4", fps=fps)
167
+ # image_selection_slider = gr.Slider(minimum=1, maximum=len(video_state[1]), value=1, label="Image Selection", interactive=True)
168
+ video_name = video_input.split('/')[-1].split('.')[0]
169
+ result_path = os.path.join('/hhd3/gaoshang/Track-Anything/results/'+video_name)
170
+ if not os.path.exists(result_path):
171
+ os.makedirs(result_path)
172
+ i=0
173
+ for mask in masks:
174
+ np.save(os.path.join(result_path,'{:05}.npy'.format(i)), mask)
175
+ i+=1
176
+ return video_output, painted_images, masks, logits
177
+
178
+ def vos_tracking_image(image_selection_slider, painted_images):
179
+
180
+ # images = video_state[1]
181
+ percentage = image_selection_slider / 100
182
+ select_frame_num = int(percentage * len(painted_images))
183
+ return painted_images[select_frame_num], select_frame_num
184
+
185
+ def interactive_correction(video_state, point_prompt, click_state, select_correction_frame, evt: gr.SelectData):
186
+ """
187
+ Args:
188
+ template_frame: PIL.Image
189
+ point_prompt: flag for positive or negative button click
190
+ click_state: [[points], [labels]]
191
+ """
192
+ refine_image = video_state[1][select_correction_frame]
193
+ if point_prompt == "Positive":
194
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
195
+ else:
196
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
197
+
198
+ # prompt for sam model
199
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
200
+ model.samcontroler.seg_again(refine_image)
201
+ corrected_mask, corrected_logit, corrected_painted_image = model.first_frame_click(
202
+ image=refine_image,
203
+ points=np.array(prompt["input_point"]),
204
+ labels=np.array(prompt["input_label"]),
205
+ multimask=prompt["multimask_output"],
206
+ )
207
+ return corrected_painted_image, [corrected_mask, corrected_logit, corrected_painted_image]
208
+
209
+ def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images, fps, video_input):
210
+ model.xmem.clear_memory()
211
+ # inference the following images
212
+ following_images = video_state[1][select_correction_frame:]
213
+ corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, template_mask=corrected_state[0])
214
+ masks = masks[:select_correction_frame] + corrected_masks
215
+ logits = logits[:select_correction_frame] + corrected_logits
216
+ painted_images = painted_images[:select_correction_frame] + corrected_painted_images
217
+ video_output = generate_video_from_frames(painted_images, output_path="./output.mp4", fps=fps)
218
+
219
+ video_name = video_input.split('/')[-1].split('.')[0]
220
+ result_path = os.path.join('/hhd3/gaoshang/Track-Anything/results/'+video_name)
221
+ if not os.path.exists(result_path):
222
+ os.makedirs(result_path)
223
+ i=0
224
+ for mask in masks:
225
+ np.save(os.path.join(result_path,'{:05}.npy'.format(i)), mask)
226
+ i+=1
227
+ return video_output, painted_images, logits, masks
228
+
229
+ # check and download checkpoints if needed
230
+ SAM_checkpoint = "sam_vit_h_4b8939.pth"
231
+ sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
232
+ xmem_checkpoint = "XMem-s012.pth"
233
+ xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
234
+ folder ="./checkpoints"
235
+ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
236
+ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
237
+
238
+ # args, defined in track_anything.py
239
+ args = parse_augment()
240
+ args.port = 12207
241
+ args.device = "cuda:5"
242
+
243
+ model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
244
+
245
+ with gr.Blocks() as iface:
246
+ """
247
+ state for
248
+ """
249
+ state = gr.State([])
250
+ play_state = gr.State([])
251
+ video_state = gr.State([[],[],[]])
252
+ click_state = gr.State([[],[]])
253
+ logits = gr.State([])
254
+ masks = gr.State([])
255
+ painted_images = gr.State([])
256
+ origin_image = gr.State(None)
257
+ template_mask = gr.State(None)
258
+ select_correction_frame = gr.State(None)
259
+ corrected_state = gr.State([[],[],[]])
260
+ fps = gr.State([])
261
+ # video_name = gr.State([])
262
+ # queue value for image refresh, origin image, mask, logits, painted image
263
+
264
+
265
+
266
+ with gr.Row():
267
+
268
+ # for user video input
269
+ with gr.Column(scale=1.0):
270
+ video_input = gr.Video().style(height=720)
271
+
272
+ # listen to the user action for play and pause input video
273
+ video_input.play(fn=play_video, inputs=play_state, outputs=play_state, scroll_to_output=True, show_progress=True)
274
+ video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state)
275
+
276
+
277
+ with gr.Row(scale=1):
278
+ # put the template frame under the radio button
279
+ with gr.Column(scale=0.5):
280
+ # click points settins, negative or positive, mode continuous or single
281
+ with gr.Row():
282
+ with gr.Row(scale=0.5):
283
+ point_prompt = gr.Radio(
284
+ choices=["Positive", "Negative"],
285
+ value="Positive",
286
+ label="Point Prompt",
287
+ interactive=True)
288
+ click_mode = gr.Radio(
289
+ choices=["Continuous", "Single"],
290
+ value="Continuous",
291
+ label="Clicking Mode",
292
+ interactive=True)
293
+ with gr.Row(scale=0.5):
294
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
295
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
296
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
297
+ with gr.Column():
298
+ template_select_button = gr.Button(value="Template select", interactive=True, variant="primary")
299
+
300
+
301
+
302
+ with gr.Column(scale=0.5):
303
+
304
+
305
+ # for intermedia result check and correction
306
+ # intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
307
+ video_output = gr.Video().style(height=360)
308
+ tracking_video_predict_button = gr.Button(value="Tracking")
309
+
310
+ image_output = gr.Image(type="pil", interactive=True, elem_id="image_output").style(height=360)
311
+ image_selection_slider = gr.Slider(minimum=0, maximum=100, step=0.1, value=0, label="Image Selection", interactive=True)
312
+ correct_track_button = gr.Button(value="Interactive Correction")
313
+
314
+ template_frame.select(
315
+ fn=sam_refine,
316
+ inputs=[
317
+ origin_image, point_prompt, click_state, logits
318
+ ],
319
+ outputs=[
320
+ template_frame, click_state, logits, template_mask
321
+ ]
322
+ )
323
+
324
+ template_select_button.click(
325
+ fn=get_frames_from_video,
326
+ inputs=[
327
+ video_input,
328
+ play_state
329
+ ],
330
+ # outputs=[video_state, template_frame, origin_image, fps, video_name],
331
+ outputs=[video_state, template_frame, origin_image, fps],
332
+ )
333
+
334
+ tracking_video_predict_button.click(
335
+ fn=vos_tracking_video,
336
+ inputs=[video_state, template_mask, fps, video_input],
337
+ outputs=[video_output, painted_images, masks, logits]
338
+ )
339
+ image_selection_slider.release(fn=vos_tracking_image,
340
+ inputs=[image_selection_slider, painted_images], outputs=[image_output, select_correction_frame], api_name="select_image")
341
+ # correction
342
+ image_output.select(
343
+ fn=interactive_correction,
344
+ inputs=[video_state, point_prompt, click_state, select_correction_frame],
345
+ outputs=[image_output, corrected_state]
346
+ )
347
+ correct_track_button.click(
348
+ fn=correct_track,
349
+ inputs=[video_state, select_correction_frame, corrected_state, masks, logits, painted_images, fps,video_input],
350
+ outputs=[video_output, painted_images, logits, masks ]
351
+ )
352
+
353
+
354
+
355
+ # clear input
356
+ video_input.clear(
357
+ lambda: ([], [], [[], [], []],
358
+ None, "", "", "", "", "", "", "", [[],[]],
359
+ None),
360
+ [],
361
+ [ state, play_state, video_state,
362
+ template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
363
+ select_correction_frame],
364
+ queue=False,
365
+ show_progress=False
366
+ )
367
+ clear_button_image.click(
368
+ fn=model_reset
369
+ )
370
+ clear_button_clike.click(
371
+ lambda: ([[],[]]),
372
+ [],
373
+ [click_state],
374
+ queue=False,
375
+ show_progress=False
376
+ )
377
+ iface.queue(concurrency_count=1)
378
+ iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
379
+
380
+
381
+
app_test.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+
3
+ # def update_iframe(slider_value):
4
+ # return f'''
5
+ # <script>
6
+ # window.addEventListener('message', function(event) {{
7
+ # if (event.data.sliderValue !== undefined) {{
8
+ # var iframe = document.getElementById("text_iframe");
9
+ # iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
10
+ # }}
11
+ # }}, false);
12
+ # </script>
13
+ # <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
14
+ # '''
15
+
16
+ # iface = gr.Interface(
17
+ # fn=update_iframe,
18
+ # inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
19
+ # outputs=gr.outputs.HTML(),
20
+ # allow_flagging=False,
21
+ # )
22
+
23
+ # iface.launch(server_name='0.0.0.0', server_port=12212)
24
+
25
+ import gradio as gr
26
+
27
+
28
+ def change_mask(drop):
29
+ return gr.update(choices=["hello", "kitty"])
30
+
31
+ with gr.Blocks() as iface:
32
+ drop = gr.Dropdown(
33
+ choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!"
34
+ )
35
+ radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
36
+ multi_drop = gr.Dropdown(
37
+ ["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl."
38
+ )
39
+
40
+ multi_drop.change(
41
+ fn=change_mask,
42
+ inputs = multi_drop,
43
+ outputs=multi_drop
44
+ )
45
+
46
+ iface.launch(server_name='0.0.0.0', server_port=1223)
assets/avengers.gif ADDED

Git LFS Details

  • SHA256: 5e07b86ee4cf002b3481c71e2038c03f4420883c3be78220dafbc4b59abfb32d
  • Pointer size: 133 Bytes
  • Size of remote file: 30 MB
assets/demo_version_1.MP4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b61b54bc6eb0d0f7416f95aa3cd6a48d850ca7473022ec1aff48310911b0233
3
+ size 27053146
assets/inpainting.gif ADDED

Git LFS Details

  • SHA256: 5e99bd697bccaed7a0dded7f00855f222031b7dcefd8f64f22f374fcdab390d2
  • Pointer size: 133 Bytes
  • Size of remote file: 22.2 MB
assets/poster_demo_version_1.png ADDED
assets/qingming.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58b34bbce0bd0a18ab5fc5450d4046e1cfc6bd55c508046695545819d8fc46dc
3
+ size 4483842
assets/track-anything-logo.jpg ADDED
checkpoints/E2FGVI-HQ-CVPR22.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afff989d41205598a79ce24630b9c83af4b0a06f45b137979a25937d94c121a5
3
+ size 164535938
demo.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor, SahiAutoSegmentation, sahi_sliced_predict
2
+
3
+ # For image
4
+
5
+ def automask_image_app(image_path, model_type, points_per_side, points_per_batch, min_area):
6
+ SegAutoMaskPredictor().image_predict(
7
+ source=image_path,
8
+ model_type=model_type, # vit_l, vit_h, vit_b
9
+ points_per_side=points_per_side,
10
+ points_per_batch=points_per_batch,
11
+ min_area=min_area,
12
+ output_path="output.png",
13
+ show=False,
14
+ save=True,
15
+ )
16
+ return "output.png"
17
+
18
+
19
+ # For video
20
+
21
+ def automask_video_app(video_path, model_type, points_per_side, points_per_batch, min_area):
22
+ SegAutoMaskPredictor().video_predict(
23
+ source=video_path,
24
+ model_type=model_type, # vit_l, vit_h, vit_b
25
+ points_per_side=points_per_side,
26
+ points_per_batch=points_per_batch,
27
+ min_area=min_area,
28
+ output_path="output.mp4",
29
+ )
30
+ return "output.mp4"
31
+
32
+
33
+ # For manuel box and point selection
34
+
35
+ def manual_app(image_path, model_type, input_point, input_label, input_box, multimask_output, random_color):
36
+ SegManualMaskPredictor().image_predict(
37
+ source=image_path,
38
+ model_type=model_type, # vit_l, vit_h, vit_b
39
+ input_point=input_point,
40
+ input_label=input_label,
41
+ input_box=input_box,
42
+ multimask_output=multimask_output,
43
+ random_color=random_color,
44
+ output_path="output.png",
45
+ show=False,
46
+ save=True,
47
+ )
48
+ return "output.png"
49
+
50
+
51
+ # For sahi sliced prediction
52
+
53
+ def sahi_autoseg_app(
54
+ image_path,
55
+ sam_model_type,
56
+ detection_model_type,
57
+ detection_model_path,
58
+ conf_th,
59
+ image_size,
60
+ slice_height,
61
+ slice_width,
62
+ overlap_height_ratio,
63
+ overlap_width_ratio,
64
+ ):
65
+ boxes = sahi_sliced_predict(
66
+ image_path=image_path,
67
+ detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision
68
+ detection_model_path=detection_model_path,
69
+ conf_th=conf_th,
70
+ image_size=image_size,
71
+ slice_height=slice_height,
72
+ slice_width=slice_width,
73
+ overlap_height_ratio=overlap_height_ratio,
74
+ overlap_width_ratio=overlap_width_ratio,
75
+ )
76
+
77
+ SahiAutoSegmentation().predict(
78
+ source=image_path,
79
+ model_type=sam_model_type,
80
+ input_box=boxes,
81
+ multimask_output=False,
82
+ random_color=False,
83
+ show=False,
84
+ save=True,
85
+ )
86
+
87
+ return "output.png"
images/groceries.jpg ADDED
images/mask_painter.png ADDED
images/painter_input_image.jpg ADDED
images/painter_input_mask.jpg ADDED
images/painter_output_image.png ADDED
images/painter_output_image__.png ADDED
images/point_painter.png ADDED
images/point_painter_1.png ADDED
images/point_painter_2.png ADDED
images/truck.jpg ADDED
images/truck_both.jpg ADDED
images/truck_mask.jpg ADDED
images/truck_point.jpg ADDED
inpainter/.DS_Store ADDED
Binary file (6.15 kB). View file
 
inpainter/base_inpainter.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from PIL import Image
4
+ import torch
5
+ import yaml
6
+ import cv2
7
+ import importlib
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from inpainter.util.tensor_util import resize_frames, resize_masks
11
+
12
+ def read_image_from_split(videp_split_path):
13
+ # if type:
14
+ image = np.asarray([np.asarray(Image.open(path)) for path in videp_split_path])
15
+ # else:
16
+ # image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
17
+ return image
18
+
19
+ def save_image_to_userfolder(video_state, index, image, type:bool):
20
+ if type:
21
+ image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
22
+ else:
23
+ image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
24
+ cv2.imwrite(image_path, image)
25
+ return image_path
26
+ class BaseInpainter:
27
+ def __init__(self, E2FGVI_checkpoint, device) -> None:
28
+ """
29
+ E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support)
30
+ """
31
+ net = importlib.import_module('inpainter.model.e2fgvi_hq')
32
+ self.model = net.InpaintGenerator().to(device)
33
+ self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device))
34
+ self.model.eval()
35
+ self.device = device
36
+ # load configurations
37
+ with open("inpainter/config/config.yaml", 'r') as stream:
38
+ config = yaml.safe_load(stream)
39
+ self.neighbor_stride = config['neighbor_stride']
40
+ self.num_ref = config['num_ref']
41
+ self.step = config['step']
42
+
43
+ # sample reference frames from the whole video
44
+ def get_ref_index(self, f, neighbor_ids, length):
45
+ ref_index = []
46
+ if self.num_ref == -1:
47
+ for i in range(0, length, self.step):
48
+ if i not in neighbor_ids:
49
+ ref_index.append(i)
50
+ else:
51
+ start_idx = max(0, f - self.step * (self.num_ref // 2))
52
+ end_idx = min(length, f + self.step * (self.num_ref // 2))
53
+ for i in range(start_idx, end_idx + 1, self.step):
54
+ if i not in neighbor_ids:
55
+ if len(ref_index) > self.num_ref:
56
+ break
57
+ ref_index.append(i)
58
+ return ref_index
59
+
60
+ def inpaint_efficient(self, frames, masks, num_tcb, num_tca, dilate_radius=15, ratio=1):
61
+ """
62
+ Perform Inpainting for video subsets
63
+ frames: numpy array, T, H, W, 3
64
+ masks: numpy array, T, H, W
65
+ num_tcb: constant, number of temporal context before, frames
66
+ num_tca: constant, number of temporal context after, frames
67
+ dilate_radius: radius when applying dilation on masks
68
+ ratio: down-sample ratio
69
+
70
+ Output:
71
+ inpainted_frames: numpy array, T, H, W, 3
72
+ """
73
+ assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
74
+ assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
75
+
76
+ # --------------------
77
+ # pre-processing
78
+ # --------------------
79
+ masks = masks.copy()
80
+ masks = np.clip(masks, 0, 1)
81
+ kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
82
+ masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
83
+ T, H, W = masks.shape
84
+ masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
85
+ # size: (w, h)
86
+ if ratio == 1:
87
+ size = None
88
+ binary_masks = masks
89
+ else:
90
+ size = [int(W*ratio), int(H*ratio)]
91
+ size = [si+1 if si%2>0 else si for si in size] # only consider even values
92
+ # shortest side should be larger than 50
93
+ if min(size) < 50:
94
+ ratio = 50. / min(H, W)
95
+ size = [int(W*ratio), int(H*ratio)]
96
+ binary_masks = resize_masks(masks, tuple(size))
97
+ frames = resize_frames(frames, tuple(size)) # T, H, W, 3
98
+ # frames and binary_masks are numpy arrays
99
+ h, w = frames.shape[1:3]
100
+ video_length = T - (num_tca + num_tcb) # real video length
101
+ # convert to tensor
102
+ imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
103
+ masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
104
+ imgs, masks = imgs.to(self.device), masks.to(self.device)
105
+ comp_frames = [None] * video_length
106
+ tcb_imgs = None
107
+ tca_imgs = None
108
+ tcb_masks = None
109
+ tca_masks = None
110
+ # --------------------
111
+ # end of pre-processing
112
+ # --------------------
113
+
114
+ # separate tc frames/masks from imgs and masks
115
+ if num_tcb > 0:
116
+ tcb_imgs = imgs[:, :num_tcb]
117
+ tcb_masks = masks[:, :num_tcb]
118
+ tcb_binary = binary_masks[:num_tcb]
119
+ if num_tca > 0:
120
+ tca_imgs = imgs[:, -num_tca:]
121
+ tca_masks = masks[:, -num_tca:]
122
+ tca_binary = binary_masks[-num_tca:]
123
+ end_idx = -num_tca
124
+ else:
125
+ end_idx = T
126
+
127
+ imgs = imgs[:, num_tcb:end_idx]
128
+ masks = masks[:, num_tcb:end_idx]
129
+ binary_masks = binary_masks[num_tcb:end_idx] # only neighbor area are involved
130
+ frames = frames[num_tcb:end_idx] # only neighbor area are involved
131
+
132
+ for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
133
+ neighbor_ids = [
134
+ i for i in range(max(0, f - self.neighbor_stride),
135
+ min(video_length, f + self.neighbor_stride + 1))
136
+ ]
137
+ ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
138
+
139
+ # selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
140
+ # selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
141
+
142
+ selected_imgs = imgs[:, neighbor_ids]
143
+ selected_masks = masks[:, neighbor_ids]
144
+ # pad before
145
+ if tcb_imgs is not None:
146
+ selected_imgs = torch.concat([selected_imgs, tcb_imgs], dim=1)
147
+ selected_masks = torch.concat([selected_masks, tcb_masks], dim=1)
148
+ # integrate ref frames
149
+ selected_imgs = torch.concat([selected_imgs, imgs[:, ref_ids]], dim=1)
150
+ selected_masks = torch.concat([selected_masks, masks[:, ref_ids]], dim=1)
151
+ # pad after
152
+ if tca_imgs is not None:
153
+ selected_imgs = torch.concat([selected_imgs, tca_imgs], dim=1)
154
+ selected_masks = torch.concat([selected_masks, tca_masks], dim=1)
155
+
156
+ with torch.no_grad():
157
+ masked_imgs = selected_imgs * (1 - selected_masks)
158
+ mod_size_h = 60
159
+ mod_size_w = 108
160
+ h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
161
+ w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
162
+ masked_imgs = torch.cat(
163
+ [masked_imgs, torch.flip(masked_imgs, [3])],
164
+ 3)[:, :, :, :h + h_pad, :]
165
+ masked_imgs = torch.cat(
166
+ [masked_imgs, torch.flip(masked_imgs, [4])],
167
+ 4)[:, :, :, :, :w + w_pad]
168
+ pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
169
+ pred_imgs = pred_imgs[:, :, :h, :w]
170
+ pred_imgs = (pred_imgs + 1) / 2
171
+ pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
172
+ for i in range(len(neighbor_ids)):
173
+ idx = neighbor_ids[i]
174
+ img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (
175
+ 1 - binary_masks[idx])
176
+ if comp_frames[idx] is None:
177
+ comp_frames[idx] = img
178
+ else:
179
+ comp_frames[idx] = comp_frames[idx].astype(
180
+ np.float32) * 0.5 + img.astype(np.float32) * 0.5
181
+ torch.cuda.empty_cache()
182
+ inpainted_frames = np.stack(comp_frames, 0)
183
+ return inpainted_frames.astype(np.uint8)
184
+
185
+ def inpaint(self, frames_path, masks, dilate_radius=15, ratio=1):
186
+ """
187
+ Perform Inpainting for video subsets
188
+ frames: numpy array, T, H, W, 3
189
+ masks: numpy array, T, H, W
190
+ dilate_radius: radius when applying dilation on masks
191
+ ratio: down-sample ratio
192
+
193
+ Output:
194
+ inpainted_frames: numpy array, T, H, W, 3
195
+ """
196
+ # assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
197
+ assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
198
+
199
+ # set interval
200
+ interval = 45
201
+ context_range = 10 # for each split, consider its temporal context [-context_range] frames and [context_range] frames
202
+ # split frames into subsets
203
+ video_length = len(frames_path)
204
+ num_splits = video_length // interval
205
+ id_splits = [[i*interval, (i+1)*interval] for i in range(num_splits)] # id splits
206
+ # if remaining split > interval/2, add a new split, else, append to the last split
207
+ if video_length - id_splits[-1][-1] > interval / 2:
208
+ id_splits.append([num_splits*interval, video_length])
209
+ else:
210
+ id_splits[-1][-1] = video_length
211
+
212
+ # perform inpainting for each split
213
+ inpainted_splits = []
214
+ for id_split in id_splits:
215
+ video_split_path = frames_path[id_split[0]:id_split[1]]
216
+ video_split = read_image_from_split(video_split_path)
217
+ mask_split = masks[id_split[0]:id_split[1]]
218
+
219
+ # | id_before | ----- | id_split[0] | ----- | id_split[1] | ----- | id_after |
220
+ # add temporal context
221
+ id_before = max(0, id_split[0] - self.step * context_range)
222
+ try:
223
+ tcb_frames = np.stack([np.array(Image.open(frames_path[idb])) for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
224
+ tcb_masks = np.stack([masks[idb] for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
225
+ num_tcb = len(tcb_frames)
226
+ except:
227
+ num_tcb = 0
228
+ id_after = min(video_length, id_split[1] + self.step * context_range)
229
+ try:
230
+ tca_frames = np.stack([np.array(Image.open(frames_path[ida])) for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
231
+ tca_masks = np.stack([masks[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
232
+ num_tca = len(tca_frames)
233
+ except:
234
+ num_tca = 0
235
+
236
+ # concatenate temporal context frames/masks with input frames/masks (for parallel pre-processing)
237
+ if num_tcb > 0:
238
+ video_split = np.concatenate([tcb_frames, video_split], 0)
239
+ mask_split = np.concatenate([tcb_masks, mask_split], 0)
240
+ if num_tca > 0:
241
+ video_split = np.concatenate([video_split, tca_frames], 0)
242
+ mask_split = np.concatenate([mask_split, tca_masks], 0)
243
+
244
+ # inpaint each split
245
+ inpainted_splits.append(self.inpaint_efficient(video_split, mask_split, num_tcb, num_tca, dilate_radius, ratio))
246
+
247
+ inpainted_frames = np.concatenate(inpainted_splits, 0)
248
+ return inpainted_frames.astype(np.uint8)
249
+
250
+ if __name__ == '__main__':
251
+
252
+ frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
253
+ frame_path.sort()
254
+ mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png"))
255
+ mask_path.sort()
256
+ save_path = '/ssd1/gaomingqi/results/inpainting/parkour'
257
+
258
+ if not os.path.exists(save_path):
259
+ os.mkdir(save_path)
260
+
261
+ frames = []
262
+ masks = []
263
+ for fid, mid in zip(frame_path, mask_path):
264
+ frames.append(Image.open(fid).convert('RGB'))
265
+ masks.append(Image.open(mid).convert('P'))
266
+
267
+ frames = np.stack(frames, 0)
268
+ masks = np.stack(masks, 0)
269
+
270
+ # ----------------------------------------------
271
+ # how to use
272
+ # ----------------------------------------------
273
+ # 1/3: set checkpoint and device
274
+ checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth'
275
+ device = 'cuda:6'
276
+ # 2/3: initialise inpainter
277
+ base_inpainter = BaseInpainter(checkpoint, device)
278
+ # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
279
+ # ratio: (0, 1], ratio for down sample, default value is 1
280
+ inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.01) # numpy array, T, H, W, 3
281
+ # ----------------------------------------------
282
+ # end
283
+ # ----------------------------------------------
284
+ # save
285
+ for ti, inpainted_frame in enumerate(inpainted_frames):
286
+ frame = Image.fromarray(inpainted_frame).convert('RGB')
287
+ frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
inpainter/config/config.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # config info for E2FGVI
2
+ neighbor_stride: 5
3
+ num_ref: -1
4
+ step: 10
inpainter/model/e2fgvi.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Towards An End-to-End Framework for Video Inpainting
2
+ '''
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from model.modules.flow_comp import SPyNet
9
+ from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
10
+ from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp
11
+ from model.modules.spectral_norm import spectral_norm as _spectral_norm
12
+
13
+
14
+ class BaseNetwork(nn.Module):
15
+ def __init__(self):
16
+ super(BaseNetwork, self).__init__()
17
+
18
+ def print_network(self):
19
+ if isinstance(self, list):
20
+ self = self[0]
21
+ num_params = 0
22
+ for param in self.parameters():
23
+ num_params += param.numel()
24
+ print(
25
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
26
+ 'To see the architecture, do print(network).' %
27
+ (type(self).__name__, num_params / 1000000))
28
+
29
+ def init_weights(self, init_type='normal', gain=0.02):
30
+ '''
31
+ initialize network's weights
32
+ init_type: normal | xavier | kaiming | orthogonal
33
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
34
+ '''
35
+ def init_func(m):
36
+ classname = m.__class__.__name__
37
+ if classname.find('InstanceNorm2d') != -1:
38
+ if hasattr(m, 'weight') and m.weight is not None:
39
+ nn.init.constant_(m.weight.data, 1.0)
40
+ if hasattr(m, 'bias') and m.bias is not None:
41
+ nn.init.constant_(m.bias.data, 0.0)
42
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
43
+ or classname.find('Linear') != -1):
44
+ if init_type == 'normal':
45
+ nn.init.normal_(m.weight.data, 0.0, gain)
46
+ elif init_type == 'xavier':
47
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
48
+ elif init_type == 'xavier_uniform':
49
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
50
+ elif init_type == 'kaiming':
51
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
52
+ elif init_type == 'orthogonal':
53
+ nn.init.orthogonal_(m.weight.data, gain=gain)
54
+ elif init_type == 'none': # uses pytorch's default init method
55
+ m.reset_parameters()
56
+ else:
57
+ raise NotImplementedError(
58
+ 'initialization method [%s] is not implemented' %
59
+ init_type)
60
+ if hasattr(m, 'bias') and m.bias is not None:
61
+ nn.init.constant_(m.bias.data, 0.0)
62
+
63
+ self.apply(init_func)
64
+
65
+ # propagate to children
66
+ for m in self.children():
67
+ if hasattr(m, 'init_weights'):
68
+ m.init_weights(init_type, gain)
69
+
70
+
71
+ class Encoder(nn.Module):
72
+ def __init__(self):
73
+ super(Encoder, self).__init__()
74
+ self.group = [1, 2, 4, 8, 1]
75
+ self.layers = nn.ModuleList([
76
+ nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
77
+ nn.LeakyReLU(0.2, inplace=True),
78
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
79
+ nn.LeakyReLU(0.2, inplace=True),
80
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
81
+ nn.LeakyReLU(0.2, inplace=True),
82
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
83
+ nn.LeakyReLU(0.2, inplace=True),
84
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
85
+ nn.LeakyReLU(0.2, inplace=True),
86
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
87
+ nn.LeakyReLU(0.2, inplace=True),
88
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
89
+ nn.LeakyReLU(0.2, inplace=True),
90
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
91
+ nn.LeakyReLU(0.2, inplace=True),
92
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
93
+ nn.LeakyReLU(0.2, inplace=True)
94
+ ])
95
+
96
+ def forward(self, x):
97
+ bt, c, h, w = x.size()
98
+ h, w = h // 4, w // 4
99
+ out = x
100
+ for i, layer in enumerate(self.layers):
101
+ if i == 8:
102
+ x0 = out
103
+ if i > 8 and i % 2 == 0:
104
+ g = self.group[(i - 8) // 2]
105
+ x = x0.view(bt, g, -1, h, w)
106
+ o = out.view(bt, g, -1, h, w)
107
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
108
+ out = layer(out)
109
+ return out
110
+
111
+
112
+ class deconv(nn.Module):
113
+ def __init__(self,
114
+ input_channel,
115
+ output_channel,
116
+ kernel_size=3,
117
+ padding=0):
118
+ super().__init__()
119
+ self.conv = nn.Conv2d(input_channel,
120
+ output_channel,
121
+ kernel_size=kernel_size,
122
+ stride=1,
123
+ padding=padding)
124
+
125
+ def forward(self, x):
126
+ x = F.interpolate(x,
127
+ scale_factor=2,
128
+ mode='bilinear',
129
+ align_corners=True)
130
+ return self.conv(x)
131
+
132
+
133
+ class InpaintGenerator(BaseNetwork):
134
+ def __init__(self, init_weights=True):
135
+ super(InpaintGenerator, self).__init__()
136
+ channel = 256
137
+ hidden = 512
138
+
139
+ # encoder
140
+ self.encoder = Encoder()
141
+
142
+ # decoder
143
+ self.decoder = nn.Sequential(
144
+ deconv(channel // 2, 128, kernel_size=3, padding=1),
145
+ nn.LeakyReLU(0.2, inplace=True),
146
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
147
+ nn.LeakyReLU(0.2, inplace=True),
148
+ deconv(64, 64, kernel_size=3, padding=1),
149
+ nn.LeakyReLU(0.2, inplace=True),
150
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
151
+
152
+ # feature propagation module
153
+ self.feat_prop_module = BidirectionalPropagation(channel // 2)
154
+
155
+ # soft split and soft composition
156
+ kernel_size = (7, 7)
157
+ padding = (3, 3)
158
+ stride = (3, 3)
159
+ output_size = (60, 108)
160
+ t2t_params = {
161
+ 'kernel_size': kernel_size,
162
+ 'stride': stride,
163
+ 'padding': padding,
164
+ 'output_size': output_size
165
+ }
166
+ self.ss = SoftSplit(channel // 2,
167
+ hidden,
168
+ kernel_size,
169
+ stride,
170
+ padding,
171
+ t2t_param=t2t_params)
172
+ self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size,
173
+ stride, padding)
174
+
175
+ n_vecs = 1
176
+ for i, d in enumerate(kernel_size):
177
+ n_vecs *= int((output_size[i] + 2 * padding[i] -
178
+ (d - 1) - 1) / stride[i] + 1)
179
+
180
+ blocks = []
181
+ depths = 8
182
+ num_heads = [4] * depths
183
+ window_size = [(5, 9)] * depths
184
+ focal_windows = [(5, 9)] * depths
185
+ focal_levels = [2] * depths
186
+ pool_method = "fc"
187
+
188
+ for i in range(depths):
189
+ blocks.append(
190
+ TemporalFocalTransformerBlock(dim=hidden,
191
+ num_heads=num_heads[i],
192
+ window_size=window_size[i],
193
+ focal_level=focal_levels[i],
194
+ focal_window=focal_windows[i],
195
+ n_vecs=n_vecs,
196
+ t2t_params=t2t_params,
197
+ pool_method=pool_method))
198
+ self.transformer = nn.Sequential(*blocks)
199
+
200
+ if init_weights:
201
+ self.init_weights()
202
+ # Need to initial the weights of MSDeformAttn specifically
203
+ for m in self.modules():
204
+ if isinstance(m, SecondOrderDeformableAlignment):
205
+ m.init_offset()
206
+
207
+ # flow completion network
208
+ self.update_spynet = SPyNet()
209
+
210
+ def forward_bidirect_flow(self, masked_local_frames):
211
+ b, l_t, c, h, w = masked_local_frames.size()
212
+
213
+ # compute forward and backward flows of masked frames
214
+ masked_local_frames = F.interpolate(masked_local_frames.view(
215
+ -1, c, h, w),
216
+ scale_factor=1 / 4,
217
+ mode='bilinear',
218
+ align_corners=True,
219
+ recompute_scale_factor=True)
220
+ masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
221
+ w // 4)
222
+ mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
223
+ -1, c, h // 4, w // 4)
224
+ mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
225
+ -1, c, h // 4, w // 4)
226
+ pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
227
+ pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
228
+
229
+ pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
230
+ w // 4)
231
+ pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
232
+ w // 4)
233
+
234
+ return pred_flows_forward, pred_flows_backward
235
+
236
+ def forward(self, masked_frames, num_local_frames):
237
+ l_t = num_local_frames
238
+ b, t, ori_c, ori_h, ori_w = masked_frames.size()
239
+
240
+ # normalization before feeding into the flow completion module
241
+ masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
242
+ pred_flows = self.forward_bidirect_flow(masked_local_frames)
243
+
244
+ # extracting features and performing the feature propagation on local features
245
+ enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
246
+ _, c, h, w = enc_feat.size()
247
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
248
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
249
+ local_feat = self.feat_prop_module(local_feat, pred_flows[0],
250
+ pred_flows[1])
251
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
252
+
253
+ # content hallucination through stacking multiple temporal focal transformer blocks
254
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b)
255
+ trans_feat = self.transformer(trans_feat)
256
+ trans_feat = self.sc(trans_feat, t)
257
+ trans_feat = trans_feat.view(b, t, -1, h, w)
258
+ enc_feat = enc_feat + trans_feat
259
+
260
+ # decode frames from features
261
+ output = self.decoder(enc_feat.view(b * t, c, h, w))
262
+ output = torch.tanh(output)
263
+ return output, pred_flows
264
+
265
+
266
+ # ######################################################################
267
+ # Discriminator for Temporal Patch GAN
268
+ # ######################################################################
269
+
270
+
271
+ class Discriminator(BaseNetwork):
272
+ def __init__(self,
273
+ in_channels=3,
274
+ use_sigmoid=False,
275
+ use_spectral_norm=True,
276
+ init_weights=True):
277
+ super(Discriminator, self).__init__()
278
+ self.use_sigmoid = use_sigmoid
279
+ nf = 32
280
+
281
+ self.conv = nn.Sequential(
282
+ spectral_norm(
283
+ nn.Conv3d(in_channels=in_channels,
284
+ out_channels=nf * 1,
285
+ kernel_size=(3, 5, 5),
286
+ stride=(1, 2, 2),
287
+ padding=1,
288
+ bias=not use_spectral_norm), use_spectral_norm),
289
+ # nn.InstanceNorm2d(64, track_running_stats=False),
290
+ nn.LeakyReLU(0.2, inplace=True),
291
+ spectral_norm(
292
+ nn.Conv3d(nf * 1,
293
+ nf * 2,
294
+ kernel_size=(3, 5, 5),
295
+ stride=(1, 2, 2),
296
+ padding=(1, 2, 2),
297
+ bias=not use_spectral_norm), use_spectral_norm),
298
+ # nn.InstanceNorm2d(128, track_running_stats=False),
299
+ nn.LeakyReLU(0.2, inplace=True),
300
+ spectral_norm(
301
+ nn.Conv3d(nf * 2,
302
+ nf * 4,
303
+ kernel_size=(3, 5, 5),
304
+ stride=(1, 2, 2),
305
+ padding=(1, 2, 2),
306
+ bias=not use_spectral_norm), use_spectral_norm),
307
+ # nn.InstanceNorm2d(256, track_running_stats=False),
308
+ nn.LeakyReLU(0.2, inplace=True),
309
+ spectral_norm(
310
+ nn.Conv3d(nf * 4,
311
+ nf * 4,
312
+ kernel_size=(3, 5, 5),
313
+ stride=(1, 2, 2),
314
+ padding=(1, 2, 2),
315
+ bias=not use_spectral_norm), use_spectral_norm),
316
+ # nn.InstanceNorm2d(256, track_running_stats=False),
317
+ nn.LeakyReLU(0.2, inplace=True),
318
+ spectral_norm(
319
+ nn.Conv3d(nf * 4,
320
+ nf * 4,
321
+ kernel_size=(3, 5, 5),
322
+ stride=(1, 2, 2),
323
+ padding=(1, 2, 2),
324
+ bias=not use_spectral_norm), use_spectral_norm),
325
+ # nn.InstanceNorm2d(256, track_running_stats=False),
326
+ nn.LeakyReLU(0.2, inplace=True),
327
+ nn.Conv3d(nf * 4,
328
+ nf * 4,
329
+ kernel_size=(3, 5, 5),
330
+ stride=(1, 2, 2),
331
+ padding=(1, 2, 2)))
332
+
333
+ if init_weights:
334
+ self.init_weights()
335
+
336
+ def forward(self, xs):
337
+ # T, C, H, W = xs.shape (old)
338
+ # B, T, C, H, W (new)
339
+ xs_t = torch.transpose(xs, 1, 2)
340
+ feat = self.conv(xs_t)
341
+ if self.use_sigmoid:
342
+ feat = torch.sigmoid(feat)
343
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
344
+ return out
345
+
346
+
347
+ def spectral_norm(module, mode=True):
348
+ if mode:
349
+ return _spectral_norm(module)
350
+ return module
inpainter/model/e2fgvi_hq.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Towards An End-to-End Framework for Video Inpainting
2
+ '''
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from inpainter.model.modules.flow_comp import SPyNet
9
+ from inpainter.model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
10
+ from inpainter.model.modules.tfocal_transformer_hq import TemporalFocalTransformerBlock, SoftSplit, SoftComp
11
+ from inpainter.model.modules.spectral_norm import spectral_norm as _spectral_norm
12
+
13
+
14
+ class BaseNetwork(nn.Module):
15
+ def __init__(self):
16
+ super(BaseNetwork, self).__init__()
17
+
18
+ def print_network(self):
19
+ if isinstance(self, list):
20
+ self = self[0]
21
+ num_params = 0
22
+ for param in self.parameters():
23
+ num_params += param.numel()
24
+ print(
25
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
26
+ 'To see the architecture, do print(network).' %
27
+ (type(self).__name__, num_params / 1000000))
28
+
29
+ def init_weights(self, init_type='normal', gain=0.02):
30
+ '''
31
+ initialize network's weights
32
+ init_type: normal | xavier | kaiming | orthogonal
33
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
34
+ '''
35
+ def init_func(m):
36
+ classname = m.__class__.__name__
37
+ if classname.find('InstanceNorm2d') != -1:
38
+ if hasattr(m, 'weight') and m.weight is not None:
39
+ nn.init.constant_(m.weight.data, 1.0)
40
+ if hasattr(m, 'bias') and m.bias is not None:
41
+ nn.init.constant_(m.bias.data, 0.0)
42
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
43
+ or classname.find('Linear') != -1):
44
+ if init_type == 'normal':
45
+ nn.init.normal_(m.weight.data, 0.0, gain)
46
+ elif init_type == 'xavier':
47
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
48
+ elif init_type == 'xavier_uniform':
49
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
50
+ elif init_type == 'kaiming':
51
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
52
+ elif init_type == 'orthogonal':
53
+ nn.init.orthogonal_(m.weight.data, gain=gain)
54
+ elif init_type == 'none': # uses pytorch's default init method
55
+ m.reset_parameters()
56
+ else:
57
+ raise NotImplementedError(
58
+ 'initialization method [%s] is not implemented' %
59
+ init_type)
60
+ if hasattr(m, 'bias') and m.bias is not None:
61
+ nn.init.constant_(m.bias.data, 0.0)
62
+
63
+ self.apply(init_func)
64
+
65
+ # propagate to children
66
+ for m in self.children():
67
+ if hasattr(m, 'init_weights'):
68
+ m.init_weights(init_type, gain)
69
+
70
+
71
+ class Encoder(nn.Module):
72
+ def __init__(self):
73
+ super(Encoder, self).__init__()
74
+ self.group = [1, 2, 4, 8, 1]
75
+ self.layers = nn.ModuleList([
76
+ nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
77
+ nn.LeakyReLU(0.2, inplace=True),
78
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
79
+ nn.LeakyReLU(0.2, inplace=True),
80
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
81
+ nn.LeakyReLU(0.2, inplace=True),
82
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
83
+ nn.LeakyReLU(0.2, inplace=True),
84
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
85
+ nn.LeakyReLU(0.2, inplace=True),
86
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
87
+ nn.LeakyReLU(0.2, inplace=True),
88
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
89
+ nn.LeakyReLU(0.2, inplace=True),
90
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
91
+ nn.LeakyReLU(0.2, inplace=True),
92
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
93
+ nn.LeakyReLU(0.2, inplace=True)
94
+ ])
95
+
96
+ def forward(self, x):
97
+ bt, c, _, _ = x.size()
98
+ # h, w = h//4, w//4
99
+ out = x
100
+ for i, layer in enumerate(self.layers):
101
+ if i == 8:
102
+ x0 = out
103
+ _, _, h, w = x0.size()
104
+ if i > 8 and i % 2 == 0:
105
+ g = self.group[(i - 8) // 2]
106
+ x = x0.view(bt, g, -1, h, w)
107
+ o = out.view(bt, g, -1, h, w)
108
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
109
+ out = layer(out)
110
+ return out
111
+
112
+
113
+ class deconv(nn.Module):
114
+ def __init__(self,
115
+ input_channel,
116
+ output_channel,
117
+ kernel_size=3,
118
+ padding=0):
119
+ super().__init__()
120
+ self.conv = nn.Conv2d(input_channel,
121
+ output_channel,
122
+ kernel_size=kernel_size,
123
+ stride=1,
124
+ padding=padding)
125
+
126
+ def forward(self, x):
127
+ x = F.interpolate(x,
128
+ scale_factor=2,
129
+ mode='bilinear',
130
+ align_corners=True)
131
+ return self.conv(x)
132
+
133
+
134
+ class InpaintGenerator(BaseNetwork):
135
+ def __init__(self, init_weights=True):
136
+ super(InpaintGenerator, self).__init__()
137
+ channel = 256
138
+ hidden = 512
139
+
140
+ # encoder
141
+ self.encoder = Encoder()
142
+
143
+ # decoder
144
+ self.decoder = nn.Sequential(
145
+ deconv(channel // 2, 128, kernel_size=3, padding=1),
146
+ nn.LeakyReLU(0.2, inplace=True),
147
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
148
+ nn.LeakyReLU(0.2, inplace=True),
149
+ deconv(64, 64, kernel_size=3, padding=1),
150
+ nn.LeakyReLU(0.2, inplace=True),
151
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
152
+
153
+ # feature propagation module
154
+ self.feat_prop_module = BidirectionalPropagation(channel // 2)
155
+
156
+ # soft split and soft composition
157
+ kernel_size = (7, 7)
158
+ padding = (3, 3)
159
+ stride = (3, 3)
160
+ output_size = (60, 108)
161
+ t2t_params = {
162
+ 'kernel_size': kernel_size,
163
+ 'stride': stride,
164
+ 'padding': padding
165
+ }
166
+ self.ss = SoftSplit(channel // 2,
167
+ hidden,
168
+ kernel_size,
169
+ stride,
170
+ padding,
171
+ t2t_param=t2t_params)
172
+ self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding)
173
+
174
+ n_vecs = 1
175
+ for i, d in enumerate(kernel_size):
176
+ n_vecs *= int((output_size[i] + 2 * padding[i] -
177
+ (d - 1) - 1) / stride[i] + 1)
178
+
179
+ blocks = []
180
+ depths = 8
181
+ num_heads = [4] * depths
182
+ window_size = [(5, 9)] * depths
183
+ focal_windows = [(5, 9)] * depths
184
+ focal_levels = [2] * depths
185
+ pool_method = "fc"
186
+
187
+ for i in range(depths):
188
+ blocks.append(
189
+ TemporalFocalTransformerBlock(dim=hidden,
190
+ num_heads=num_heads[i],
191
+ window_size=window_size[i],
192
+ focal_level=focal_levels[i],
193
+ focal_window=focal_windows[i],
194
+ n_vecs=n_vecs,
195
+ t2t_params=t2t_params,
196
+ pool_method=pool_method))
197
+ self.transformer = nn.Sequential(*blocks)
198
+
199
+ if init_weights:
200
+ self.init_weights()
201
+ # Need to initial the weights of MSDeformAttn specifically
202
+ for m in self.modules():
203
+ if isinstance(m, SecondOrderDeformableAlignment):
204
+ m.init_offset()
205
+
206
+ # flow completion network
207
+ self.update_spynet = SPyNet()
208
+
209
+ def forward_bidirect_flow(self, masked_local_frames):
210
+ b, l_t, c, h, w = masked_local_frames.size()
211
+
212
+ # compute forward and backward flows of masked frames
213
+ masked_local_frames = F.interpolate(masked_local_frames.view(
214
+ -1, c, h, w),
215
+ scale_factor=1 / 4,
216
+ mode='bilinear',
217
+ align_corners=True,
218
+ recompute_scale_factor=True)
219
+ masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
220
+ w // 4)
221
+ mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
222
+ -1, c, h // 4, w // 4)
223
+ mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
224
+ -1, c, h // 4, w // 4)
225
+ pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
226
+ pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
227
+
228
+ pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
229
+ w // 4)
230
+ pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
231
+ w // 4)
232
+
233
+ return pred_flows_forward, pred_flows_backward
234
+
235
+ def forward(self, masked_frames, num_local_frames):
236
+ l_t = num_local_frames
237
+ b, t, ori_c, ori_h, ori_w = masked_frames.size()
238
+
239
+ # normalization before feeding into the flow completion module
240
+ masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
241
+ pred_flows = self.forward_bidirect_flow(masked_local_frames)
242
+
243
+ # extracting features and performing the feature propagation on local features
244
+ enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
245
+ _, c, h, w = enc_feat.size()
246
+ fold_output_size = (h, w)
247
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
248
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
249
+ local_feat = self.feat_prop_module(local_feat, pred_flows[0],
250
+ pred_flows[1])
251
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
252
+
253
+ # content hallucination through stacking multiple temporal focal transformer blocks
254
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size)
255
+ trans_feat = self.transformer([trans_feat, fold_output_size])
256
+ trans_feat = self.sc(trans_feat[0], t, fold_output_size)
257
+ trans_feat = trans_feat.view(b, t, -1, h, w)
258
+ enc_feat = enc_feat + trans_feat
259
+
260
+ # decode frames from features
261
+ output = self.decoder(enc_feat.view(b * t, c, h, w))
262
+ output = torch.tanh(output)
263
+ return output, pred_flows
264
+
265
+
266
+ # ######################################################################
267
+ # Discriminator for Temporal Patch GAN
268
+ # ######################################################################
269
+
270
+
271
+ class Discriminator(BaseNetwork):
272
+ def __init__(self,
273
+ in_channels=3,
274
+ use_sigmoid=False,
275
+ use_spectral_norm=True,
276
+ init_weights=True):
277
+ super(Discriminator, self).__init__()
278
+ self.use_sigmoid = use_sigmoid
279
+ nf = 32
280
+
281
+ self.conv = nn.Sequential(
282
+ spectral_norm(
283
+ nn.Conv3d(in_channels=in_channels,
284
+ out_channels=nf * 1,
285
+ kernel_size=(3, 5, 5),
286
+ stride=(1, 2, 2),
287
+ padding=1,
288
+ bias=not use_spectral_norm), use_spectral_norm),
289
+ # nn.InstanceNorm2d(64, track_running_stats=False),
290
+ nn.LeakyReLU(0.2, inplace=True),
291
+ spectral_norm(
292
+ nn.Conv3d(nf * 1,
293
+ nf * 2,
294
+ kernel_size=(3, 5, 5),
295
+ stride=(1, 2, 2),
296
+ padding=(1, 2, 2),
297
+ bias=not use_spectral_norm), use_spectral_norm),
298
+ # nn.InstanceNorm2d(128, track_running_stats=False),
299
+ nn.LeakyReLU(0.2, inplace=True),
300
+ spectral_norm(
301
+ nn.Conv3d(nf * 2,
302
+ nf * 4,
303
+ kernel_size=(3, 5, 5),
304
+ stride=(1, 2, 2),
305
+ padding=(1, 2, 2),
306
+ bias=not use_spectral_norm), use_spectral_norm),
307
+ # nn.InstanceNorm2d(256, track_running_stats=False),
308
+ nn.LeakyReLU(0.2, inplace=True),
309
+ spectral_norm(
310
+ nn.Conv3d(nf * 4,
311
+ nf * 4,
312
+ kernel_size=(3, 5, 5),
313
+ stride=(1, 2, 2),
314
+ padding=(1, 2, 2),
315
+ bias=not use_spectral_norm), use_spectral_norm),
316
+ # nn.InstanceNorm2d(256, track_running_stats=False),
317
+ nn.LeakyReLU(0.2, inplace=True),
318
+ spectral_norm(
319
+ nn.Conv3d(nf * 4,
320
+ nf * 4,
321
+ kernel_size=(3, 5, 5),
322
+ stride=(1, 2, 2),
323
+ padding=(1, 2, 2),
324
+ bias=not use_spectral_norm), use_spectral_norm),
325
+ # nn.InstanceNorm2d(256, track_running_stats=False),
326
+ nn.LeakyReLU(0.2, inplace=True),
327
+ nn.Conv3d(nf * 4,
328
+ nf * 4,
329
+ kernel_size=(3, 5, 5),
330
+ stride=(1, 2, 2),
331
+ padding=(1, 2, 2)))
332
+
333
+ if init_weights:
334
+ self.init_weights()
335
+
336
+ def forward(self, xs):
337
+ # T, C, H, W = xs.shape (old)
338
+ # B, T, C, H, W (new)
339
+ xs_t = torch.transpose(xs, 1, 2)
340
+ feat = self.conv(xs_t)
341
+ if self.use_sigmoid:
342
+ feat = torch.sigmoid(feat)
343
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
344
+ return out
345
+
346
+
347
+ def spectral_norm(module, mode=True):
348
+ if mode:
349
+ return _spectral_norm(module)
350
+ return module
inpainter/model/modules/feat_prop.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
8
+ from mmengine.model import constant_init
9
+
10
+ from inpainter.model.modules.flow_comp import flow_warp
11
+
12
+
13
+ class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
14
+ """Second-order deformable alignment module."""
15
+ def __init__(self, *args, **kwargs):
16
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
17
+
18
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
19
+
20
+ self.conv_offset = nn.Sequential(
21
+ nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
22
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
23
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
24
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
25
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
26
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
27
+ nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
28
+ )
29
+
30
+ self.init_offset()
31
+
32
+ def init_offset(self):
33
+ constant_init(self.conv_offset[-1], val=0, bias=0)
34
+
35
+ def forward(self, x, extra_feat, flow_1, flow_2):
36
+ extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
37
+ out = self.conv_offset(extra_feat)
38
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
39
+
40
+ # offset
41
+ offset = self.max_residue_magnitude * torch.tanh(
42
+ torch.cat((o1, o2), dim=1))
43
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
44
+ offset_1 = offset_1 + flow_1.flip(1).repeat(1,
45
+ offset_1.size(1) // 2, 1,
46
+ 1)
47
+ offset_2 = offset_2 + flow_2.flip(1).repeat(1,
48
+ offset_2.size(1) // 2, 1,
49
+ 1)
50
+ offset = torch.cat([offset_1, offset_2], dim=1)
51
+
52
+ # mask
53
+ mask = torch.sigmoid(mask)
54
+
55
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
56
+ self.stride, self.padding,
57
+ self.dilation, self.groups,
58
+ self.deform_groups)
59
+
60
+
61
+ class BidirectionalPropagation(nn.Module):
62
+ def __init__(self, channel):
63
+ super(BidirectionalPropagation, self).__init__()
64
+ modules = ['backward_', 'forward_']
65
+ self.deform_align = nn.ModuleDict()
66
+ self.backbone = nn.ModuleDict()
67
+ self.channel = channel
68
+
69
+ for i, module in enumerate(modules):
70
+ self.deform_align[module] = SecondOrderDeformableAlignment(
71
+ 2 * channel, channel, 3, padding=1, deform_groups=16)
72
+
73
+ self.backbone[module] = nn.Sequential(
74
+ nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
75
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
76
+ nn.Conv2d(channel, channel, 3, 1, 1),
77
+ )
78
+
79
+ self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
80
+
81
+ def forward(self, x, flows_backward, flows_forward):
82
+ """
83
+ x shape : [b, t, c, h, w]
84
+ return [b, t, c, h, w]
85
+ """
86
+ b, t, c, h, w = x.shape
87
+ feats = {}
88
+ feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
89
+
90
+ for module_name in ['backward_', 'forward_']:
91
+
92
+ feats[module_name] = []
93
+
94
+ frame_idx = range(0, t)
95
+ flow_idx = range(-1, t - 1)
96
+ mapping_idx = list(range(0, len(feats['spatial'])))
97
+ mapping_idx += mapping_idx[::-1]
98
+
99
+ if 'backward' in module_name:
100
+ frame_idx = frame_idx[::-1]
101
+ flows = flows_backward
102
+ else:
103
+ flows = flows_forward
104
+
105
+ feat_prop = x.new_zeros(b, self.channel, h, w)
106
+ for i, idx in enumerate(frame_idx):
107
+ feat_current = feats['spatial'][mapping_idx[idx]]
108
+
109
+ if i > 0:
110
+ flow_n1 = flows[:, flow_idx[i], :, :, :]
111
+ cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
112
+
113
+ # initialize second-order features
114
+ feat_n2 = torch.zeros_like(feat_prop)
115
+ flow_n2 = torch.zeros_like(flow_n1)
116
+ cond_n2 = torch.zeros_like(cond_n1)
117
+ if i > 1:
118
+ feat_n2 = feats[module_name][-2]
119
+ flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
120
+ flow_n2 = flow_n1 + flow_warp(
121
+ flow_n2, flow_n1.permute(0, 2, 3, 1))
122
+ cond_n2 = flow_warp(feat_n2,
123
+ flow_n2.permute(0, 2, 3, 1))
124
+
125
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
126
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
127
+ feat_prop = self.deform_align[module_name](feat_prop, cond,
128
+ flow_n1,
129
+ flow_n2)
130
+
131
+ feat = [feat_current] + [
132
+ feats[k][idx]
133
+ for k in feats if k not in ['spatial', module_name]
134
+ ] + [feat_prop]
135
+
136
+ feat = torch.cat(feat, dim=1)
137
+ feat_prop = feat_prop + self.backbone[module_name](feat)
138
+ feats[module_name].append(feat_prop)
139
+
140
+ if 'backward' in module_name:
141
+ feats[module_name] = feats[module_name][::-1]
142
+
143
+ outputs = []
144
+ for i in range(0, t):
145
+ align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
146
+ align_feats = torch.cat(align_feats, dim=1)
147
+ outputs.append(self.fusion(align_feats))
148
+
149
+ return torch.stack(outputs, dim=1) + x
inpainter/model/modules/flow_comp.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch
6
+
7
+ from mmcv.cnn import ConvModule
8
+ from mmengine.runner import load_checkpoint
9
+
10
+
11
+ class FlowCompletionLoss(nn.Module):
12
+ """Flow completion loss"""
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.fix_spynet = SPyNet()
16
+ for p in self.fix_spynet.parameters():
17
+ p.requires_grad = False
18
+
19
+ self.l1_criterion = nn.L1Loss()
20
+
21
+ def forward(self, pred_flows, gt_local_frames):
22
+ b, l_t, c, h, w = gt_local_frames.size()
23
+
24
+ with torch.no_grad():
25
+ # compute gt forward and backward flows
26
+ gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
27
+ scale_factor=1 / 4,
28
+ mode='bilinear',
29
+ align_corners=True,
30
+ recompute_scale_factor=True)
31
+ gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4)
32
+ gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
33
+ -1, c, h // 4, w // 4)
34
+ gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
35
+ -1, c, h // 4, w // 4)
36
+ gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
37
+ gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)
38
+
39
+ # calculate loss for flow completion
40
+ forward_flow_loss = self.l1_criterion(
41
+ pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
42
+ backward_flow_loss = self.l1_criterion(
43
+ pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
44
+ flow_loss = forward_flow_loss + backward_flow_loss
45
+
46
+ return flow_loss
47
+
48
+
49
+ class SPyNet(nn.Module):
50
+ """SPyNet network structure.
51
+ The difference to the SPyNet in [tof.py] is that
52
+ 1. more SPyNetBasicModule is used in this version, and
53
+ 2. no batch normalization is used in this version.
54
+ Paper:
55
+ Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
56
+ Args:
57
+ pretrained (str): path for pre-trained SPyNet. Default: None.
58
+ """
59
+ def __init__(
60
+ self,
61
+ use_pretrain=True,
62
+ pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'
63
+ ):
64
+ super().__init__()
65
+
66
+ self.basic_module = nn.ModuleList(
67
+ [SPyNetBasicModule() for _ in range(6)])
68
+
69
+ if use_pretrain:
70
+ if isinstance(pretrained, str):
71
+ print("load pretrained SPyNet...")
72
+ load_checkpoint(self, pretrained, strict=True)
73
+ elif pretrained is not None:
74
+ raise TypeError('[pretrained] should be str or None, '
75
+ f'but got {type(pretrained)}.')
76
+
77
+ self.register_buffer(
78
+ 'mean',
79
+ torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
80
+ self.register_buffer(
81
+ 'std',
82
+ torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
83
+
84
+ def compute_flow(self, ref, supp):
85
+ """Compute flow from ref to supp.
86
+ Note that in this function, the images are already resized to a
87
+ multiple of 32.
88
+ Args:
89
+ ref (Tensor): Reference image with shape of (n, 3, h, w).
90
+ supp (Tensor): Supporting image with shape of (n, 3, h, w).
91
+ Returns:
92
+ Tensor: Estimated optical flow: (n, 2, h, w).
93
+ """
94
+ n, _, h, w = ref.size()
95
+
96
+ # normalize the input images
97
+ ref = [(ref - self.mean) / self.std]
98
+ supp = [(supp - self.mean) / self.std]
99
+
100
+ # generate downsampled frames
101
+ for level in range(5):
102
+ ref.append(
103
+ F.avg_pool2d(input=ref[-1],
104
+ kernel_size=2,
105
+ stride=2,
106
+ count_include_pad=False))
107
+ supp.append(
108
+ F.avg_pool2d(input=supp[-1],
109
+ kernel_size=2,
110
+ stride=2,
111
+ count_include_pad=False))
112
+ ref = ref[::-1]
113
+ supp = supp[::-1]
114
+
115
+ # flow computation
116
+ flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
117
+ for level in range(len(ref)):
118
+ if level == 0:
119
+ flow_up = flow
120
+ else:
121
+ flow_up = F.interpolate(input=flow,
122
+ scale_factor=2,
123
+ mode='bilinear',
124
+ align_corners=True) * 2.0
125
+
126
+ # add the residue to the upsampled flow
127
+ flow = flow_up + self.basic_module[level](torch.cat([
128
+ ref[level],
129
+ flow_warp(supp[level],
130
+ flow_up.permute(0, 2, 3, 1).contiguous(),
131
+ padding_mode='border'), flow_up
132
+ ], 1))
133
+
134
+ return flow
135
+
136
+ def forward(self, ref, supp):
137
+ """Forward function of SPyNet.
138
+ This function computes the optical flow from ref to supp.
139
+ Args:
140
+ ref (Tensor): Reference image with shape of (n, 3, h, w).
141
+ supp (Tensor): Supporting image with shape of (n, 3, h, w).
142
+ Returns:
143
+ Tensor: Estimated optical flow: (n, 2, h, w).
144
+ """
145
+
146
+ # upsize to a multiple of 32
147
+ h, w = ref.shape[2:4]
148
+ w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
149
+ h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
150
+ ref = F.interpolate(input=ref,
151
+ size=(h_up, w_up),
152
+ mode='bilinear',
153
+ align_corners=False)
154
+ supp = F.interpolate(input=supp,
155
+ size=(h_up, w_up),
156
+ mode='bilinear',
157
+ align_corners=False)
158
+
159
+ # compute flow, and resize back to the original resolution
160
+ flow = F.interpolate(input=self.compute_flow(ref, supp),
161
+ size=(h, w),
162
+ mode='bilinear',
163
+ align_corners=False)
164
+
165
+ # adjust the flow values
166
+ flow[:, 0, :, :] *= float(w) / float(w_up)
167
+ flow[:, 1, :, :] *= float(h) / float(h_up)
168
+
169
+ return flow
170
+
171
+
172
+ class SPyNetBasicModule(nn.Module):
173
+ """Basic Module for SPyNet.
174
+ Paper:
175
+ Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
176
+ """
177
+ def __init__(self):
178
+ super().__init__()
179
+
180
+ self.basic_module = nn.Sequential(
181
+ ConvModule(in_channels=8,
182
+ out_channels=32,
183
+ kernel_size=7,
184
+ stride=1,
185
+ padding=3,
186
+ norm_cfg=None,
187
+ act_cfg=dict(type='ReLU')),
188
+ ConvModule(in_channels=32,
189
+ out_channels=64,
190
+ kernel_size=7,
191
+ stride=1,
192
+ padding=3,
193
+ norm_cfg=None,
194
+ act_cfg=dict(type='ReLU')),
195
+ ConvModule(in_channels=64,
196
+ out_channels=32,
197
+ kernel_size=7,
198
+ stride=1,
199
+ padding=3,
200
+ norm_cfg=None,
201
+ act_cfg=dict(type='ReLU')),
202
+ ConvModule(in_channels=32,
203
+ out_channels=16,
204
+ kernel_size=7,
205
+ stride=1,
206
+ padding=3,
207
+ norm_cfg=None,
208
+ act_cfg=dict(type='ReLU')),
209
+ ConvModule(in_channels=16,
210
+ out_channels=2,
211
+ kernel_size=7,
212
+ stride=1,
213
+ padding=3,
214
+ norm_cfg=None,
215
+ act_cfg=None))
216
+
217
+ def forward(self, tensor_input):
218
+ """
219
+ Args:
220
+ tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
221
+ 8 channels contain:
222
+ [reference image (3), neighbor image (3), initial flow (2)].
223
+ Returns:
224
+ Tensor: Refined flow with shape (b, 2, h, w)
225
+ """
226
+ return self.basic_module(tensor_input)
227
+
228
+
229
+ # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
230
+ def make_colorwheel():
231
+ """
232
+ Generates a color wheel for optical flow visualization as presented in:
233
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
234
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
235
+
236
+ Code follows the original C++ source code of Daniel Scharstein.
237
+ Code follows the the Matlab source code of Deqing Sun.
238
+
239
+ Returns:
240
+ np.ndarray: Color wheel
241
+ """
242
+
243
+ RY = 15
244
+ YG = 6
245
+ GC = 4
246
+ CB = 11
247
+ BM = 13
248
+ MR = 6
249
+
250
+ ncols = RY + YG + GC + CB + BM + MR
251
+ colorwheel = np.zeros((ncols, 3))
252
+ col = 0
253
+
254
+ # RY
255
+ colorwheel[0:RY, 0] = 255
256
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
257
+ col = col + RY
258
+ # YG
259
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
260
+ colorwheel[col:col + YG, 1] = 255
261
+ col = col + YG
262
+ # GC
263
+ colorwheel[col:col + GC, 1] = 255
264
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
265
+ col = col + GC
266
+ # CB
267
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
268
+ colorwheel[col:col + CB, 2] = 255
269
+ col = col + CB
270
+ # BM
271
+ colorwheel[col:col + BM, 2] = 255
272
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
273
+ col = col + BM
274
+ # MR
275
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
276
+ colorwheel[col:col + MR, 0] = 255
277
+ return colorwheel
278
+
279
+
280
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
281
+ """
282
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
283
+
284
+ According to the C++ source code of Daniel Scharstein
285
+ According to the Matlab source code of Deqing Sun
286
+
287
+ Args:
288
+ u (np.ndarray): Input horizontal flow of shape [H,W]
289
+ v (np.ndarray): Input vertical flow of shape [H,W]
290
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
291
+
292
+ Returns:
293
+ np.ndarray: Flow visualization image of shape [H,W,3]
294
+ """
295
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
296
+ colorwheel = make_colorwheel() # shape [55x3]
297
+ ncols = colorwheel.shape[0]
298
+ rad = np.sqrt(np.square(u) + np.square(v))
299
+ a = np.arctan2(-v, -u) / np.pi
300
+ fk = (a + 1) / 2 * (ncols - 1)
301
+ k0 = np.floor(fk).astype(np.int32)
302
+ k1 = k0 + 1
303
+ k1[k1 == ncols] = 0
304
+ f = fk - k0
305
+ for i in range(colorwheel.shape[1]):
306
+ tmp = colorwheel[:, i]
307
+ col0 = tmp[k0] / 255.0
308
+ col1 = tmp[k1] / 255.0
309
+ col = (1 - f) * col0 + f * col1
310
+ idx = (rad <= 1)
311
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
312
+ col[~idx] = col[~idx] * 0.75 # out of range
313
+ # Note the 2-i => BGR instead of RGB
314
+ ch_idx = 2 - i if convert_to_bgr else i
315
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
316
+ return flow_image
317
+
318
+
319
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
320
+ """
321
+ Expects a two dimensional flow image of shape.
322
+
323
+ Args:
324
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
325
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
326
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
327
+
328
+ Returns:
329
+ np.ndarray: Flow visualization image of shape [H,W,3]
330
+ """
331
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
332
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
333
+ if clip_flow is not None:
334
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
335
+ u = flow_uv[:, :, 0]
336
+ v = flow_uv[:, :, 1]
337
+ rad = np.sqrt(np.square(u) + np.square(v))
338
+ rad_max = np.max(rad)
339
+ epsilon = 1e-5
340
+ u = u / (rad_max + epsilon)
341
+ v = v / (rad_max + epsilon)
342
+ return flow_uv_to_colors(u, v, convert_to_bgr)
343
+
344
+
345
+ def flow_warp(x,
346
+ flow,
347
+ interpolation='bilinear',
348
+ padding_mode='zeros',
349
+ align_corners=True):
350
+ """Warp an image or a feature map with optical flow.
351
+ Args:
352
+ x (Tensor): Tensor with size (n, c, h, w).
353
+ flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
354
+ a two-channel, denoting the width and height relative offsets.
355
+ Note that the values are not normalized to [-1, 1].
356
+ interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
357
+ Default: 'bilinear'.
358
+ padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
359
+ Default: 'zeros'.
360
+ align_corners (bool): Whether align corners. Default: True.
361
+ Returns:
362
+ Tensor: Warped image or feature map.
363
+ """
364
+ if x.size()[-2:] != flow.size()[1:3]:
365
+ raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
366
+ f'flow ({flow.size()[1:3]}) are not the same.')
367
+ _, _, h, w = x.size()
368
+ # create mesh grid
369
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
370
+ grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
371
+ grid.requires_grad = False
372
+
373
+ grid_flow = grid + flow
374
+ # scale grid_flow to [-1,1]
375
+ grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
376
+ grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
377
+ grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
378
+ output = F.grid_sample(x,
379
+ grid_flow,
380
+ mode=interpolation,
381
+ padding_mode=padding_mode,
382
+ align_corners=align_corners)
383
+ return output
384
+
385
+
386
+ def initial_mask_flow(mask):
387
+ """
388
+ mask 1 indicates valid pixel 0 indicates unknown pixel
389
+ """
390
+ B, T, C, H, W = mask.shape
391
+
392
+ # calculate relative position
393
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
394
+
395
+ grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
396
+ abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
397
+ relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])
398
+
399
+ abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
400
+ relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])
401
+
402
+ # calculate the nearest indices
403
+ pos_up = mask.unsqueeze(3).repeat(
404
+ 1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
405
+ relative_pos_y <= H)[None, None, None]
406
+ nearest_indice_up = pos_up.max(dim=4)[1]
407
+
408
+ pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
409
+ None, None, None] * (relative_pos_y <= H)[None, None, None]
410
+ nearest_indice_down = (pos_down).max(dim=4)[1]
411
+
412
+ pos_left = mask.unsqueeze(4).repeat(
413
+ 1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
414
+ relative_pos_x <= W)[None, None, None]
415
+ nearest_indice_left = (pos_left).max(dim=5)[1]
416
+
417
+ pos_right = mask.unsqueeze(4).repeat(
418
+ 1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
419
+ relative_pos_x <= W)[None, None, None]
420
+ nearest_indice_right = (pos_right).max(dim=5)[1]
421
+
422
+ # NOTE: IMPORTANT !!! depending on how to use this offset
423
+ initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
424
+ initial_offset_down = nearest_indice_down - grid_y[None, None, None]
425
+
426
+ initial_offset_left = -(nearest_indice_left -
427
+ grid_x[None, None, None]).flip(4)
428
+ initial_offset_right = nearest_indice_right - grid_x[None, None, None]
429
+
430
+ # nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
431
+ # initial_offset_x = nearest_indice_x - grid_x
432
+
433
+ # handle the boundary cases
434
+ final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
435
+ initial_offset_down > 0) * initial_offset_down
436
+ final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
437
+ initial_offset_up < 0) * initial_offset_up
438
+ final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
439
+ initial_offset_right > 0) * initial_offset_right
440
+ final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
441
+ initial_offset_left < 0) * initial_offset_left
442
+ zero_offset = torch.zeros_like(final_offset_down)
443
+ # out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
444
+ out = torch.cat([
445
+ zero_offset, final_offset_left, zero_offset, final_offset_right,
446
+ final_offset_up, zero_offset, final_offset_down, zero_offset
447
+ ],
448
+ dim=2)
449
+
450
+ return out
inpainter/model/modules/spectral_norm.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spectral Normalization from https://arxiv.org/abs/1802.05957
3
+ """
4
+ import torch
5
+ from torch.nn.functional import normalize
6
+
7
+
8
+ class SpectralNorm(object):
9
+ # Invariant before and after each forward call:
10
+ # u = normalize(W @ v)
11
+ # NB: At initialization, this invariant is not enforced
12
+
13
+ _version = 1
14
+
15
+ # At version 1:
16
+ # made `W` not a buffer,
17
+ # added `v` as a buffer, and
18
+ # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
19
+
20
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
21
+ self.name = name
22
+ self.dim = dim
23
+ if n_power_iterations <= 0:
24
+ raise ValueError(
25
+ 'Expected n_power_iterations to be positive, but '
26
+ 'got n_power_iterations={}'.format(n_power_iterations))
27
+ self.n_power_iterations = n_power_iterations
28
+ self.eps = eps
29
+
30
+ def reshape_weight_to_matrix(self, weight):
31
+ weight_mat = weight
32
+ if self.dim != 0:
33
+ # permute dim to front
34
+ weight_mat = weight_mat.permute(
35
+ self.dim,
36
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
37
+ height = weight_mat.size(0)
38
+ return weight_mat.reshape(height, -1)
39
+
40
+ def compute_weight(self, module, do_power_iteration):
41
+ # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
42
+ # updated in power iteration **in-place**. This is very important
43
+ # because in `DataParallel` forward, the vectors (being buffers) are
44
+ # broadcast from the parallelized module to each module replica,
45
+ # which is a new module object created on the fly. And each replica
46
+ # runs its own spectral norm power iteration. So simply assigning
47
+ # the updated vectors to the module this function runs on will cause
48
+ # the update to be lost forever. And the next time the parallelized
49
+ # module is replicated, the same randomly initialized vectors are
50
+ # broadcast and used!
51
+ #
52
+ # Therefore, to make the change propagate back, we rely on two
53
+ # important behaviors (also enforced via tests):
54
+ # 1. `DataParallel` doesn't clone storage if the broadcast tensor
55
+ # is already on correct device; and it makes sure that the
56
+ # parallelized module is already on `device[0]`.
57
+ # 2. If the out tensor in `out=` kwarg has correct shape, it will
58
+ # just fill in the values.
59
+ # Therefore, since the same power iteration is performed on all
60
+ # devices, simply updating the tensors in-place will make sure that
61
+ # the module replica on `device[0]` will update the _u vector on the
62
+ # parallized module (by shared storage).
63
+ #
64
+ # However, after we update `u` and `v` in-place, we need to **clone**
65
+ # them before using them to normalize the weight. This is to support
66
+ # backproping through two forward passes, e.g., the common pattern in
67
+ # GAN training: loss = D(real) - D(fake). Otherwise, engine will
68
+ # complain that variables needed to do backward for the first forward
69
+ # (i.e., the `u` and `v` vectors) are changed in the second forward.
70
+ weight = getattr(module, self.name + '_orig')
71
+ u = getattr(module, self.name + '_u')
72
+ v = getattr(module, self.name + '_v')
73
+ weight_mat = self.reshape_weight_to_matrix(weight)
74
+
75
+ if do_power_iteration:
76
+ with torch.no_grad():
77
+ for _ in range(self.n_power_iterations):
78
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
79
+ # are the first left and right singular vectors.
80
+ # This power iteration produces approximations of `u` and `v`.
81
+ v = normalize(torch.mv(weight_mat.t(), u),
82
+ dim=0,
83
+ eps=self.eps,
84
+ out=v)
85
+ u = normalize(torch.mv(weight_mat, v),
86
+ dim=0,
87
+ eps=self.eps,
88
+ out=u)
89
+ if self.n_power_iterations > 0:
90
+ # See above on why we need to clone
91
+ u = u.clone()
92
+ v = v.clone()
93
+
94
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
95
+ weight = weight / sigma
96
+ return weight
97
+
98
+ def remove(self, module):
99
+ with torch.no_grad():
100
+ weight = self.compute_weight(module, do_power_iteration=False)
101
+ delattr(module, self.name)
102
+ delattr(module, self.name + '_u')
103
+ delattr(module, self.name + '_v')
104
+ delattr(module, self.name + '_orig')
105
+ module.register_parameter(self.name,
106
+ torch.nn.Parameter(weight.detach()))
107
+
108
+ def __call__(self, module, inputs):
109
+ setattr(
110
+ module, self.name,
111
+ self.compute_weight(module, do_power_iteration=module.training))
112
+
113
+ def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
114
+ # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
115
+ # (the invariant at top of this class) and `u @ W @ v = sigma`.
116
+ # This uses pinverse in case W^T W is not invertible.
117
+ v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
118
+ weight_mat.t(), u.unsqueeze(1)).squeeze(1)
119
+ return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
120
+
121
+ @staticmethod
122
+ def apply(module, name, n_power_iterations, dim, eps):
123
+ for k, hook in module._forward_pre_hooks.items():
124
+ if isinstance(hook, SpectralNorm) and hook.name == name:
125
+ raise RuntimeError(
126
+ "Cannot register two spectral_norm hooks on "
127
+ "the same parameter {}".format(name))
128
+
129
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
130
+ weight = module._parameters[name]
131
+
132
+ with torch.no_grad():
133
+ weight_mat = fn.reshape_weight_to_matrix(weight)
134
+
135
+ h, w = weight_mat.size()
136
+ # randomly initialize `u` and `v`
137
+ u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
138
+ v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
139
+
140
+ delattr(module, fn.name)
141
+ module.register_parameter(fn.name + "_orig", weight)
142
+ # We still need to assign weight back as fn.name because all sorts of
143
+ # things may assume that it exists, e.g., when initializing weights.
144
+ # However, we can't directly assign as it could be an nn.Parameter and
145
+ # gets added as a parameter. Instead, we register weight.data as a plain
146
+ # attribute.
147
+ setattr(module, fn.name, weight.data)
148
+ module.register_buffer(fn.name + "_u", u)
149
+ module.register_buffer(fn.name + "_v", v)
150
+
151
+ module.register_forward_pre_hook(fn)
152
+
153
+ module._register_state_dict_hook(SpectralNormStateDictHook(fn))
154
+ module._register_load_state_dict_pre_hook(
155
+ SpectralNormLoadStateDictPreHook(fn))
156
+ return fn
157
+
158
+
159
+ # This is a top level class because Py2 pickle doesn't like inner class nor an
160
+ # instancemethod.
161
+ class SpectralNormLoadStateDictPreHook(object):
162
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
163
+ def __init__(self, fn):
164
+ self.fn = fn
165
+
166
+ # For state_dict with version None, (assuming that it has gone through at
167
+ # least one training forward), we have
168
+ #
169
+ # u = normalize(W_orig @ v)
170
+ # W = W_orig / sigma, where sigma = u @ W_orig @ v
171
+ #
172
+ # To compute `v`, we solve `W_orig @ x = u`, and let
173
+ # v = x / (u @ W_orig @ x) * (W / W_orig).
174
+ def __call__(self, state_dict, prefix, local_metadata, strict,
175
+ missing_keys, unexpected_keys, error_msgs):
176
+ fn = self.fn
177
+ version = local_metadata.get('spectral_norm',
178
+ {}).get(fn.name + '.version', None)
179
+ if version is None or version < 1:
180
+ with torch.no_grad():
181
+ weight_orig = state_dict[prefix + fn.name + '_orig']
182
+ # weight = state_dict.pop(prefix + fn.name)
183
+ # sigma = (weight_orig / weight).mean()
184
+ weight_mat = fn.reshape_weight_to_matrix(weight_orig)
185
+ u = state_dict[prefix + fn.name + '_u']
186
+ # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
187
+ # state_dict[prefix + fn.name + '_v'] = v
188
+
189
+
190
+ # This is a top level class because Py2 pickle doesn't like inner class nor an
191
+ # instancemethod.
192
+ class SpectralNormStateDictHook(object):
193
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
194
+ def __init__(self, fn):
195
+ self.fn = fn
196
+
197
+ def __call__(self, module, state_dict, prefix, local_metadata):
198
+ if 'spectral_norm' not in local_metadata:
199
+ local_metadata['spectral_norm'] = {}
200
+ key = self.fn.name + '.version'
201
+ if key in local_metadata['spectral_norm']:
202
+ raise RuntimeError(
203
+ "Unexpected key in metadata['spectral_norm']: {}".format(key))
204
+ local_metadata['spectral_norm'][key] = self.fn._version
205
+
206
+
207
+ def spectral_norm(module,
208
+ name='weight',
209
+ n_power_iterations=1,
210
+ eps=1e-12,
211
+ dim=None):
212
+ r"""Applies spectral normalization to a parameter in the given module.
213
+
214
+ .. math::
215
+ \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
216
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
217
+
218
+ Spectral normalization stabilizes the training of discriminators (critics)
219
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
220
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
221
+ power iteration method. If the dimension of the weight tensor is greater
222
+ than 2, it is reshaped to 2D in power iteration method to get spectral
223
+ norm. This is implemented via a hook that calculates spectral norm and
224
+ rescales weight before every :meth:`~Module.forward` call.
225
+
226
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
227
+
228
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
229
+
230
+ Args:
231
+ module (nn.Module): containing module
232
+ name (str, optional): name of weight parameter
233
+ n_power_iterations (int, optional): number of power iterations to
234
+ calculate spectral norm
235
+ eps (float, optional): epsilon for numerical stability in
236
+ calculating norms
237
+ dim (int, optional): dimension corresponding to number of outputs,
238
+ the default is ``0``, except for modules that are instances of
239
+ ConvTranspose{1,2,3}d, when it is ``1``
240
+
241
+ Returns:
242
+ The original module with the spectral norm hook
243
+
244
+ Example::
245
+
246
+ >>> m = spectral_norm(nn.Linear(20, 40))
247
+ >>> m
248
+ Linear(in_features=20, out_features=40, bias=True)
249
+ >>> m.weight_u.size()
250
+ torch.Size([40])
251
+
252
+ """
253
+ if dim is None:
254
+ if isinstance(module,
255
+ (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
256
+ torch.nn.ConvTranspose3d)):
257
+ dim = 1
258
+ else:
259
+ dim = 0
260
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
261
+ return module
262
+
263
+
264
+ def remove_spectral_norm(module, name='weight'):
265
+ r"""Removes the spectral normalization reparameterization from a module.
266
+
267
+ Args:
268
+ module (Module): containing module
269
+ name (str, optional): name of weight parameter
270
+
271
+ Example:
272
+ >>> m = spectral_norm(nn.Linear(40, 10))
273
+ >>> remove_spectral_norm(m)
274
+ """
275
+ for k, hook in module._forward_pre_hooks.items():
276
+ if isinstance(hook, SpectralNorm) and hook.name == name:
277
+ hook.remove(module)
278
+ del module._forward_pre_hooks[k]
279
+ return module
280
+
281
+ raise ValueError("spectral_norm of '{}' not found in {}".format(
282
+ name, module))
283
+
284
+
285
+ def use_spectral_norm(module, use_sn=False):
286
+ if use_sn:
287
+ return spectral_norm(module)
288
+ return module
inpainter/model/modules/tfocal_transformer.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is based on:
3
+ [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
4
+ https://github.com/ruiliu-ai/FuseFormer
5
+ [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
6
+ https://github.com/yitu-opensource/T2T-ViT
7
+ [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
8
+ https://github.com/microsoft/Focal-Transformer
9
+ """
10
+
11
+ import math
12
+ from functools import reduce
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class SoftSplit(nn.Module):
20
+ def __init__(self, channel, hidden, kernel_size, stride, padding,
21
+ t2t_param):
22
+ super(SoftSplit, self).__init__()
23
+ self.kernel_size = kernel_size
24
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
25
+ stride=stride,
26
+ padding=padding)
27
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
28
+ self.embedding = nn.Linear(c_in, hidden)
29
+
30
+ self.f_h = int(
31
+ (t2t_param['output_size'][0] + 2 * t2t_param['padding'][0] -
32
+ (t2t_param['kernel_size'][0] - 1) - 1) / t2t_param['stride'][0] +
33
+ 1)
34
+ self.f_w = int(
35
+ (t2t_param['output_size'][1] + 2 * t2t_param['padding'][1] -
36
+ (t2t_param['kernel_size'][1] - 1) - 1) / t2t_param['stride'][1] +
37
+ 1)
38
+
39
+ def forward(self, x, b):
40
+ feat = self.t2t(x)
41
+ feat = feat.permute(0, 2, 1)
42
+ # feat shape [b*t, num_vec, ks*ks*c]
43
+ feat = self.embedding(feat)
44
+ # feat shape after embedding [b, t*num_vec, hidden]
45
+ feat = feat.view(b, -1, self.f_h, self.f_w, feat.size(2))
46
+ return feat
47
+
48
+
49
+ class SoftComp(nn.Module):
50
+ def __init__(self, channel, hidden, output_size, kernel_size, stride,
51
+ padding):
52
+ super(SoftComp, self).__init__()
53
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
54
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
55
+ self.embedding = nn.Linear(hidden, c_out)
56
+ self.t2t = torch.nn.Fold(output_size=output_size,
57
+ kernel_size=kernel_size,
58
+ stride=stride,
59
+ padding=padding)
60
+ h, w = output_size
61
+ self.bias = nn.Parameter(torch.zeros((channel, h, w),
62
+ dtype=torch.float32),
63
+ requires_grad=True)
64
+
65
+ def forward(self, x, t):
66
+ b_, _, _, _, c_ = x.shape
67
+ x = x.view(b_, -1, c_)
68
+ feat = self.embedding(x)
69
+ b, _, c = feat.size()
70
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
71
+ feat = self.t2t(feat) + self.bias[None]
72
+ return feat
73
+
74
+
75
+ class FusionFeedForward(nn.Module):
76
+ def __init__(self, d_model, n_vecs=None, t2t_params=None):
77
+ super(FusionFeedForward, self).__init__()
78
+ # We set d_ff as a default to 1960
79
+ hd = 1960
80
+ self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
81
+ self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
82
+ assert t2t_params is not None and n_vecs is not None
83
+ tp = t2t_params.copy()
84
+ self.fold = nn.Fold(**tp)
85
+ del tp['output_size']
86
+ self.unfold = nn.Unfold(**tp)
87
+ self.n_vecs = n_vecs
88
+
89
+ def forward(self, x):
90
+ x = self.conv1(x)
91
+ b, n, c = x.size()
92
+ normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs,
93
+ 49).permute(0, 2, 1)
94
+ x = self.unfold(
95
+ self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) /
96
+ self.fold(normalizer)).permute(0, 2, 1).contiguous().view(b, n, c)
97
+ x = self.conv2(x)
98
+ return x
99
+
100
+
101
+ def window_partition(x, window_size):
102
+ """
103
+ Args:
104
+ x: shape is (B, T, H, W, C)
105
+ window_size (tuple[int]): window size
106
+ Returns:
107
+ windows: (B*num_windows, T*window_size*window_size, C)
108
+ """
109
+ B, T, H, W, C = x.shape
110
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
111
+ window_size[1], C)
112
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
113
+ -1, T * window_size[0] * window_size[1], C)
114
+ return windows
115
+
116
+
117
+ def window_partition_noreshape(x, window_size):
118
+ """
119
+ Args:
120
+ x: shape is (B, T, H, W, C)
121
+ window_size (tuple[int]): window size
122
+ Returns:
123
+ windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
124
+ """
125
+ B, T, H, W, C = x.shape
126
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
127
+ window_size[1], C)
128
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
129
+ return windows
130
+
131
+
132
+ def window_reverse(windows, window_size, T, H, W):
133
+ """
134
+ Args:
135
+ windows: shape is (num_windows*B, T, window_size, window_size, C)
136
+ window_size (tuple[int]): Window size
137
+ T (int): Temporal length of video
138
+ H (int): Height of image
139
+ W (int): Width of image
140
+ Returns:
141
+ x: (B, T, H, W, C)
142
+ """
143
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
144
+ x = windows.view(B, H // window_size[0], W // window_size[1], T,
145
+ window_size[0], window_size[1], -1)
146
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
147
+ return x
148
+
149
+
150
+ class WindowAttention(nn.Module):
151
+ """Temporal focal window attention
152
+ """
153
+ def __init__(self, dim, expand_size, window_size, focal_window,
154
+ focal_level, num_heads, qkv_bias, pool_method):
155
+
156
+ super().__init__()
157
+ self.dim = dim
158
+ self.expand_size = expand_size
159
+ self.window_size = window_size # Wh, Ww
160
+ self.pool_method = pool_method
161
+ self.num_heads = num_heads
162
+ head_dim = dim // num_heads
163
+ self.scale = head_dim**-0.5
164
+ self.focal_level = focal_level
165
+ self.focal_window = focal_window
166
+
167
+ if any(i > 0 for i in self.expand_size) and focal_level > 0:
168
+ # get mask for rolled k and rolled v
169
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
170
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
171
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
172
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
173
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
174
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
175
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
176
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
177
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
178
+ 0).flatten(0)
179
+ self.register_buffer("valid_ind_rolled",
180
+ mask_rolled.nonzero(as_tuple=False).view(-1))
181
+
182
+ if pool_method != "none" and focal_level > 1:
183
+ self.unfolds = nn.ModuleList()
184
+
185
+ # build relative position bias between local patch and pooled windows
186
+ for k in range(focal_level - 1):
187
+ stride = 2**k
188
+ kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
189
+ for i in self.focal_window)
190
+ # define unfolding operations
191
+ self.unfolds += [
192
+ nn.Unfold(kernel_size=kernel_size,
193
+ stride=stride,
194
+ padding=tuple(i // 2 for i in kernel_size))
195
+ ]
196
+
197
+ # define unfolding index for focal_level > 0
198
+ if k > 0:
199
+ mask = torch.zeros(kernel_size)
200
+ mask[(2**k) - 1:, (2**k) - 1:] = 1
201
+ self.register_buffer(
202
+ "valid_ind_unfold_{}".format(k),
203
+ mask.flatten(0).nonzero(as_tuple=False).view(-1))
204
+
205
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
206
+ self.proj = nn.Linear(dim, dim)
207
+
208
+ self.softmax = nn.Softmax(dim=-1)
209
+
210
+ def forward(self, x_all, mask_all=None):
211
+ """
212
+ Args:
213
+ x: input features with shape of (B, T, Wh, Ww, C)
214
+ mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
215
+
216
+ output: (nW*B, Wh*Ww, C)
217
+ """
218
+ x = x_all[0]
219
+
220
+ B, T, nH, nW, C = x.shape
221
+ qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
222
+ C).permute(4, 0, 1, 2, 3, 5).contiguous()
223
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
224
+
225
+ # partition q map
226
+ (q_windows, k_windows, v_windows) = map(
227
+ lambda t: window_partition(t, self.window_size).view(
228
+ -1, T, self.window_size[0] * self.window_size[1], self.
229
+ num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
230
+ contiguous().view(-1, self.num_heads, T * self.window_size[
231
+ 0] * self.window_size[1], C // self.num_heads), (q, k, v))
232
+ # q(k/v)_windows shape : [16, 4, 225, 128]
233
+
234
+ if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
235
+ (k_tl, v_tl) = map(
236
+ lambda t: torch.roll(t,
237
+ shifts=(-self.expand_size[0], -self.
238
+ expand_size[1]),
239
+ dims=(2, 3)), (k, v))
240
+ (k_tr, v_tr) = map(
241
+ lambda t: torch.roll(t,
242
+ shifts=(-self.expand_size[0], self.
243
+ expand_size[1]),
244
+ dims=(2, 3)), (k, v))
245
+ (k_bl, v_bl) = map(
246
+ lambda t: torch.roll(t,
247
+ shifts=(self.expand_size[0], -self.
248
+ expand_size[1]),
249
+ dims=(2, 3)), (k, v))
250
+ (k_br, v_br) = map(
251
+ lambda t: torch.roll(t,
252
+ shifts=(self.expand_size[0], self.
253
+ expand_size[1]),
254
+ dims=(2, 3)), (k, v))
255
+
256
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
257
+ lambda t: window_partition(t, self.window_size).view(
258
+ -1, T, self.window_size[0] * self.window_size[1], self.
259
+ num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
260
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
261
+ lambda t: window_partition(t, self.window_size).view(
262
+ -1, T, self.window_size[0] * self.window_size[1], self.
263
+ num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
264
+ k_rolled = torch.cat(
265
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
266
+ 2).permute(0, 3, 1, 2, 4).contiguous()
267
+ v_rolled = torch.cat(
268
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
269
+ 2).permute(0, 3, 1, 2, 4).contiguous()
270
+
271
+ # mask out tokens in current window
272
+ k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
273
+ v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
274
+ temp_N = k_rolled.shape[3]
275
+ k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
276
+ C // self.num_heads)
277
+ v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
278
+ C // self.num_heads)
279
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
280
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
281
+ else:
282
+ k_rolled = k_windows
283
+ v_rolled = v_windows
284
+
285
+ # q(k/v)_windows shape : [16, 4, 225, 128]
286
+ # k_rolled.shape : [16, 4, 5, 165, 128]
287
+ # ideal expanded window size 153 ((5+2*2)*(9+2*4))
288
+ # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
289
+
290
+ if self.pool_method != "none" and self.focal_level > 1:
291
+ k_pooled = []
292
+ v_pooled = []
293
+ for k in range(self.focal_level - 1):
294
+ stride = 2**k
295
+ x_window_pooled = x_all[k + 1].permute(
296
+ 0, 3, 1, 2, 4).contiguous() # B, T, nWh, nWw, C
297
+
298
+ nWh, nWw = x_window_pooled.shape[2:4]
299
+
300
+ # generate mask for pooled windows
301
+ mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
302
+ # unfold mask: [nWh*nWw//s//s, k*k, 1]
303
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
304
+ 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
305
+ view(nWh*nWw // stride // stride, -1, 1)
306
+
307
+ if k > 0:
308
+ valid_ind_unfold_k = getattr(
309
+ self, "valid_ind_unfold_{}".format(k))
310
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
311
+
312
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
313
+ x_window_masks = x_window_masks.masked_fill(
314
+ x_window_masks == 0,
315
+ float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
316
+ mask_all[k + 1] = x_window_masks
317
+
318
+ # generate k and v for pooled windows
319
+ qkv_pooled = self.qkv(x_window_pooled).reshape(
320
+ B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
321
+ 3).view(3, -1, C, nWh,
322
+ nWw).contiguous()
323
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[
324
+ 2] # B*T, C, nWh, nWw
325
+ # k_pooled_k shape: [5, 512, 4, 4]
326
+ # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
327
+
328
+ (k_pooled_k, v_pooled_k) = map(
329
+ lambda t: self.unfolds[k](t).view(
330
+ B, T, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 5, 1, 3, 4, 2).contiguous().\
331
+ view(-1, T, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).contiguous(),
332
+ (k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
333
+ )
334
+ # k_pooled_k shape : [16, 4, 5, 45, 128]
335
+
336
+ # select valid unfolding index
337
+ if k > 0:
338
+ (k_pooled_k, v_pooled_k) = map(
339
+ lambda t: t[:, :, :, valid_ind_unfold_k],
340
+ (k_pooled_k, v_pooled_k))
341
+
342
+ k_pooled_k = k_pooled_k.view(
343
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
344
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
345
+ v_pooled_k = v_pooled_k.view(
346
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
347
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
348
+
349
+ k_pooled += [k_pooled_k]
350
+ v_pooled += [v_pooled_k]
351
+
352
+ # k_all (v_all) shape : [16, 4, 5 * 210, 128]
353
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
354
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
355
+ else:
356
+ k_all = k_rolled
357
+ v_all = v_rolled
358
+
359
+ N = k_all.shape[-2]
360
+ q_windows = q_windows * self.scale
361
+ attn = (
362
+ q_windows @ k_all.transpose(-2, -1)
363
+ ) # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
364
+ # T * 45
365
+ window_area = T * self.window_size[0] * self.window_size[1]
366
+ # T * 165
367
+ window_area_rolled = k_rolled.shape[2]
368
+
369
+ if self.pool_method != "none" and self.focal_level > 1:
370
+ offset = window_area_rolled
371
+ for k in range(self.focal_level - 1):
372
+ # add attentional mask
373
+ # mask_all[1] shape [1, 16, T * 45]
374
+
375
+ bias = tuple((i + 2**k - 1) for i in self.focal_window)
376
+
377
+ if mask_all[k + 1] is not None:
378
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
379
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
380
+ mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
381
+
382
+ offset += T * bias[0] * bias[1]
383
+
384
+ if mask_all[0] is not None:
385
+ nW = mask_all[0].shape[0]
386
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
387
+ window_area, N)
388
+ attn[:, :, :, :, :
389
+ window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
390
+ None, :, None, :, :]
391
+ attn = attn.view(-1, self.num_heads, window_area, N)
392
+ attn = self.softmax(attn)
393
+ else:
394
+ attn = self.softmax(attn)
395
+
396
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
397
+ C)
398
+ x = self.proj(x)
399
+ return x
400
+
401
+
402
+ class TemporalFocalTransformerBlock(nn.Module):
403
+ r""" Temporal Focal Transformer Block.
404
+ Args:
405
+ dim (int): Number of input channels.
406
+ num_heads (int): Number of attention heads.
407
+ window_size (tuple[int]): Window size.
408
+ shift_size (int): Shift size for SW-MSA.
409
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
410
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
411
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
412
+ focal_level (int): The number level of focal window.
413
+ focal_window (int): Window size of each focal window.
414
+ n_vecs (int): Required for F3N.
415
+ t2t_params (int): T2T parameters for F3N.
416
+ """
417
+ def __init__(self,
418
+ dim,
419
+ num_heads,
420
+ window_size=(5, 9),
421
+ mlp_ratio=4.,
422
+ qkv_bias=True,
423
+ pool_method="fc",
424
+ focal_level=2,
425
+ focal_window=(5, 9),
426
+ norm_layer=nn.LayerNorm,
427
+ n_vecs=None,
428
+ t2t_params=None):
429
+ super().__init__()
430
+ self.dim = dim
431
+ self.num_heads = num_heads
432
+ self.window_size = window_size
433
+ self.expand_size = tuple(i // 2 for i in window_size) # TODO
434
+ self.mlp_ratio = mlp_ratio
435
+ self.pool_method = pool_method
436
+ self.focal_level = focal_level
437
+ self.focal_window = focal_window
438
+
439
+ self.window_size_glo = self.window_size
440
+
441
+ self.pool_layers = nn.ModuleList()
442
+ if self.pool_method != "none":
443
+ for k in range(self.focal_level - 1):
444
+ window_size_glo = tuple(
445
+ math.floor(i / (2**k)) for i in self.window_size_glo)
446
+ self.pool_layers.append(
447
+ nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
448
+ self.pool_layers[-1].weight.data.fill_(
449
+ 1. / (window_size_glo[0] * window_size_glo[1]))
450
+ self.pool_layers[-1].bias.data.fill_(0)
451
+
452
+ self.norm1 = norm_layer(dim)
453
+
454
+ self.attn = WindowAttention(dim,
455
+ expand_size=self.expand_size,
456
+ window_size=self.window_size,
457
+ focal_window=focal_window,
458
+ focal_level=focal_level,
459
+ num_heads=num_heads,
460
+ qkv_bias=qkv_bias,
461
+ pool_method=pool_method)
462
+
463
+ self.norm2 = norm_layer(dim)
464
+ self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
465
+
466
+ def forward(self, x):
467
+ B, T, H, W, C = x.shape
468
+
469
+ shortcut = x
470
+ x = self.norm1(x)
471
+
472
+ shifted_x = x
473
+
474
+ x_windows_all = [shifted_x]
475
+ x_window_masks_all = [None]
476
+
477
+ # partition windows tuple(i // 2 for i in window_size)
478
+ if self.focal_level > 1 and self.pool_method != "none":
479
+ # if we add coarser granularity and the pool method is not none
480
+ for k in range(self.focal_level - 1):
481
+ window_size_glo = tuple(
482
+ math.floor(i / (2**k)) for i in self.window_size_glo)
483
+ pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
484
+ pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
485
+ H_pool = pooled_h * window_size_glo[0]
486
+ W_pool = pooled_w * window_size_glo[1]
487
+
488
+ x_level_k = shifted_x
489
+ # trim or pad shifted_x depending on the required size
490
+ if H > H_pool:
491
+ trim_t = (H - H_pool) // 2
492
+ trim_b = H - H_pool - trim_t
493
+ x_level_k = x_level_k[:, :, trim_t:-trim_b]
494
+ elif H < H_pool:
495
+ pad_t = (H_pool - H) // 2
496
+ pad_b = H_pool - H - pad_t
497
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
498
+
499
+ if W > W_pool:
500
+ trim_l = (W - W_pool) // 2
501
+ trim_r = W - W_pool - trim_l
502
+ x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
503
+ elif W < W_pool:
504
+ pad_l = (W_pool - W) // 2
505
+ pad_r = W_pool - W - pad_l
506
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
507
+
508
+ x_windows_noreshape = window_partition_noreshape(
509
+ x_level_k.contiguous(), window_size_glo
510
+ ) # B, nw, nw, T, window_size, window_size, C
511
+ nWh, nWw = x_windows_noreshape.shape[1:3]
512
+ x_windows_noreshape = x_windows_noreshape.view(
513
+ B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
514
+ C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
515
+ x_windows_pooled = self.pool_layers[k](
516
+ x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
517
+
518
+ x_windows_all += [x_windows_pooled]
519
+ x_window_masks_all += [None]
520
+
521
+ attn_windows = self.attn(
522
+ x_windows_all,
523
+ mask_all=x_window_masks_all) # nW*B, T*window_size*window_size, C
524
+
525
+ # merge windows
526
+ attn_windows = attn_windows.view(-1, T, self.window_size[0],
527
+ self.window_size[1], C)
528
+ shifted_x = window_reverse(attn_windows, self.window_size, T, H,
529
+ W) # B T H' W' C
530
+
531
+ # FFN
532
+ x = shortcut + shifted_x
533
+ y = self.norm2(x)
534
+ x = x + self.mlp(y.view(B, T * H * W, C)).view(B, T, H, W, C)
535
+
536
+ return x
inpainter/model/modules/tfocal_transformer_hq.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is based on:
3
+ [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
4
+ https://github.com/ruiliu-ai/FuseFormer
5
+ [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
6
+ https://github.com/yitu-opensource/T2T-ViT
7
+ [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
8
+ https://github.com/microsoft/Focal-Transformer
9
+ """
10
+
11
+ import math
12
+ from functools import reduce
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class SoftSplit(nn.Module):
20
+ def __init__(self, channel, hidden, kernel_size, stride, padding,
21
+ t2t_param):
22
+ super(SoftSplit, self).__init__()
23
+ self.kernel_size = kernel_size
24
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
25
+ stride=stride,
26
+ padding=padding)
27
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
28
+ self.embedding = nn.Linear(c_in, hidden)
29
+
30
+ self.t2t_param = t2t_param
31
+
32
+ def forward(self, x, b, output_size):
33
+ f_h = int((output_size[0] + 2 * self.t2t_param['padding'][0] -
34
+ (self.t2t_param['kernel_size'][0] - 1) - 1) /
35
+ self.t2t_param['stride'][0] + 1)
36
+ f_w = int((output_size[1] + 2 * self.t2t_param['padding'][1] -
37
+ (self.t2t_param['kernel_size'][1] - 1) - 1) /
38
+ self.t2t_param['stride'][1] + 1)
39
+
40
+ feat = self.t2t(x)
41
+ feat = feat.permute(0, 2, 1)
42
+ # feat shape [b*t, num_vec, ks*ks*c]
43
+ feat = self.embedding(feat)
44
+ # feat shape after embedding [b, t*num_vec, hidden]
45
+ feat = feat.view(b, -1, f_h, f_w, feat.size(2))
46
+ return feat
47
+
48
+
49
+ class SoftComp(nn.Module):
50
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
51
+ super(SoftComp, self).__init__()
52
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
53
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
54
+ self.embedding = nn.Linear(hidden, c_out)
55
+ self.kernel_size = kernel_size
56
+ self.stride = stride
57
+ self.padding = padding
58
+ self.bias_conv = nn.Conv2d(channel,
59
+ channel,
60
+ kernel_size=3,
61
+ stride=1,
62
+ padding=1)
63
+ # TODO upsample conv
64
+ # self.bias_conv = nn.Conv2d()
65
+ # self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True)
66
+
67
+ def forward(self, x, t, output_size):
68
+ b_, _, _, _, c_ = x.shape
69
+ x = x.view(b_, -1, c_)
70
+ feat = self.embedding(x)
71
+ b, _, c = feat.size()
72
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
73
+ feat = F.fold(feat,
74
+ output_size=output_size,
75
+ kernel_size=self.kernel_size,
76
+ stride=self.stride,
77
+ padding=self.padding)
78
+ feat = self.bias_conv(feat)
79
+ return feat
80
+
81
+
82
+ class FusionFeedForward(nn.Module):
83
+ def __init__(self, d_model, n_vecs=None, t2t_params=None):
84
+ super(FusionFeedForward, self).__init__()
85
+ # We set d_ff as a default to 1960
86
+ hd = 1960
87
+ self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
88
+ self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
89
+ assert t2t_params is not None and n_vecs is not None
90
+ self.t2t_params = t2t_params
91
+
92
+ def forward(self, x, output_size):
93
+ n_vecs = 1
94
+ for i, d in enumerate(self.t2t_params['kernel_size']):
95
+ n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
96
+ (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
97
+
98
+ x = self.conv1(x)
99
+ b, n, c = x.size()
100
+ normalizer = x.new_ones(b, n, 49).view(-1, n_vecs, 49).permute(0, 2, 1)
101
+ normalizer = F.fold(normalizer,
102
+ output_size=output_size,
103
+ kernel_size=self.t2t_params['kernel_size'],
104
+ padding=self.t2t_params['padding'],
105
+ stride=self.t2t_params['stride'])
106
+
107
+ x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
108
+ output_size=output_size,
109
+ kernel_size=self.t2t_params['kernel_size'],
110
+ padding=self.t2t_params['padding'],
111
+ stride=self.t2t_params['stride'])
112
+
113
+ x = F.unfold(x / normalizer,
114
+ kernel_size=self.t2t_params['kernel_size'],
115
+ padding=self.t2t_params['padding'],
116
+ stride=self.t2t_params['stride']).permute(
117
+ 0, 2, 1).contiguous().view(b, n, c)
118
+ x = self.conv2(x)
119
+ return x
120
+
121
+
122
+ def window_partition(x, window_size):
123
+ """
124
+ Args:
125
+ x: shape is (B, T, H, W, C)
126
+ window_size (tuple[int]): window size
127
+ Returns:
128
+ windows: (B*num_windows, T*window_size*window_size, C)
129
+ """
130
+ B, T, H, W, C = x.shape
131
+
132
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
133
+ window_size[1], C)
134
+
135
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
136
+ -1, T * window_size[0] * window_size[1], C)
137
+ return windows
138
+
139
+
140
+ def window_partition_noreshape(x, window_size):
141
+ """
142
+ Args:
143
+ x: shape is (B, T, H, W, C)
144
+ window_size (tuple[int]): window size
145
+ Returns:
146
+ windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
147
+ """
148
+ B, T, H, W, C = x.shape
149
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
150
+ window_size[1], C)
151
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
152
+ return windows
153
+
154
+
155
+ def window_reverse(windows, window_size, T, H, W):
156
+ """
157
+ Args:
158
+ windows: shape is (num_windows*B, T, window_size, window_size, C)
159
+ window_size (tuple[int]): Window size
160
+ T (int): Temporal length of video
161
+ H (int): Height of image
162
+ W (int): Width of image
163
+ Returns:
164
+ x: (B, T, H, W, C)
165
+ """
166
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
167
+ x = windows.view(B, H // window_size[0], W // window_size[1], T,
168
+ window_size[0], window_size[1], -1)
169
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
170
+ return x
171
+
172
+
173
+ class WindowAttention(nn.Module):
174
+ """Temporal focal window attention
175
+ """
176
+ def __init__(self, dim, expand_size, window_size, focal_window,
177
+ focal_level, num_heads, qkv_bias, pool_method):
178
+
179
+ super().__init__()
180
+ self.dim = dim
181
+ self.expand_size = expand_size
182
+ self.window_size = window_size # Wh, Ww
183
+ self.pool_method = pool_method
184
+ self.num_heads = num_heads
185
+ head_dim = dim // num_heads
186
+ self.scale = head_dim**-0.5
187
+ self.focal_level = focal_level
188
+ self.focal_window = focal_window
189
+
190
+ if any(i > 0 for i in self.expand_size) and focal_level > 0:
191
+ # get mask for rolled k and rolled v
192
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
193
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
194
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
195
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
196
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
197
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
198
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
199
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
200
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
201
+ 0).flatten(0)
202
+ self.register_buffer("valid_ind_rolled",
203
+ mask_rolled.nonzero(as_tuple=False).view(-1))
204
+
205
+ if pool_method != "none" and focal_level > 1:
206
+ self.unfolds = nn.ModuleList()
207
+
208
+ # build relative position bias between local patch and pooled windows
209
+ for k in range(focal_level - 1):
210
+ stride = 2**k
211
+ kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
212
+ for i in self.focal_window)
213
+ # define unfolding operations
214
+ self.unfolds += [
215
+ nn.Unfold(kernel_size=kernel_size,
216
+ stride=stride,
217
+ padding=tuple(i // 2 for i in kernel_size))
218
+ ]
219
+
220
+ # define unfolding index for focal_level > 0
221
+ if k > 0:
222
+ mask = torch.zeros(kernel_size)
223
+ mask[(2**k) - 1:, (2**k) - 1:] = 1
224
+ self.register_buffer(
225
+ "valid_ind_unfold_{}".format(k),
226
+ mask.flatten(0).nonzero(as_tuple=False).view(-1))
227
+
228
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
229
+ self.proj = nn.Linear(dim, dim)
230
+
231
+ self.softmax = nn.Softmax(dim=-1)
232
+
233
+ def forward(self, x_all, mask_all=None):
234
+ """
235
+ Args:
236
+ x: input features with shape of (B, T, Wh, Ww, C)
237
+ mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
238
+
239
+ output: (nW*B, Wh*Ww, C)
240
+ """
241
+ x = x_all[0]
242
+
243
+ B, T, nH, nW, C = x.shape
244
+ qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
245
+ C).permute(4, 0, 1, 2, 3, 5).contiguous()
246
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
247
+
248
+ # partition q map
249
+ (q_windows, k_windows, v_windows) = map(
250
+ lambda t: window_partition(t, self.window_size).view(
251
+ -1, T, self.window_size[0] * self.window_size[1], self.
252
+ num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
253
+ contiguous().view(-1, self.num_heads, T * self.window_size[
254
+ 0] * self.window_size[1], C // self.num_heads), (q, k, v))
255
+ # q(k/v)_windows shape : [16, 4, 225, 128]
256
+
257
+ if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
258
+ (k_tl, v_tl) = map(
259
+ lambda t: torch.roll(t,
260
+ shifts=(-self.expand_size[0], -self.
261
+ expand_size[1]),
262
+ dims=(2, 3)), (k, v))
263
+ (k_tr, v_tr) = map(
264
+ lambda t: torch.roll(t,
265
+ shifts=(-self.expand_size[0], self.
266
+ expand_size[1]),
267
+ dims=(2, 3)), (k, v))
268
+ (k_bl, v_bl) = map(
269
+ lambda t: torch.roll(t,
270
+ shifts=(self.expand_size[0], -self.
271
+ expand_size[1]),
272
+ dims=(2, 3)), (k, v))
273
+ (k_br, v_br) = map(
274
+ lambda t: torch.roll(t,
275
+ shifts=(self.expand_size[0], self.
276
+ expand_size[1]),
277
+ dims=(2, 3)), (k, v))
278
+
279
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
280
+ lambda t: window_partition(t, self.window_size).view(
281
+ -1, T, self.window_size[0] * self.window_size[1], self.
282
+ num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
283
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
284
+ lambda t: window_partition(t, self.window_size).view(
285
+ -1, T, self.window_size[0] * self.window_size[1], self.
286
+ num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
287
+ k_rolled = torch.cat(
288
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
289
+ 2).permute(0, 3, 1, 2, 4).contiguous()
290
+ v_rolled = torch.cat(
291
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
292
+ 2).permute(0, 3, 1, 2, 4).contiguous()
293
+
294
+ # mask out tokens in current window
295
+ k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
296
+ v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
297
+ temp_N = k_rolled.shape[3]
298
+ k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
299
+ C // self.num_heads)
300
+ v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
301
+ C // self.num_heads)
302
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
303
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
304
+ else:
305
+ k_rolled = k_windows
306
+ v_rolled = v_windows
307
+
308
+ # q(k/v)_windows shape : [16, 4, 225, 128]
309
+ # k_rolled.shape : [16, 4, 5, 165, 128]
310
+ # ideal expanded window size 153 ((5+2*2)*(9+2*4))
311
+ # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
312
+
313
+ if self.pool_method != "none" and self.focal_level > 1:
314
+ k_pooled = []
315
+ v_pooled = []
316
+ for k in range(self.focal_level - 1):
317
+ stride = 2**k
318
+ # B, T, nWh, nWw, C
319
+ x_window_pooled = x_all[k + 1].permute(0, 3, 1, 2,
320
+ 4).contiguous()
321
+
322
+ nWh, nWw = x_window_pooled.shape[2:4]
323
+
324
+ # generate mask for pooled windows
325
+ mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
326
+ # unfold mask: [nWh*nWw//s//s, k*k, 1]
327
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
328
+ 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
329
+ view(nWh*nWw // stride // stride, -1, 1)
330
+
331
+ if k > 0:
332
+ valid_ind_unfold_k = getattr(
333
+ self, "valid_ind_unfold_{}".format(k))
334
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
335
+
336
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
337
+ x_window_masks = x_window_masks.masked_fill(
338
+ x_window_masks == 0,
339
+ float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
340
+ mask_all[k + 1] = x_window_masks
341
+
342
+ # generate k and v for pooled windows
343
+ qkv_pooled = self.qkv(x_window_pooled).reshape(
344
+ B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
345
+ 3).view(3, -1, C, nWh,
346
+ nWw).contiguous()
347
+ # B*T, C, nWh, nWw
348
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2]
349
+ # k_pooled_k shape: [5, 512, 4, 4]
350
+ # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
351
+
352
+ (k_pooled_k, v_pooled_k) = map(
353
+ lambda t: self.unfolds[k]
354
+ (t).view(B, T, C, self.unfolds[k].kernel_size[0], self.
355
+ unfolds[k].kernel_size[1], -1)
356
+ .permute(0, 5, 1, 3, 4, 2).contiguous().view(
357
+ -1, T, self.unfolds[k].kernel_size[0] * self.unfolds[
358
+ k].kernel_size[1], self.num_heads, C // self.
359
+ num_heads).permute(0, 3, 1, 2, 4).contiguous(),
360
+ # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
361
+ (k_pooled_k, v_pooled_k))
362
+ # k_pooled_k shape : [16, 4, 5, 45, 128]
363
+
364
+ # select valid unfolding index
365
+ if k > 0:
366
+ (k_pooled_k, v_pooled_k) = map(
367
+ lambda t: t[:, :, :, valid_ind_unfold_k],
368
+ (k_pooled_k, v_pooled_k))
369
+
370
+ k_pooled_k = k_pooled_k.view(
371
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
372
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
373
+ v_pooled_k = v_pooled_k.view(
374
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
375
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
376
+
377
+ k_pooled += [k_pooled_k]
378
+ v_pooled += [v_pooled_k]
379
+
380
+ # k_all (v_all) shape : [16, 4, 5 * 210, 128]
381
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
382
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
383
+ else:
384
+ k_all = k_rolled
385
+ v_all = v_rolled
386
+
387
+ N = k_all.shape[-2]
388
+ q_windows = q_windows * self.scale
389
+ # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
390
+ attn = (q_windows @ k_all.transpose(-2, -1))
391
+ # T * 45
392
+ window_area = T * self.window_size[0] * self.window_size[1]
393
+ # T * 165
394
+ window_area_rolled = k_rolled.shape[2]
395
+
396
+ if self.pool_method != "none" and self.focal_level > 1:
397
+ offset = window_area_rolled
398
+ for k in range(self.focal_level - 1):
399
+ # add attentional mask
400
+ # mask_all[1] shape [1, 16, T * 45]
401
+
402
+ bias = tuple((i + 2**k - 1) for i in self.focal_window)
403
+
404
+ if mask_all[k + 1] is not None:
405
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
406
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
407
+ mask_all[k+1][:, :, None, None, :].repeat(
408
+ attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
409
+
410
+ offset += T * bias[0] * bias[1]
411
+
412
+ if mask_all[0] is not None:
413
+ nW = mask_all[0].shape[0]
414
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
415
+ window_area, N)
416
+ attn[:, :, :, :, :
417
+ window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
418
+ None, :, None, :, :]
419
+ attn = attn.view(-1, self.num_heads, window_area, N)
420
+ attn = self.softmax(attn)
421
+ else:
422
+ attn = self.softmax(attn)
423
+
424
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
425
+ C)
426
+ x = self.proj(x)
427
+ return x
428
+
429
+
430
+ class TemporalFocalTransformerBlock(nn.Module):
431
+ r""" Temporal Focal Transformer Block.
432
+ Args:
433
+ dim (int): Number of input channels.
434
+ num_heads (int): Number of attention heads.
435
+ window_size (tuple[int]): Window size.
436
+ shift_size (int): Shift size for SW-MSA.
437
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
438
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
439
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
440
+ focal_level (int): The number level of focal window.
441
+ focal_window (int): Window size of each focal window.
442
+ n_vecs (int): Required for F3N.
443
+ t2t_params (int): T2T parameters for F3N.
444
+ """
445
+ def __init__(self,
446
+ dim,
447
+ num_heads,
448
+ window_size=(5, 9),
449
+ mlp_ratio=4.,
450
+ qkv_bias=True,
451
+ pool_method="fc",
452
+ focal_level=2,
453
+ focal_window=(5, 9),
454
+ norm_layer=nn.LayerNorm,
455
+ n_vecs=None,
456
+ t2t_params=None):
457
+ super().__init__()
458
+ self.dim = dim
459
+ self.num_heads = num_heads
460
+ self.window_size = window_size
461
+ self.expand_size = tuple(i // 2 for i in window_size) # TODO
462
+ self.mlp_ratio = mlp_ratio
463
+ self.pool_method = pool_method
464
+ self.focal_level = focal_level
465
+ self.focal_window = focal_window
466
+
467
+ self.window_size_glo = self.window_size
468
+
469
+ self.pool_layers = nn.ModuleList()
470
+ if self.pool_method != "none":
471
+ for k in range(self.focal_level - 1):
472
+ window_size_glo = tuple(
473
+ math.floor(i / (2**k)) for i in self.window_size_glo)
474
+ self.pool_layers.append(
475
+ nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
476
+ self.pool_layers[-1].weight.data.fill_(
477
+ 1. / (window_size_glo[0] * window_size_glo[1]))
478
+ self.pool_layers[-1].bias.data.fill_(0)
479
+
480
+ self.norm1 = norm_layer(dim)
481
+
482
+ self.attn = WindowAttention(dim,
483
+ expand_size=self.expand_size,
484
+ window_size=self.window_size,
485
+ focal_window=focal_window,
486
+ focal_level=focal_level,
487
+ num_heads=num_heads,
488
+ qkv_bias=qkv_bias,
489
+ pool_method=pool_method)
490
+
491
+ self.norm2 = norm_layer(dim)
492
+ self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
493
+
494
+ def forward(self, x):
495
+ output_size = x[1]
496
+ x = x[0]
497
+
498
+ B, T, H, W, C = x.shape
499
+
500
+ shortcut = x
501
+ x = self.norm1(x)
502
+
503
+ shifted_x = x
504
+
505
+ x_windows_all = [shifted_x]
506
+ x_window_masks_all = [None]
507
+
508
+ # partition windows tuple(i // 2 for i in window_size)
509
+ if self.focal_level > 1 and self.pool_method != "none":
510
+ # if we add coarser granularity and the pool method is not none
511
+ for k in range(self.focal_level - 1):
512
+ window_size_glo = tuple(
513
+ math.floor(i / (2**k)) for i in self.window_size_glo)
514
+ pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
515
+ pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
516
+ H_pool = pooled_h * window_size_glo[0]
517
+ W_pool = pooled_w * window_size_glo[1]
518
+
519
+ x_level_k = shifted_x
520
+ # trim or pad shifted_x depending on the required size
521
+ if H > H_pool:
522
+ trim_t = (H - H_pool) // 2
523
+ trim_b = H - H_pool - trim_t
524
+ x_level_k = x_level_k[:, :, trim_t:-trim_b]
525
+ elif H < H_pool:
526
+ pad_t = (H_pool - H) // 2
527
+ pad_b = H_pool - H - pad_t
528
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
529
+
530
+ if W > W_pool:
531
+ trim_l = (W - W_pool) // 2
532
+ trim_r = W - W_pool - trim_l
533
+ x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
534
+ elif W < W_pool:
535
+ pad_l = (W_pool - W) // 2
536
+ pad_r = W_pool - W - pad_l
537
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
538
+
539
+ x_windows_noreshape = window_partition_noreshape(
540
+ x_level_k.contiguous(), window_size_glo
541
+ ) # B, nw, nw, T, window_size, window_size, C
542
+ nWh, nWw = x_windows_noreshape.shape[1:3]
543
+ x_windows_noreshape = x_windows_noreshape.view(
544
+ B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
545
+ C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
546
+ x_windows_pooled = self.pool_layers[k](
547
+ x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
548
+
549
+ x_windows_all += [x_windows_pooled]
550
+ x_window_masks_all += [None]
551
+
552
+ # nW*B, T*window_size*window_size, C
553
+ attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all)
554
+
555
+ # merge windows
556
+ attn_windows = attn_windows.view(-1, T, self.window_size[0],
557
+ self.window_size[1], C)
558
+ shifted_x = window_reverse(attn_windows, self.window_size, T, H,
559
+ W) # B T H' W' C
560
+
561
+ # FFN
562
+ x = shortcut + shifted_x
563
+ y = self.norm2(x)
564
+ x = x + self.mlp(y.view(B, T * H * W, C), output_size).view(
565
+ B, T, H, W, C)
566
+
567
+ return x, output_size
inpainter/util/__init__.py ADDED
File without changes
inpainter/util/tensor_util.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ # resize frames
5
+ def resize_frames(frames, size=None):
6
+ """
7
+ size: (w, h)
8
+ """
9
+ if size is not None:
10
+ frames = [cv2.resize(f, size) for f in frames]
11
+ frames = np.stack(frames, 0)
12
+
13
+ return frames
14
+
15
+ # resize frames
16
+ def resize_masks(masks, size=None):
17
+ """
18
+ size: (w, h)
19
+ """
20
+ if size is not None:
21
+ masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks]
22
+ masks = np.stack(masks, 0)
23
+
24
+ return masks
overleaf/.DS_Store ADDED
Binary file (6.15 kB). View file
 
overleaf/Track Anything.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d271378ac9538e322b362b43a41e2c22a21cffac6f539a0c3e5b140c3b24b47e
3
+ size 5370701
overleaf/Track Anything/figs/avengers_1.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a519eb00a2d315ecdc36b5a53e174e9b3361a9526c7fcd8a96bfefde2eeb940f
3
+ size 2570569
overleaf/Track Anything/figs/davisresults.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fefd74df3daafd48ffb72a725c43354712a244db70e6c5d7ae8773203e0be492
3
+ size 1349133
overleaf/Track Anything/figs/failedcases.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb662ff62914d05fe8dc99640b9f89b32847675dd2069900a27771569378aa4
3
+ size 1200242
overleaf/Track Anything/figs/overview_4.pdf ADDED
Binary file (424 kB). View file
 
overleaf/Track Anything/neurips_2022.bbl ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ \begin{thebibliography}{10}
2
+
3
+ \bibitem{xmem}
4
+ Ho~Kei Cheng and Alexander~G. Schwing.
5
+ \newblock Xmem: Long-term video object segmentation with an atkinson-shiffrin
6
+ memory model.
7
+ \newblock In {\em {ECCV} {(28)}}, volume 13688 of {\em Lecture Notes in
8
+ Computer Science}, pages 640--658. Springer, 2022.
9
+
10
+ \bibitem{mivos}
11
+ Ho~Kei Cheng, Yu{-}Wing Tai, and Chi{-}Keung Tang.
12
+ \newblock Modular interactive video object segmentation: Interaction-to-mask,
13
+ propagation and difference-aware fusion.
14
+ \newblock In {\em {CVPR}}, pages 5559--5568. Computer Vision Foundation /
15
+ {IEEE}, 2021.
16
+
17
+ \bibitem{vit}
18
+ Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn,
19
+ Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg
20
+ Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby.
21
+ \newblock An image is worth 16x16 words: Transformers for image recognition at
22
+ scale.
23
+ \newblock In {\em {ICLR}}. OpenReview.net, 2021.
24
+
25
+ \bibitem{vos}
26
+ Mingqi Gao, Feng Zheng, James J.~Q. Yu, Caifeng Shan, Guiguang Ding, and
27
+ Jungong Han.
28
+ \newblock Deep learning for video object segmentation: a review.
29
+ \newblock {\em Artif. Intell. Rev.}, 56(1):457--531, 2023.
30
+
31
+ \bibitem{sam}
32
+ Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura
33
+ Gustafson, Tete Xiao, Spencer Whitehead, Alexander~C Berg, Wan-Yen Lo, et~al.
34
+ \newblock Segment anything.
35
+ \newblock {\em arXiv preprint arXiv:2304.02643}, 2023.
36
+
37
+ \bibitem{vot10}
38
+ Matej Kristan, Ale{\v{s}} Leonardis, Ji{\v{r}}{\'\i} Matas, Michael Felsberg,
39
+ Roman Pflugfelder, Joni-Kristian K{\"a}m{\"a}r{\"a}inen, Hyung~Jin Chang,
40
+ Martin Danelljan, Luka~{\v{C}}ehovin Zajc, Alan Luke{\v{z}}i{\v{c}}, et~al.
41
+ \newblock The tenth visual object tracking vot2022 challenge results.
42
+ \newblock In {\em Computer Vision--ECCV 2022 Workshops: Tel Aviv, Israel,
43
+ October 23--27, 2022, Proceedings, Part VIII}, pages 431--460. Springer,
44
+ 2023.
45
+
46
+ \bibitem{vot8}
47
+ Matej Kristan, Ale{\v{s}} Leonardis, Ji{\v{r}}{\'\i} Matas, Michael Felsberg,
48
+ Roman Pflugfelder, Joni-Kristian K{\"a}m{\"a}r{\"a}inen, Martin Danelljan,
49
+ Luka~{\v{C}}ehovin Zajc, Alan Luke{\v{z}}i{\v{c}}, Ondrej Drbohlav, et~al.
50
+ \newblock The eighth visual object tracking vot2020 challenge results.
51
+ \newblock In {\em European Conference on Computer Vision}, pages 547--601.
52
+ Springer, 2020.
53
+
54
+ \bibitem{vot6}
55
+ Matej Kristan, Ales Leonardis, Jiri Matas, Michael Felsberg, Roman Pflugfelder,
56
+ Luka ˇCehovin~Zajc, Tomas Vojir, Goutam Bhat, Alan Lukezic, Abdelrahman
57
+ Eldesokey, et~al.
58
+ \newblock The sixth visual object tracking vot2018 challenge results.
59
+ \newblock In {\em Proceedings of the European Conference on Computer Vision
60
+ (ECCV) Workshops}, pages 0--0, 2018.
61
+
62
+ \bibitem{vot9}
63
+ Matej Kristan, Ji{\v{r}}{\'\i} Matas, Ale{\v{s}} Leonardis, Michael Felsberg,
64
+ Roman Pflugfelder, Joni-Kristian K{\"a}m{\"a}r{\"a}inen, Hyung~Jin Chang,
65
+ Martin Danelljan, Luka Cehovin, Alan Luke{\v{z}}i{\v{c}}, et~al.
66
+ \newblock The ninth visual object tracking vot2021 challenge results.
67
+ \newblock In {\em Proceedings of the IEEE/CVF International Conference on
68
+ Computer Vision}, pages 2711--2738, 2021.
69
+
70
+ \bibitem{vot7}
71
+ Matej Kristan, Jiri Matas, Ales Leonardis, Michael Felsberg, Roman Pflugfelder,
72
+ Joni-Kristian Kamarainen, Luka ˇCehovin~Zajc, Ondrej Drbohlav, Alan Lukezic,
73
+ Amanda Berg, et~al.
74
+ \newblock The seventh visual object tracking vot2019 challenge results.
75
+ \newblock In {\em Proceedings of the IEEE/CVF International Conference on
76
+ Computer Vision Workshops}, pages 0--0, 2019.
77
+
78
+ \bibitem{e2fgvi}
79
+ Zhen Li, Chengze Lu, Jianhua Qin, Chun{-}Le Guo, and Ming{-}Ming Cheng.
80
+ \newblock Towards an end-to-end framework for flow-guided video inpainting.
81
+ \newblock In {\em {CVPR}}, pages 17541--17550. {IEEE}, 2022.
82
+
83
+ \bibitem{stm}
84
+ Seoung~Wug Oh, Joon{-}Young Lee, Ning Xu, and Seon~Joo Kim.
85
+ \newblock Video object segmentation using space-time memory networks.
86
+ \newblock In {\em {ICCV}}, pages 9225--9234. {IEEE}, 2019.
87
+
88
+ \bibitem{davis}
89
+ Jordi Pont{-}Tuset, Federico Perazzi, Sergi Caelles, Pablo Arbelaez, Alexander
90
+ Sorkine{-}Hornung, and Luc~Van Gool.
91
+ \newblock The 2017 {DAVIS} challenge on video object segmentation.
92
+ \newblock {\em CoRR}, abs/1704.00675, 2017.
93
+
94
+ \bibitem{siammask}
95
+ Qiang Wang, Li~Zhang, Luca Bertinetto, Weiming Hu, and Philip H.~S. Torr.
96
+ \newblock Fast online object tracking and segmentation: {A} unifying approach.
97
+ \newblock In {\em {CVPR}}, pages 1328--1338. Computer Vision Foundation /
98
+ {IEEE}, 2019.
99
+
100
+ \bibitem{aot}
101
+ Zongxin Yang, Yunchao Wei, and Yi~Yang.
102
+ \newblock Associating objects with transformers for video object segmentation.
103
+ \newblock In {\em NeurIPS}, pages 2491--2502, 2021.
104
+
105
+ \end{thebibliography}
overleaf/Track Anything/neurips_2022.bib ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @article{sam,
2
+ title={Segment anything},
3
+ author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C and Lo, Wan-Yen and others},
4
+ journal={arXiv preprint arXiv:2304.02643},
5
+ year={2023}
6
+ }
7
+
8
+ @inproceedings{xmem,
9
+ author = {Ho Kei Cheng and
10
+ Alexander G. Schwing},
11
+ title = {XMem: Long-Term Video Object Segmentation with an Atkinson-Shiffrin
12
+ Memory Model},
13
+ booktitle = {{ECCV} {(28)}},
14
+ series = {Lecture Notes in Computer Science},
15
+ volume = {13688},
16
+ pages = {640--658},
17
+ publisher = {Springer},
18
+ year = {2022}
19
+ }
20
+
21
+
22
+ %related
23
+
24
+ @article{vos,
25
+ author = {Mingqi Gao and
26
+ Feng Zheng and
27
+ James J. Q. Yu and
28
+ Caifeng Shan and
29
+ Guiguang Ding and
30
+ Jungong Han},
31
+ title = {Deep learning for video object segmentation: a review},
32
+ journal = {Artif. Intell. Rev.},
33
+ volume = {56},
34
+ number = {1},
35
+ pages = {457--531},
36
+ year = {2023}
37
+ }
38
+
39
+ @inproceedings{vot9,
40
+ title={The ninth visual object tracking vot2021 challenge results},
41
+ author={Kristan, Matej and Matas, Ji{\v{r}}{\'\i} and Leonardis, Ale{\v{s}} and Felsberg, Michael and Pflugfelder, Roman and K{\"a}m{\"a}r{\"a}inen, Joni-Kristian and Chang, Hyung Jin and Danelljan, Martin and Cehovin, Luka and Luke{\v{z}}i{\v{c}}, Alan and others},
42
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
43
+ pages={2711--2738},
44
+ year={2021}
45
+ }
46
+
47
+ @inproceedings{vot10,
48
+ title={The Tenth Visual Object Tracking VOT2022 Challenge Results},
49
+ author={Kristan, Matej and Leonardis, Ale{\v{s}} and Matas, Ji{\v{r}}{\'\i} and Felsberg, Michael and Pflugfelder, Roman and K{\"a}m{\"a}r{\"a}inen, Joni-Kristian and Chang, Hyung Jin and Danelljan, Martin and Zajc, Luka {\v{C}}ehovin and Luke{\v{z}}i{\v{c}}, Alan and others},
50
+ booktitle={Computer Vision--ECCV 2022 Workshops: Tel Aviv, Israel, October 23--27, 2022, Proceedings, Part VIII},
51
+ pages={431--460},
52
+ year={2023},
53
+ organization={Springer}
54
+ }
55
+
56
+ @inproceedings{vot8,
57
+ title={The eighth visual object tracking VOT2020 challenge results},
58
+ author={Kristan, Matej and Leonardis, Ale{\v{s}} and Matas, Ji{\v{r}}{\'\i} and Felsberg, Michael and Pflugfelder, Roman and K{\"a}m{\"a}r{\"a}inen, Joni-Kristian and Danelljan, Martin and Zajc, Luka {\v{C}}ehovin and Luke{\v{z}}i{\v{c}}, Alan and Drbohlav, Ondrej and others},
59
+ booktitle={European Conference on Computer Vision},
60
+ pages={547--601},
61
+ year={2020},
62
+ organization={Springer}
63
+ }
64
+ @inproceedings{vot7,
65
+ title={The seventh visual object tracking vot2019 challenge results},
66
+ author={Kristan, Matej and Matas, Jiri and Leonardis, Ales and Felsberg, Michael and Pflugfelder, Roman and Kamarainen, Joni-Kristian and ˇCehovin Zajc, Luka and Drbohlav, Ondrej and Lukezic, Alan and Berg, Amanda and others},
67
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops},
68
+ pages={0--0},
69
+ year={2019}
70
+ }
71
+ @inproceedings{vot6,
72
+ title={The sixth visual object tracking vot2018 challenge results},
73
+ author={Kristan, Matej and Leonardis, Ales and Matas, Jiri and Felsberg, Michael and Pflugfelder, Roman and ˇCehovin Zajc, Luka and Vojir, Tomas and Bhat, Goutam and Lukezic, Alan and Eldesokey, Abdelrahman and others},
74
+ booktitle={Proceedings of the European Conference on Computer Vision (ECCV) Workshops},
75
+ pages={0--0},
76
+ year={2018}
77
+ }
78
+
79
+ @inproceedings{vit,
80
+ author = {Alexey Dosovitskiy and
81
+ Lucas Beyer and
82
+ Alexander Kolesnikov and
83
+ Dirk Weissenborn and
84
+ Xiaohua Zhai and
85
+ Thomas Unterthiner and
86
+ Mostafa Dehghani and
87
+ Matthias Minderer and
88
+ Georg Heigold and
89
+ Sylvain Gelly and
90
+ Jakob Uszkoreit and
91
+ Neil Houlsby},
92
+ title = {An Image is Worth 16x16 Words: Transformers for Image Recognition
93
+ at Scale},
94
+ booktitle = {{ICLR}},
95
+ publisher = {OpenReview.net},
96
+ year = {2021}
97
+ }
98
+
99
+ @inproceedings{stm,
100
+ author = {Seoung Wug Oh and
101
+ Joon{-}Young Lee and
102
+ Ning Xu and
103
+ Seon Joo Kim},
104
+ title = {Video Object Segmentation Using Space-Time Memory Networks},
105
+ booktitle = {{ICCV}},
106
+ pages = {9225--9234},
107
+ publisher = {{IEEE}},
108
+ year = {2019}
109
+ }
110
+
111
+ @inproceedings{siammask,
112
+ author = {Qiang Wang and
113
+ Li Zhang and
114
+ Luca Bertinetto and
115
+ Weiming Hu and
116
+ Philip H. S. Torr},
117
+ title = {Fast Online Object Tracking and Segmentation: {A} Unifying Approach},
118
+ booktitle = {{CVPR}},
119
+ pages = {1328--1338},
120
+ publisher = {Computer Vision Foundation / {IEEE}},
121
+ year = {2019}
122
+ }
123
+
124
+ @inproceedings{mivos,
125
+ author = {Ho Kei Cheng and
126
+ Yu{-}Wing Tai and
127
+ Chi{-}Keung Tang},
128
+ title = {Modular Interactive Video Object Segmentation: Interaction-to-Mask,
129
+ Propagation and Difference-Aware Fusion},
130
+ booktitle = {{CVPR}},
131
+ pages = {5559--5568},
132
+ publisher = {Computer Vision Foundation / {IEEE}},
133
+ year = {2021}
134
+ }
135
+
136
+ @article{davis,
137
+ author = {Jordi Pont{-}Tuset and
138
+ Federico Perazzi and
139
+ Sergi Caelles and
140
+ Pablo Arbelaez and
141
+ Alexander Sorkine{-}Hornung and
142
+ Luc Van Gool},
143
+ title = {The 2017 {DAVIS} Challenge on Video Object Segmentation},
144
+ journal = {CoRR},
145
+ volume = {abs/1704.00675},
146
+ year = {2017}
147
+ }
148
+
149
+ @inproceedings{aot,
150
+ author = {Zongxin Yang and
151
+ Yunchao Wei and
152
+ Yi Yang},
153
+ title = {Associating Objects with Transformers for Video Object Segmentation},
154
+ booktitle = {NeurIPS},
155
+ pages = {2491--2502},
156
+ year = {2021}
157
+ }
158
+
159
+ @inproceedings{icip,
160
+ author = {St{\'{e}}phane Vujasinovic and
161
+ Sebastian Bullinger and
162
+ Stefan Becker and
163
+ Norbert Scherer{-}Negenborn and
164
+ Michael Arens and
165
+ Rainer Stiefelhagen},
166
+ title = {Revisiting Click-Based Interactive Video Object Segmentation},
167
+ booktitle = {{ICIP}},
168
+ pages = {2756--2760},
169
+ publisher = {{IEEE}},
170
+ year = {2022}
171
+ }
172
+
173
+
174
+
175
+
176
+ @inproceedings{e2fgvi,
177
+ author = {Zhen Li and
178
+ Chengze Lu and
179
+ Jianhua Qin and
180
+ Chun{-}Le Guo and
181
+ Ming{-}Ming Cheng},
182
+ title = {Towards An End-to-End Framework for Flow-Guided Video Inpainting},
183
+ booktitle = {{CVPR}},
184
+ pages = {17541--17550},
185
+ publisher = {{IEEE}},
186
+ year = {2022}
187
+ }
overleaf/Track Anything/neurips_2022.sty ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % partial rewrite of the LaTeX2e package for submissions to the
2
+ % Conference on Neural Information Processing Systems (NeurIPS):
3
+ %
4
+ % - uses more LaTeX conventions
5
+ % - line numbers at submission time replaced with aligned numbers from
6
+ % lineno package
7
+ % - \nipsfinalcopy replaced with [final] package option
8
+ % - automatically loads times package for authors
9
+ % - loads natbib automatically; this can be suppressed with the
10
+ % [nonatbib] package option
11
+ % - adds foot line to first page identifying the conference
12
+ % - adds preprint option for submission to e.g. arXiv
13
+ % - conference acronym modified
14
+ %
15
+ % Roman Garnett (garnett@wustl.edu) and the many authors of
16
+ % nips15submit_e.sty, including MK and drstrip@sandia
17
+ %
18
+ % last revision: March 2022
19
+
20
+ \NeedsTeXFormat{LaTeX2e}
21
+ \ProvidesPackage{neurips_2022}[2022/03/31 NeurIPS 2022 submission/camera-ready style file]
22
+
23
+ % declare final option, which creates camera-ready copy
24
+ \newif\if@neuripsfinal\@neuripsfinalfalse
25
+ \DeclareOption{final}{
26
+ \@neuripsfinaltrue
27
+ }
28
+
29
+ % declare nonatbib option, which does not load natbib in case of
30
+ % package clash (users can pass options to natbib via
31
+ % \PassOptionsToPackage)
32
+ \newif\if@natbib\@natbibtrue
33
+ \DeclareOption{nonatbib}{
34
+ \@natbibfalse
35
+ }
36
+
37
+ % declare preprint option, which creates a preprint version ready for
38
+ % upload to, e.g., arXiv
39
+ \newif\if@preprint\@preprintfalse
40
+ \DeclareOption{preprint}{
41
+ \@preprinttrue
42
+ }
43
+
44
+ \ProcessOptions\relax
45
+
46
+ % determine whether this is an anonymized submission
47
+ \newif\if@submission\@submissiontrue
48
+ \if@neuripsfinal\@submissionfalse\fi
49
+ \if@preprint\@submissionfalse\fi
50
+
51
+ % fonts
52
+ \renewcommand{\rmdefault}{ptm}
53
+ \renewcommand{\sfdefault}{phv}
54
+
55
+ % change this every year for notice string at bottom
56
+ \newcommand{\@neuripsordinal}{36th}
57
+ \newcommand{\@neuripsyear}{2022}
58
+ \newcommand{\@neuripslocation}{New Orleans}
59
+
60
+ % acknowledgments
61
+ \usepackage{environ}
62
+ \newcommand{\acksection}{\section*{Acknowledgments and Disclosure of Funding}}
63
+ \NewEnviron{ack}{%
64
+ \acksection
65
+ \BODY
66
+ }
67
+
68
+
69
+ % load natbib unless told otherwise
70
+ %\if@natbib
71
+ % \RequirePackage{natbib}
72
+ %\fi
73
+
74
+ % set page geometry
75
+ \usepackage[verbose=true,letterpaper]{geometry}
76
+ \AtBeginDocument{
77
+ \newgeometry{
78
+ textheight=9in,
79
+ textwidth=5.5in,
80
+ top=1in,
81
+ headheight=12pt,
82
+ headsep=25pt,
83
+ footskip=30pt
84
+ }
85
+ \@ifpackageloaded{fullpage}
86
+ {\PackageWarning{neurips_2022}{fullpage package not allowed! Overwriting formatting.}}
87
+ {}
88
+ }
89
+
90
+ \widowpenalty=10000
91
+ \clubpenalty=10000
92
+ \flushbottom
93
+ \sloppy
94
+
95
+
96
+ % font sizes with reduced leading
97
+ \renewcommand{\normalsize}{%
98
+ \@setfontsize\normalsize\@xpt\@xipt
99
+ \abovedisplayskip 7\p@ \@plus 2\p@ \@minus 5\p@
100
+ \abovedisplayshortskip \z@ \@plus 3\p@
101
+ \belowdisplayskip \abovedisplayskip
102
+ \belowdisplayshortskip 4\p@ \@plus 3\p@ \@minus 3\p@
103
+ }
104
+ \normalsize
105
+ \renewcommand{\small}{%
106
+ \@setfontsize\small\@ixpt\@xpt
107
+ \abovedisplayskip 6\p@ \@plus 1.5\p@ \@minus 4\p@
108
+ \abovedisplayshortskip \z@ \@plus 2\p@
109
+ \belowdisplayskip \abovedisplayskip
110
+ \belowdisplayshortskip 3\p@ \@plus 2\p@ \@minus 2\p@
111
+ }
112
+ \renewcommand{\footnotesize}{\@setfontsize\footnotesize\@ixpt\@xpt}
113
+ \renewcommand{\scriptsize}{\@setfontsize\scriptsize\@viipt\@viiipt}
114
+ \renewcommand{\tiny}{\@setfontsize\tiny\@vipt\@viipt}
115
+ \renewcommand{\large}{\@setfontsize\large\@xiipt{14}}
116
+ \renewcommand{\Large}{\@setfontsize\Large\@xivpt{16}}
117
+ \renewcommand{\LARGE}{\@setfontsize\LARGE\@xviipt{20}}
118
+ \renewcommand{\huge}{\@setfontsize\huge\@xxpt{23}}
119
+ \renewcommand{\Huge}{\@setfontsize\Huge\@xxvpt{28}}
120
+
121
+ % sections with less space
122
+ \providecommand{\section}{}
123
+ \renewcommand{\section}{%
124
+ \@startsection{section}{1}{\z@}%
125
+ {-2.0ex \@plus -0.5ex \@minus -0.2ex}%
126
+ { 1.5ex \@plus 0.3ex \@minus 0.2ex}%
127
+ {\large\bf\raggedright}%
128
+ }
129
+ \providecommand{\subsection}{}
130
+ \renewcommand{\subsection}{%
131
+ \@startsection{subsection}{2}{\z@}%
132
+ {-1.8ex \@plus -0.5ex \@minus -0.2ex}%
133
+ { 0.8ex \@plus 0.2ex}%
134
+ {\normalsize\bf\raggedright}%
135
+ }
136
+ \providecommand{\subsubsection}{}
137
+ \renewcommand{\subsubsection}{%
138
+ \@startsection{subsubsection}{3}{\z@}%
139
+ {-1.5ex \@plus -0.5ex \@minus -0.2ex}%
140
+ { 0.5ex \@plus 0.2ex}%
141
+ {\normalsize\bf\raggedright}%
142
+ }
143
+ \providecommand{\paragraph}{}
144
+ \renewcommand{\paragraph}{%
145
+ \@startsection{paragraph}{4}{\z@}%
146
+ {1.5ex \@plus 0.5ex \@minus 0.2ex}%
147
+ {-1em}%
148
+ {\normalsize\bf}%
149
+ }
150
+ \providecommand{\subparagraph}{}
151
+ \renewcommand{\subparagraph}{%
152
+ \@startsection{subparagraph}{5}{\z@}%
153
+ {1.5ex \@plus 0.5ex \@minus 0.2ex}%
154
+ {-1em}%
155
+ {\normalsize\bf}%
156
+ }
157
+ \providecommand{\subsubsubsection}{}
158
+ \renewcommand{\subsubsubsection}{%
159
+ \vskip5pt{\noindent\normalsize\rm\raggedright}%
160
+ }
161
+
162
+ % float placement
163
+ \renewcommand{\topfraction }{0.85}
164
+ \renewcommand{\bottomfraction }{0.4}
165
+ \renewcommand{\textfraction }{0.1}
166
+ \renewcommand{\floatpagefraction}{0.7}
167
+
168
+ \newlength{\@neuripsabovecaptionskip}\setlength{\@neuripsabovecaptionskip}{7\p@}
169
+ \newlength{\@neuripsbelowcaptionskip}\setlength{\@neuripsbelowcaptionskip}{\z@}
170
+
171
+ \setlength{\abovecaptionskip}{\@neuripsabovecaptionskip}
172
+ \setlength{\belowcaptionskip}{\@neuripsbelowcaptionskip}
173
+
174
+ % swap above/belowcaptionskip lengths for tables
175
+ \renewenvironment{table}
176
+ {\setlength{\abovecaptionskip}{\@neuripsbelowcaptionskip}%
177
+ \setlength{\belowcaptionskip}{\@neuripsabovecaptionskip}%
178
+ \@float{table}}
179
+ {\end@float}
180
+
181
+ % footnote formatting
182
+ \setlength{\footnotesep }{6.65\p@}
183
+ \setlength{\skip\footins}{9\p@ \@plus 4\p@ \@minus 2\p@}
184
+ \renewcommand{\footnoterule}{\kern-3\p@ \hrule width 12pc \kern 2.6\p@}
185
+ \setcounter{footnote}{0}
186
+
187
+ % paragraph formatting
188
+ \setlength{\parindent}{\z@}
189
+ \setlength{\parskip }{5.5\p@}
190
+
191
+ % list formatting
192
+ \setlength{\topsep }{4\p@ \@plus 1\p@ \@minus 2\p@}
193
+ \setlength{\partopsep }{1\p@ \@plus 0.5\p@ \@minus 0.5\p@}
194
+ \setlength{\itemsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@}
195
+ \setlength{\parsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@}
196
+ \setlength{\leftmargin }{3pc}
197
+ \setlength{\leftmargini }{\leftmargin}
198
+ \setlength{\leftmarginii }{2em}
199
+ \setlength{\leftmarginiii}{1.5em}
200
+ \setlength{\leftmarginiv }{1.0em}
201
+ \setlength{\leftmarginv }{0.5em}
202
+ \def\@listi {\leftmargin\leftmargini}
203
+ \def\@listii {\leftmargin\leftmarginii
204
+ \labelwidth\leftmarginii
205
+ \advance\labelwidth-\labelsep
206
+ \topsep 2\p@ \@plus 1\p@ \@minus 0.5\p@
207
+ \parsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@
208
+ \itemsep \parsep}
209
+ \def\@listiii{\leftmargin\leftmarginiii
210
+ \labelwidth\leftmarginiii
211
+ \advance\labelwidth-\labelsep
212
+ \topsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@
213
+ \parsep \z@
214
+ \partopsep 0.5\p@ \@plus 0\p@ \@minus 0.5\p@
215
+ \itemsep \topsep}
216
+ \def\@listiv {\leftmargin\leftmarginiv
217
+ \labelwidth\leftmarginiv
218
+ \advance\labelwidth-\labelsep}
219
+ \def\@listv {\leftmargin\leftmarginv
220
+ \labelwidth\leftmarginv
221
+ \advance\labelwidth-\labelsep}
222
+ \def\@listvi {\leftmargin\leftmarginvi
223
+ \labelwidth\leftmarginvi
224
+ \advance\labelwidth-\labelsep}
225
+
226
+ % create title
227
+ \providecommand{\maketitle}{}
228
+ \renewcommand{\maketitle}{%
229
+ \par
230
+ \begingroup
231
+ \renewcommand{\thefootnote}{\fnsymbol{footnote}}
232
+ % for perfect author name centering
233
+ \renewcommand{\@makefnmark}{\hbox to \z@{$^{\@thefnmark}$\hss}}
234
+ % The footnote-mark was overlapping the footnote-text,
235
+ % added the following to fix this problem (MK)
236
+ \long\def\@makefntext##1{%
237
+ \parindent 1em\noindent
238
+ \hbox to 1.8em{\hss $\m@th ^{\@thefnmark}$}##1
239
+ }
240
+ \thispagestyle{empty}
241
+ \@maketitle
242
+ \@thanks
243
+ \@notice
244
+ \endgroup
245
+ \let\maketitle\relax
246
+ \let\thanks\relax
247
+ }
248
+
249
+ % rules for title box at top of first page
250
+ \newcommand{\@toptitlebar}{
251
+ \hrule height 4\p@
252
+ \vskip 0.25in
253
+ \vskip -\parskip%
254
+ }
255
+ \newcommand{\@bottomtitlebar}{
256
+ \vskip 0.29in
257
+ \vskip -\parskip
258
+ \hrule height 1\p@
259
+ \vskip 0.09in%
260
+ }
261
+
262
+ % create title (includes both anonymized and non-anonymized versions)
263
+ \providecommand{\@maketitle}{}
264
+ \renewcommand{\@maketitle}{%
265
+ \vbox{%
266
+ \hsize\textwidth
267
+ \linewidth\hsize
268
+ \vskip 0.1in
269
+ \@toptitlebar
270
+ \centering
271
+ {\LARGE\bf \@title\par}
272
+ \@bottomtitlebar
273
+ \if@submission
274
+ \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}
275
+ Anonymous Author(s) \\
276
+ Affiliation \\
277
+ Address \\
278
+ \texttt{email} \\
279
+ \end{tabular}%
280
+ \else
281
+ \def\And{%
282
+ \end{tabular}\hfil\linebreak[0]\hfil%
283
+ \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces%
284
+ }
285
+ \def\AND{%
286
+ \end{tabular}\hfil\linebreak[4]\hfil%
287
+ \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces%
288
+ }
289
+ \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\@author\end{tabular}%
290
+ \fi
291
+ \vskip 0.3in \@minus 0.1in
292
+ }
293
+ }
294
+
295
+ % add conference notice to bottom of first page
296
+ \newcommand{\ftype@noticebox}{8}
297
+ \newcommand{\@notice}{%
298
+ % give a bit of extra room back to authors on first page
299
+ \enlargethispage{2\baselineskip}%
300
+ \@float{noticebox}[b]%
301
+ \footnotesize\@noticestring%
302
+ \end@float%
303
+ }
304
+
305
+ % abstract styling
306
+ \renewenvironment{abstract}%
307
+ {%
308
+ \vskip 0.075in%
309
+ \centerline%
310
+ {\large\bf Abstract}%
311
+ \vspace{0.5ex}%
312
+ \begin{quote}%
313
+ }
314
+ {
315
+ \par%
316
+ \end{quote}%
317
+ \vskip 1ex%
318
+ }
319
+
320
+ % For the paper checklist
321
+ \newcommand{\answerYes}[1][]{\textcolor{blue}{[Yes] #1}}
322
+ \newcommand{\answerNo}[1][]{\textcolor{orange}{[No] #1}}
323
+ \newcommand{\answerNA}[1][]{\textcolor{gray}{[N/A] #1}}
324
+ \newcommand{\answerTODO}[1][]{\textcolor{red}{\bf [TODO]}}
325
+
326
+ % handle tweaks for camera-ready copy vs. submission copy
327
+ \if@preprint
328
+ \newcommand{\@noticestring}{%
329
+ Preprint. Under review.%
330
+ }
331
+ \else
332
+ \if@neuripsfinal
333
+ \newcommand{\@noticestring}{%
334
+ \@neuripsordinal\/ Conference on Neural Information Processing Systems
335
+ (NeurIPS \@neuripsyear).%, \@neuripslocation.%
336
+ }
337
+ \else
338
+ \newcommand{\@noticestring}{%
339
+ Submitted to \@neuripsordinal\/ Conference on Neural Information
340
+ Processing Systems (NeurIPS \@neuripsyear). Do not distribute.%
341
+ }
342
+
343
+ % hide the acknowledgements
344
+ \NewEnviron{hide}{}
345
+ \let\ack\hide
346
+ \let\endack\endhide
347
+
348
+ % line numbers for submission
349
+ \RequirePackage{lineno}
350
+ \linenumbers
351
+
352
+ % fix incompatibilities between lineno and amsmath, if required, by
353
+ % transparently wrapping linenomath environments around amsmath
354
+ % environments
355
+ \AtBeginDocument{%
356
+ \@ifpackageloaded{amsmath}{%
357
+ \newcommand*\patchAmsMathEnvironmentForLineno[1]{%
358
+ \expandafter\let\csname old#1\expandafter\endcsname\csname #1\endcsname
359
+ \expandafter\let\csname oldend#1\expandafter\endcsname\csname end#1\endcsname
360
+ \renewenvironment{#1}%
361
+ {\linenomath\csname old#1\endcsname}%
362
+ {\csname oldend#1\endcsname\endlinenomath}%
363
+ }%
364
+ \newcommand*\patchBothAmsMathEnvironmentsForLineno[1]{%
365
+ \patchAmsMathEnvironmentForLineno{#1}%
366
+ \patchAmsMathEnvironmentForLineno{#1*}%
367
+ }%
368
+ \patchBothAmsMathEnvironmentsForLineno{equation}%
369
+ \patchBothAmsMathEnvironmentsForLineno{align}%
370
+ \patchBothAmsMathEnvironmentsForLineno{flalign}%
371
+ \patchBothAmsMathEnvironmentsForLineno{alignat}%
372
+ \patchBothAmsMathEnvironmentsForLineno{gather}%
373
+ \patchBothAmsMathEnvironmentsForLineno{multline}%
374
+ }
375
+ {}
376
+ }
377
+ \fi
378
+ \fi
379
+
380
+
381
+ \endinput
overleaf/Track Anything/neurips_2022.tex ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ \documentclass{article}
2
+
3
+
4
+ % if you need to pass options to natbib, use, e.g.:
5
+ % \PassOptionsToPackage{numbers, compress}{natbib}
6
+ % before loading neurips_2022
7
+
8
+
9
+ % ready for submission
10
+ % \usepackage{neurips_2022}
11
+
12
+
13
+ % to compile a preprint version, e.g., for submission to arXiv, add add the
14
+ % [preprint] option:
15
+ \usepackage[preprint]{neurips_2022}
16
+
17
+ % to compile a camera-ready version, add the [final] option, e.g.:
18
+ % \usepackage[final]{neurips_2022}
19
+
20
+
21
+ % to avoid loading the natbib package, add option nonatbib:
22
+ % \usepackage[nonatbib]{neurips_2022}
23
+ \usepackage{graphicx}
24
+ \usepackage[utf8]{inputenc} % allow utf-8 input
25
+ \usepackage[T1]{fontenc} % use 8-bit T1 fonts
26
+ \usepackage{hyperref} % hyperlinks
27
+ \usepackage{url} % simple URL typesetting
28
+ \usepackage{booktabs} % professional-quality tables
29
+ \usepackage{amsfonts} % blackboard math symbols
30
+ \usepackage{nicefrac} % compact symbols for 1/2, etc.
31
+ \usepackage{microtype} % microtypography
32
+ \usepackage{xcolor} % colors
33
+ % \usepackage{acmart}
34
+
35
+ \title{Track Anything: High-performance Interactive Tracking and Segmentation}
36
+ \title{Track Anything: High-performance Object Tracking in Videos by Interactive Masks}
37
+ % \title{Track Anything: Interaction to Mask in Videos}
38
+ \title{Track Anything: Segment Anything Meets Videos}
39
+
40
+ % \author{%
41
+ % David S.~Hippocampus\thanks{Use footnote for providing further information
42
+ % about author (webpage, alternative address)---\emph{not} for acknowledging
43
+ % funding agencies.} \\
44
+ % SUSTech VIPG\\
45
+
46
+ % \author{Jinyu Yang}
47
+ % \authornote{equal}
48
+
49
+ % \author{Mingqi Gao}
50
+ % \authornotemark[1]
51
+
52
+ \author{%
53
+ Jinyu Yang\thanks{Equal contribution. Alphabetical order.},\enskip Mingqi Gao\footnotemark[1],\enskip Zhe Li\footnotemark[1],\enskip Shang Gao, Fangjing Wang, Feng Zheng \\
54
+ SUSTech VIP Lab\\
55
+ % Cranberry-Lemon University\\
56
+ % Pittsburgh, PA 15213 \\
57
+ % \texttt{hippo@cs.cranberry-lemon.edu} \\
58
+ % \url{https://github.com/gaomingqi/Track-Anything}\\
59
+ % examples of more authors
60
+ % \And
61
+ % Coauthor \\
62
+ % Affiliation \\
63
+ % Address \\
64
+ % \texttt{email} \\
65
+ % \AND
66
+ % Coauthor \\
67
+ % Affiliation \\
68
+ % Address \\
69
+ % \texttt{email} \\
70
+ % \And
71
+ % Coauthor \\
72
+ % Affiliation \\
73
+ % Address \\
74
+ % \texttt{email} \\
75
+ % \And
76
+ % Coauthor \\
77
+ % Affiliation \\
78
+ % Address \\
79
+ % \texttt{email} \\
80
+ % \thanks{these authors contributed equally}
81
+ }
82
+ % \affiliation{\institution{SUSTech VIP Lab}}
83
+ % \footnote{Equal contribution. Alphabetical order.}
84
+
85
+ \begin{document}
86
+
87
+
88
+ \maketitle
89
+
90
+
91
+ \begin{abstract}
92
+
93
+ Recently, the Segment Anything Model (SAM) gains lots of attention rapidly due to its impressive segmentation performance on images.
94
+ Regarding its strong ability on image segmentation and high interactivity with different prompts, we found that it performs poorly on consistent segmentation in videos.
95
+ Therefore, in this report, we propose Track Anything Model (TAM), which achieves high-performance interactive tracking and segmentation in videos.
96
+ To be detailed, given a video sequence, only with very little human participation, \textit{i.e.}, several clicks, people can track anything they are interested in, and get satisfactory results in one-pass inference.
97
+ Without additional training, such an interactive design performs impressively on video object tracking and segmentation.
98
+ % superior to prior works on video object tracking and segmentation.
99
+ All resources are available on \url{https://github.com/gaomingqi/Track-Anything}.
100
+ We hope this work can facilitate related research.
101
+
102
+ \end{abstract}
103
+
104
+ \section{Introduction}
105
+
106
+ Tracking an arbitrary object in generic scenes is important, and Video Object Tracking (VOT) is a fundamental task in computer vision.
107
+ Similar to VOT, Video Object Segmentation (VOS) aims to separate the target (region of interest) from the background in a video sequence, which can be seen as a kind of more fine-grained object tracking.
108
+ We notice that current state-of-the-art video trackers/segmenters are trained on large-scale manually-annotated datasets and initialized by a bounding box or a segmentation mask.
109
+ On the one hand, the massive human labor force is hidden behind huge amounts of labeled data.
110
+ % Recently, interactive algorithms help to liberate users from labor-expensive initialization and annotation.
111
+ Moreover, current initialization settings, especially the semi-supervised VOS, need specific object mask groundtruth for model initialization.
112
+ How to liberate researchers from labor-expensive annotation and initialization is much of important.
113
+
114
+
115
+ Recently, Segment-Anything Model (SAM)~\cite{sam} has been proposed, which is a large foundation model for image segmentation.
116
+ It supports flexible prompts and computes masks in real-time, thus allowing interactive use.
117
+ We conclude that SAM has the following advantages that can assist interactive tracking:
118
+ \textbf{1) Strong image segmentation ability.}
119
+ Trained on 11 million images and 1.1 billion masks, SAM can produce high-quality masks and do zero-shot segmentation in generic scenarios.
120
+ \textbf{2) High interactivity with different kinds of prompts. }
121
+ With input user-friendly prompts of points, boxes, or language, SAM can give satisfactory segmentation masks on specific image areas.
122
+ However, using SAM in videos directly did not give us an impressive performance due to its deficiency in temporal correspondence.
123
+
124
+ On the other hand, tracking or segmenting in videos faces challenges from scale variation, target deformation, motion blur, camera motion, similar objects, and so on~\cite{vos,vot6,vot7,vot8,vot9,vot10}.
125
+ Even the state-of-the-art models suffer from complex scenarios in the public datasets~\cite{xmem}, not to mention the real-world applications.
126
+ Therefore, a question is considered by us:
127
+ \textit{can we achieve high-performance tracking/segmentation in videos through the way of interaction?}
128
+
129
+ In this technical report, we introduce our Track-Anything project, which develops an efficient toolkit for high-performance object tracking and segmentation in videos.
130
+ With a user-friendly interface, the Track Anything Model (TAM) can track and segment any objects in a given video with only one-pass inference.
131
+ Figure~\ref{fig:overview} shows the one-pass interactive process in the proposed TAM.
132
+ In detail, TAM combines SAM~\cite{sam}, a large segmentation model, and XMem~\cite{xmem}, an advanced VOS model.
133
+ As shown, we integrate them in an interactive way.
134
+ Firstly, users can interactively initialize the SAM, \textit{i.e.}, clicking on the object, to define a target object;
135
+ then, XMem is used to give a mask prediction of the object in the next frame according to both temporal and spatial correspondence;
136
+ next, SAM is utilized to give a more precise mask description;
137
+ during the tracking process, users can pause and correct as soon as they notice tracking failures.
138
+
139
+ Our contributions can be concluded as follows:
140
+
141
+ 1) We promote the SAM applications to the video level to achieve interactive video object tracking and segmentation.
142
+ % We combine the SAM with VOS models to achieve interactive video object tracking and segmentation.
143
+ Rather than separately using SAM per frame, we integrate SAM into the process of temporal correspondence construction.
144
+
145
+ 2) We propose one-pass interactive tracking and segmentation for efficient annotation and a user-friendly tracking interface, which uses very small amounts of human participation to solve extreme difficulties in video object perception.
146
+
147
+ 3) Our proposed method shows superior performance and high usability in complex scenes and has many potential applications.
148
+
149
+ % \section{Related Works}
150
+
151
+ % \textbf{Video Object Tracking.}
152
+
153
+
154
+
155
+ % \textbf{Video Object Segmentation.}
156
+ \section{Track Anything Task}
157
+
158
+ Inspired by the Segment Anything task~\cite{sam}, we propose the Track Anything task, which aims to flexible object tracking in arbitrary videos.
159
+ Here we define that the target objects can be flexibly selected, added, or removed in any way according to the users' interests.
160
+ Also, the video length and types can be arbitrary rather than limited to trimmed or natural videos.
161
+ With such settings, diverse downstream tasks can be achieved, including single/multiple object tracking, short-/long-term object tracking, unsupervised VOS, semi-supervised VOS, referring VOS, interactive VOS, long-term VOS, and so on.
162
+
163
+ \section{Methodology}
164
+
165
+ \subsection{Preliminaries}
166
+
167
+ \textbf{Segment Anything Model~\cite{sam}.}
168
+ Very recently, the Segment Anything Model (SAM) has been proposed by Meta AI Research and gets numerous attention.
169
+ As a foundation model for image segmentation, SAM is based on ViT~\cite{vit} and trained on the large-scale dataset SA-1B~\cite{sam}.
170
+ Obviously, SAM shows promising segmentation ability on images, especially on zero-shot segmentation tasks.
171
+ Unfortunately, SAM only shows superior performance on image segmentation, while it cannot deal with complex video segmentation.
172
+
173
+
174
+ \textbf{XMem~\cite{xmem}.}
175
+ Given the mask description of the target object at the first frame, XMem can track the object and generate corresponding masks in the subsequent frames.
176
+ Inspired by the Atkinson-Shiffrin memory model, it aims to solve the difficulties in long-term videos with unified feature memory stores.
177
+ The drawbacks of XMem are also obvious: 1) as a semi-supervised VOS model, it requires a precise mask to initialize; 2) for long videos, it is difficult for XMem to recover from tracking or segmentation failure.
178
+ In this paper, we solve both difficulties by importing interactive tracking with SAM.
179
+
180
+
181
+ \textbf{Interactive Video Object Segmentation.}
182
+ Interactive VOS~\cite{mivos} takes user interactions as inputs, \textit{e.g.}, scribbles.
183
+ Then, users can iteratively refine the segmentation results until they are satisfied with them.
184
+ Interactive VOS gains lots of attention as it is much easier to provide scribbles than to specify every pixel for an object mask.
185
+ However, we found that current interactive VOS methods require multiple rounds to refine the results, which impedes their efficiency in real-world applications.
186
+
187
+ \begin{figure}[t]
188
+ \centering
189
+ \includegraphics[width=\linewidth]{figs/overview_4.pdf}
190
+ \caption{Pipeline of our proposed Track Anything Model (TAM). Only within one round of inference can the TAM obtain impressive tracking and segmentation performance on the human-selected target.}
191
+ \label{fig:overview}
192
+ \end{figure}
193
+
194
+ \begin{table}
195
+ \caption{Results on DAVIS-2016-val and DAVIS-2017-test-dev datasets~\cite{davis}.}
196
+ \label{davis1617}
197
+ \centering
198
+ \small
199
+ \setlength\tabcolsep{4pt}
200
+ \begin{tabular}{l|c|c|c|ccc|ccc}
201
+ \toprule
202
+ & & & &\multicolumn{3}{c|}{DAVIS-2016-val} &\multicolumn{3}{c}{DAVIS-2017-test-dev} \\
203
+ Method & Venue & Initialization & Evaluation& $J\&F$ & $J$ &$F$ &$J\&F$ & $J$ &$F$\\
204
+ \midrule
205
+ STM~\cite{stm} & ICCV2019 &Mask & One Pass &89.3 &88.7 &89.9 & 72.2 & 69.3 & 75.2 \\
206
+ AOT~\cite{aot} &NeurIPS2021 &Mask & One Pass & 91.1 & 90.1 & 92.1 & 79.6 & 75.9 & 83.3 \\
207
+ XMem~\cite{xmem} & NeurIPS2022 &Mask & One Pass & 92.0 &90.7 &93.2 & 81.2 & 77.6 & 84.7\\
208
+ \midrule
209
+ % SiamMask~\cite{siammask}& CVPR2019 &Box & One Pass & 69.8 &71.7 &67.8 &56.4 &54.3 &58.5 \\
210
+ SiamMask~\cite{siammask}& CVPR2019 &Box & One Pass & 69.8 &71.7 &67.8 &- &- &- \\
211
+ \midrule
212
+ % MiVOS~\cite{mivos} & CVPR2021 &Scribble &8 Rounds &91.0 &89.6 &92.4 & 84.5 &81.7 &87.4\\
213
+ MiVOS~\cite{mivos} & CVPR2021 &Scribble &8 Rounds &91.0 &89.6 &92.4 &78.6 &74.9 &82.2\\
214
+ % \midrule
215
+ % & ICIP2022 &Click & \\
216
+ \midrule
217
+ TAM (Proposed) &- & Click & One Pass & 88.4 & 87.5 &89.4 & 73.1 & 69.8 & 76.4\\
218
+ % Ours & & 5 Clicks & \\
219
+ \bottomrule
220
+ \end{tabular}
221
+ \end{table}
222
+
223
+
224
+
225
+ \subsection{Implementation}\label{implementation}
226
+
227
+ Inspired by SAM, we consider tracking anything in videos.
228
+ We aim to define this task with high interactivity and ease of use.
229
+ It leads to ease of use and is able to obtain high performance with very little human interaction effort.
230
+ Figure~\ref{fig:overview} shows the pipeline of our Track Anything Model (TAM).
231
+ As shown, we divide our Track-Anything process into the following four steps:
232
+
233
+ \textbf{Step 1: Initialization with SAM~\cite{sam}.}
234
+ As SAM provides us an opportunity to segment a region of interest with weak prompts, \textit{e.g.}, points, and bounding boxes, we use it to give an initial mask of the target object.
235
+ Following SAM, users can get a mask description of the interested object by a click or modify the object mask with several clicks to get a satisfactory initialization.
236
+
237
+ \textbf{Step 2: Tracking with XMem~\cite{xmem}.}
238
+ Given the initialized mask, XMem performs semi-supervised VOS on the following frames.
239
+ Since XMem is an advanced VOS method that can output satisfactory results on simple scenarios, we output the predicted masks of XMem on most occasions.
240
+ When the mask quality is not such good, we save the XMem predictions and corresponding intermediate parameters, \textit{i.e.}, probes and affinities, and skip to step 3.
241
+ % Given the initialized mask and the whole sequence, XMem performs semi-supervised VOS, which aims to solve the performance decay in long-term prediction with memory potentiation.
242
+
243
+
244
+ \textbf{Step 3: Refinement with SAM~\cite{sam}.}
245
+ We notice that during the inference of VOS models, keep predicting consistent and precise masks are challenging.
246
+ In fact, most state-of-the-art VOS models tend to segment more and more coarsely over time during inference.
247
+ Therefore, we utilize SAM to refine the masks predicted by XMem when its quality assessment is not satisfactory.
248
+ Specifically, we project the probes and affinities to be point prompts for SAM, and the predicted mask from Step 2 is used as a mask prompt for SAM.
249
+ Then, with these prompts, SAM is able to produce a refined segmentation mask.
250
+ Such refined masks will also be added to the temporal correspondence of XMem to refine all subsequent object discrimination.
251
+
252
+ \textbf{Step 4: Correction with human participation.}
253
+ % Long video annotation.
254
+ After the above three steps, the TAM can now successfully solve some common challenges and predict segmentation masks.
255
+ However, we notice that it is still difficult to accurately distinguish the objects in some extremely challenging scenarios, especially when processing long videos.
256
+ Therefore, we propose to add human correction during inference, which can bring a qualitative leap in performance with only very small human efforts.
257
+ In detail, users can compulsively stop the TAM process and correct the mask of the current frame with positive and negative clicks.
258
+
259
+ \section{Experiments}
260
+
261
+ \subsection{Quantitative Results}
262
+
263
+
264
+ To evaluate TAM, we utilize the validation set of DAVIS-2016 and test-development set of DAVIS-2017~\cite{davis}.
265
+ % The evaluation process follows the one we proposed in Section~\ref{implementation}.
266
+ Then, we execute the proposed TAM as demonstrated in Section~\ref{implementation}.
267
+ The results are given in Table~\ref{davis1617}.
268
+ As shown, our TAM obtains $J\&F$ scores of 88.4 and 73.1 on DAVIS-2016-val and DAVIS-2017-test-dev datasets, respectively.
269
+ Note that TAM is initialized by clicks and evaluated in one pass.
270
+ Notably, we found that TAM performs well when against difficult and complex scenarios.
271
+ % During the evaluation,
272
+
273
+ % click-based interactive video object segmentation
274
+
275
+ % CLICK-BASED INTERACTIVE VIDEO OBJECT
276
+ % SEGMENTATION
277
+
278
+
279
+ \begin{figure}[t]
280
+ \centering
281
+ \includegraphics[width=\linewidth]{figs/davisresults.pdf}
282
+ \caption{Qualitative results on video sequences from DAVIS-16 and DAVIS-17 datasets~\cite{davis}.}
283
+ \label{fig:davisresult}
284
+ \end{figure}
285
+
286
+
287
+ \begin{figure}[t]
288
+ \centering
289
+ \includegraphics[width=\linewidth]{figs/failedcases.pdf}
290
+ \caption{Failed cases.}
291
+ \label{fig:failedcases}
292
+ \end{figure}
293
+
294
+ \subsection{Qualitative Results}
295
+
296
+ % As we use a new one-pass interactive method to evaluation our TAM, here we only present some qualitative results.
297
+ We also give some qualitative results in Figure~\ref{fig:davisresult}.
298
+ As shown, TAM can handle multi-object separation, target deformation, scale change, and camera motion well, which demonstrates its superior tracking and segmentation abilities within only click initialization and one-round inference.
299
+
300
+ \subsection{Failed Cases}
301
+ We here also analyze the failed cases, as shown in Figure~\ref{fig:failedcases}.
302
+ Overall, we notice that the failed cases typically appear on the following two occasions.
303
+ 1)
304
+ % Separated masks of one object in a long video.
305
+ Current VOS models are mostly designed for short videos, which focus more on maintaining short-term memory rather than long-term memory.
306
+ This leads to mask shrinkage or lacking refinement in long-term videos, as shown in seq (a).
307
+ Essentially, we aim to solve them in step 3 by the refinement ability of SAM, while its effectiveness is lower than expected in realistic applications.
308
+ It indicates that the ability of SAM refinement based on multiple prompts can be further improved in the future.
309
+ On the other hand, human participation/interaction in TAM can be an approach to solving such difficulties, while too much interaction will also result in low efficiency.
310
+ Thus, the mechanism of long-term memory preserving and transient memory updating is still important.
311
+ % Limited refinement by SAM. Although SAM supports to refine previous predictions, via point and mask prompts, . How to .
312
+ 2) When the object structure is complex, \textit{e.g.}, the bicycle wheels in seq (b) contain many cavities in groundtruth masks. We found it very difficult to get a fine-grained initialized mask by propagating the clicks.
313
+ Thus, the coarse initialized masks may have side effects on the subsequent frames and lead to poor predictions.
314
+ This also inspires us that SAM is still struggling with complex and precision structures.
315
+
316
+
317
+ \begin{figure}[t]
318
+ \centering
319
+ \includegraphics[width=\linewidth]{figs/avengers_1.pdf}
320
+ \caption{Raw frames, object masks, and inpainted results from the movie \textit{Captain America: Civil War (2016)}.}
321
+ \label{fig:captain}
322
+ \end{figure}
323
+
324
+
325
+
326
+ \section{Applications}
327
+ The proposed Track Anything Model (TAM) provides many possibilities for flexible tracking and segmentation in videos.
328
+ Here, we demonstrate several applications enabled by our proposed method.
329
+ % Our method may be able to a variety of applications.
330
+ In such an interactive way, diverse downstream tasks can be easily achieved.
331
+ % \textbf{Demo.}
332
+ % It is able to solve diverse downstream tasks in such a interactive way.
333
+
334
+ \textbf{Efficient video annotation.}
335
+ TAM has the ability to segment the regions of interest in videos and flexibly choose the objects users want to track. Thus, it can be used for video annotation for tasks like video object tracking and video object segmentation.
336
+ On the other hand, click-based interaction makes it easy to use, and the annotation process is of high efficiency.
337
+
338
+
339
+ \textbf{Long-term object tracking.}
340
+ The study of long-term tracking is gaining more and more attention because it is much closer to practical applications.
341
+ Current long-term object tracking task requires the tracker to have the ability to handle target disappearance and reappearance while it is still limited in the scope of trimmed videos.
342
+ Our TAM is more advanced in real-world applications which can handle the shot changes in long videos.
343
+
344
+
345
+ \textbf{User-friendly video editing.}
346
+ Track Anything Model provides us the opportunities to segment objects
347
+ With the object segmentation masks provided by TAM, we are then able to remove or alter any of the existing objects in a given video.
348
+ Here we combine E$^2$FGVI~\cite{e2fgvi} to evaluate its application value.
349
+
350
+ \textbf{Visualized development toolkit for video tasks.}
351
+ For ease of use, we also provide visualized interfaces for multiple video tasks, \textit{e.g.}, VOS, VOT, video inpainting, and so on.
352
+ With the provided toolkit, users can apply their models on real-world videos and visualize the results instantaneously.
353
+ Corresponding demos are available in Hugging Face\footnote{\url{https://huggingface.co/spaces/watchtowerss/Track-Anything}}.
354
+
355
+
356
+ To show the effectiveness, we give a comprehensive test by applying TAM on the movie \textit{Captain America: Civil War (2016)}.
357
+ Some representative results are given in Figure \ref{fig:captain}.
358
+ As shown, TAM can present multiple object tracking precisely in videos with lots of shot changes and can further be helpful in video inpainting.
359
+
360
+ % \section{Further work}
361
+
362
+
363
+ % \section*{Acknowledgements}
364
+
365
+ % \appendix
366
+
367
+ % \section{Appendix}
368
+
369
+
370
+ % Optionally include extra information (complete proofs, additional experiments and plots) in the appendix.
371
+ % This section will often be part of the supplemental material.
372
+
373
+
374
+
375
+ \bibliographystyle{plain}
376
+ \bibliography{neurips_2022}
377
+
378
+ \end{document}