mt-cly
commited on
Commit
·
909940e
1
Parent(s):
a6d2ec4
init
Browse files- .gitattributes +2 -0
- .gitignore +31 -0
- README.md +6 -3
- app.py +359 -0
- assets/0846x4.png +3 -0
- assets/0873.png +3 -0
- assets/0873x4.png +3 -0
- assets/0873x4_cropped_120x120.png +3 -0
- assets/0892x4.png +3 -0
- assets/Screenshot_cropped_180x100.png +3 -0
- dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl +3 -0
- requirements.txt +12 -0
- setup.py +26 -0
- utils/edsrbaseline.py +113 -0
- utils/fea2gsropeamp.py +749 -0
- utils/gaussian_splatting.py +265 -0
- utils/gs_cuda/check.py +115 -0
- utils/gs_cuda/gs.cu +199 -0
- utils/gs_cuda/gs.h +24 -0
- utils/gs_cuda/gswrapper.cpp +80 -0
- utils/gs_cuda/gswrapper.py +49 -0
- utils/gs_cuda/mylineprofiler.py +264 -0
- utils/gs_cuda/profile.log +69 -0
- utils/gs_cuda/profile.py +137 -0
- utils/gs_cuda_dmax/__init__.py +0 -0
- utils/gs_cuda_dmax/check.py +122 -0
- utils/gs_cuda_dmax/gs copy.cu +212 -0
- utils/gs_cuda_dmax/gs.backup.cu +188 -0
- utils/gs_cuda_dmax/gs.cu +187 -0
- utils/gs_cuda_dmax/gs.h +26 -0
- utils/gs_cuda_dmax/gswrapper.cpp +82 -0
- utils/gs_cuda_dmax/gswrapper.py +63 -0
- utils/gs_cuda_dmax/mylineprofiler.py +264 -0
- utils/gs_cuda_dmax/profile.py +142 -0
- utils/hatropeamp.py +1156 -0
- utils/rdn.py +120 -0
- utils/split_and_joint_image.py +232 -0
- utils/swinir.py +1243 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python build and package directories
|
2 |
+
build/
|
3 |
+
gscuda.egg-info/
|
4 |
+
|
5 |
+
# Additional common Python ignore patterns
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
*.so
|
10 |
+
*.egg
|
11 |
+
*.egg-info/
|
12 |
+
|
13 |
+
# IDE and editor files
|
14 |
+
.vscode/
|
15 |
+
.idea/
|
16 |
+
*.swp
|
17 |
+
*.swo
|
18 |
+
*~
|
19 |
+
|
20 |
+
# OS generated files
|
21 |
+
.DS_Store
|
22 |
+
.DS_Store?
|
23 |
+
._*
|
24 |
+
.Spotlight-V100
|
25 |
+
.Trashes
|
26 |
+
ehthumbs.db
|
27 |
+
Thumbs.db
|
28 |
+
|
29 |
+
# Gradio cache
|
30 |
+
.gradio/
|
31 |
+
.setup_complete
|
README.md
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
---
|
2 |
title: GSASR
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
license: mit
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: GSASR
|
3 |
+
emoji: 🌖
|
4 |
+
colorFrom: pink
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.44.1
|
8 |
+
python_version: 3.10
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
+
# suggested_hardware: zero-a10g
|
12 |
license: mit
|
13 |
+
short_description: GSASR(2d gaussian for arbitrary-scale super-resolution)
|
14 |
---
|
15 |
|
16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
import math
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import os
|
10 |
+
import tempfile
|
11 |
+
import time
|
12 |
+
import threading
|
13 |
+
|
14 |
+
from utils.hatropeamp import HATNOUP_ROPE_AMP
|
15 |
+
from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
|
16 |
+
from utils.edsrbaseline import EDSRNOUP
|
17 |
+
from utils.hatropeamp import HATNOUP_ROPE_AMP
|
18 |
+
from utils.rdn import RDNNOUP
|
19 |
+
from utils.swinir import SwinIRNOUP
|
20 |
+
from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
|
21 |
+
from utils.gaussian_splatting import generate_2D_gaussian_splatting_step
|
22 |
+
from utils.split_and_joint_image import split_and_joint_image
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
+
import subprocess
|
25 |
+
import sys
|
26 |
+
import spaces
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
# Device setup
|
31 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
+
|
33 |
+
# Global stop flag for interrupting inference
|
34 |
+
stop_inference = False
|
35 |
+
inference_lock = threading.Lock()
|
36 |
+
|
37 |
+
def load_model(
|
38 |
+
pretrained_model_name_or_path: str = "mutou0308/GSASR",
|
39 |
+
model_name: str = "HATL_SA1B",
|
40 |
+
device: str | torch.device = "cuda"
|
41 |
+
):
|
42 |
+
enc_path = hf_hub_download(
|
43 |
+
repo_id=pretrained_model_name_or_path, filename=os.path.join(model_name, 'encoder.pth')
|
44 |
+
)
|
45 |
+
dec_path = hf_hub_download(
|
46 |
+
repo_id=pretrained_model_name_or_path, filename=os.path.join(model_name, 'decoder.pth')
|
47 |
+
)
|
48 |
+
|
49 |
+
enc_weight = torch.load(enc_path, weights_only=True)['params_ema']
|
50 |
+
dec_weight = torch.load(dec_path, weights_only=True)['params_ema']
|
51 |
+
|
52 |
+
if model_name in ['EDSR_DIV2K', 'EDSR_DF2K']:
|
53 |
+
encoder = EDSRNOUP()
|
54 |
+
decoder = Fea2GS_ROPE_AMP()
|
55 |
+
elif model_name in ['RDN_DIV2K', 'RDN_DF2K']:
|
56 |
+
encoder = RDNNOUP()
|
57 |
+
decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks = 2)
|
58 |
+
elif model_name in ['SwinIR_DIV2K', 'SwinIR_DF2K']:
|
59 |
+
encoder = SwinIRNOUP()
|
60 |
+
decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks=2, num_crossattn_layers=4, num_gs_seed=256, window_size=16)
|
61 |
+
elif model_name in ['HATL_SA1B']:
|
62 |
+
encoder = HATNOUP_ROPE_AMP()
|
63 |
+
decoder = Fea2GS_ROPE_AMP(channel=192, num_crossattn_blocks=4, num_crossattn_layers=4, num_selfattn_blocks=8, num_selfattn_layers=6,
|
64 |
+
num_gs_seed=256, window_size=16)
|
65 |
+
else:
|
66 |
+
raise ValueError(f"args.model-{model_name} must be in ['EDSR_DIV2K', 'EDSR_DF2K', 'RDN_DIV2K', 'RDN_DF2K', 'SwinIR_DIV2K', 'SwinIR_DF2K', 'HATL_SA1B']")
|
67 |
+
|
68 |
+
encoder.load_state_dict(enc_weight, strict=True)
|
69 |
+
decoder.load_state_dict(dec_weight, strict=True)
|
70 |
+
encoder.eval()
|
71 |
+
decoder.eval()
|
72 |
+
encoder = encoder.to(device)
|
73 |
+
decoder = decoder.to(device)
|
74 |
+
return encoder, decoder
|
75 |
+
|
76 |
+
|
77 |
+
def preprocess(x, denominator=16):
|
78 |
+
"""Preprocess image to ensure dimensions are multiples of denominator"""
|
79 |
+
_, c, h, w = x.shape
|
80 |
+
if h % denominator > 0:
|
81 |
+
pad_h = denominator - h % denominator
|
82 |
+
else:
|
83 |
+
pad_h = 0
|
84 |
+
if w % denominator > 0:
|
85 |
+
pad_w = denominator - w % denominator
|
86 |
+
else:
|
87 |
+
pad_w = 0
|
88 |
+
x_new = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
|
89 |
+
return x_new
|
90 |
+
|
91 |
+
def postprocess(x, gt_size_h, gt_size_w):
|
92 |
+
"""Post-process by cropping to target size"""
|
93 |
+
x_new = x[:, :, :gt_size_h, :gt_size_w]
|
94 |
+
return x_new
|
95 |
+
|
96 |
+
def should_use_tile(image_height, image_width, threshold=1024):
|
97 |
+
"""Determine if tile processing should be used based on image resolution"""
|
98 |
+
return max(image_height, image_width) > threshold
|
99 |
+
|
100 |
+
def set_stop_flag():
|
101 |
+
"""Set the global stop flag to interrupt inference"""
|
102 |
+
global stop_inference
|
103 |
+
with inference_lock:
|
104 |
+
stop_inference = True
|
105 |
+
return "🛑 Stopping inference...", gr.update(interactive=False)
|
106 |
+
|
107 |
+
def reset_stop_flag():
|
108 |
+
"""Reset the global stop flag"""
|
109 |
+
global stop_inference
|
110 |
+
with inference_lock:
|
111 |
+
stop_inference = False
|
112 |
+
|
113 |
+
def check_stop_flag():
|
114 |
+
"""Check if inference should be stopped"""
|
115 |
+
global stop_inference
|
116 |
+
with inference_lock:
|
117 |
+
return stop_inference
|
118 |
+
|
119 |
+
@spaces.GPU
|
120 |
+
def super_resolution_inference(image, scale=4.0):
|
121 |
+
"""Super-resolution inference function with automatic tile processing"""
|
122 |
+
|
123 |
+
# Check if gscuda setup has been run
|
124 |
+
setup_marker = ".setup_complete"
|
125 |
+
if not os.path.exists(setup_marker):
|
126 |
+
print("First run detected, installing dependencies...")
|
127 |
+
try:
|
128 |
+
# subprocess.check_call(["pip", "install", "-e", "."])
|
129 |
+
subprocess.check_call(["pip", "install", "dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl"])
|
130 |
+
# Create marker file to indicate setup is complete
|
131 |
+
with open(setup_marker, "w") as f:
|
132 |
+
f.write("Setup completed")
|
133 |
+
print("Setup completed successfully!")
|
134 |
+
except subprocess.CalledProcessError as e:
|
135 |
+
return None, f"❌ Setup failed with error: {e}", None
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
if image is None:
|
140 |
+
return None, "Please upload an image", None
|
141 |
+
|
142 |
+
# Load model
|
143 |
+
encoder, decoder = load_model(model_name="HATL_SA1B")
|
144 |
+
|
145 |
+
# Reset stop flag at the beginning
|
146 |
+
reset_stop_flag()
|
147 |
+
|
148 |
+
# Fixed parameters
|
149 |
+
tile_overlap = 16 # Fixed overlap size
|
150 |
+
crop_size = 8 # Fixed crop size
|
151 |
+
tile_size = 1024 # Fixed tile size for large images
|
152 |
+
|
153 |
+
try:
|
154 |
+
# Check for interruption
|
155 |
+
if check_stop_flag():
|
156 |
+
return None, "❌ Inference interrupted", None
|
157 |
+
|
158 |
+
# Convert PIL image to numpy array
|
159 |
+
img_np = np.array(image)
|
160 |
+
if len(img_np.shape) == 3:
|
161 |
+
img_np = img_np[:, :, [2, 1, 0]] # RGB to BGR
|
162 |
+
|
163 |
+
# Convert to tensor
|
164 |
+
img = torch.from_numpy(np.transpose(img_np.astype(np.float32) / 255., (2, 0, 1))).float()
|
165 |
+
img = img.unsqueeze(0).to(device)
|
166 |
+
|
167 |
+
# Check for interruption
|
168 |
+
if check_stop_flag():
|
169 |
+
return None, "❌ Inference interrupted", None
|
170 |
+
|
171 |
+
# Calculate target size
|
172 |
+
gt_size = [math.floor(scale * img.shape[2]), math.floor(scale * img.shape[3])]
|
173 |
+
|
174 |
+
# Determine if tile processing should be used
|
175 |
+
use_tile = should_use_tile(img.shape[2], img.shape[3])
|
176 |
+
|
177 |
+
# Force AMP mixed precision
|
178 |
+
with torch.inference_mode():
|
179 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
180 |
+
# Check for interruption before main processing
|
181 |
+
if check_stop_flag():
|
182 |
+
return None, "❌ Inference interrupted", None
|
183 |
+
|
184 |
+
if use_tile:
|
185 |
+
# Use tile processing
|
186 |
+
assert tile_size % 16 == 0, f"tile_size-{tile_size} must be divisible by 16"
|
187 |
+
assert 2 * tile_overlap < tile_size, f"2 * tile_overlap must be less than tile_size"
|
188 |
+
assert 2 * crop_size <= tile_overlap, f"2 * crop_size must be less than or equal to tile_overlap"
|
189 |
+
|
190 |
+
with torch.no_grad():
|
191 |
+
output = split_and_joint_image(
|
192 |
+
lq=img,
|
193 |
+
scale_factor=scale,
|
194 |
+
split_size=tile_size,
|
195 |
+
overlap_size=tile_overlap,
|
196 |
+
model_g=encoder,
|
197 |
+
model_fea2gs=decoder,
|
198 |
+
crop_size=crop_size,
|
199 |
+
scale_modify=torch.tensor([scale, scale]),
|
200 |
+
default_step_size=1.2,
|
201 |
+
cuda_rendering=True,
|
202 |
+
mode='scale_modify',
|
203 |
+
if_dmax=True,
|
204 |
+
dmax_mode='fix',
|
205 |
+
dmax=0.1
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
# Direct processing without tiles
|
209 |
+
lq_pad = preprocess(img, 16) # denominator=16 for HATL
|
210 |
+
gt_size_pad = torch.tensor([math.floor(scale * lq_pad.shape[2]),
|
211 |
+
math.floor(scale * lq_pad.shape[3])])
|
212 |
+
gt_size_pad = gt_size_pad.unsqueeze(0)
|
213 |
+
|
214 |
+
with torch.no_grad():
|
215 |
+
# Check for interruption before encoder
|
216 |
+
if check_stop_flag():
|
217 |
+
return None, "❌ Inference interrupted", None
|
218 |
+
|
219 |
+
# Encoder output
|
220 |
+
encoder_output = encoder(lq_pad) # b,c,h,w
|
221 |
+
|
222 |
+
# Check for interruption before decoder
|
223 |
+
if check_stop_flag():
|
224 |
+
return None, "❌ Inference interrupted", None
|
225 |
+
|
226 |
+
scale_vector = torch.tensor(scale, dtype=torch.float32).unsqueeze(0).to(device)
|
227 |
+
|
228 |
+
# Decoder output
|
229 |
+
batch_gs_parameters = decoder(encoder_output, scale_vector)
|
230 |
+
gs_parameters = batch_gs_parameters[0, :]
|
231 |
+
|
232 |
+
# Check for interruption before gaussian rendering
|
233 |
+
if check_stop_flag():
|
234 |
+
return None, "❌ Inference interrupted", None
|
235 |
+
|
236 |
+
# Gaussian rendering
|
237 |
+
b_output = generate_2D_gaussian_splatting_step(
|
238 |
+
gs_parameters=gs_parameters,
|
239 |
+
sr_size=gt_size_pad[0],
|
240 |
+
scale=scale,
|
241 |
+
sample_coords=None,
|
242 |
+
scale_modify=torch.tensor([scale, scale]),
|
243 |
+
default_step_size=1.2,
|
244 |
+
cuda_rendering=True,
|
245 |
+
mode='scale_modify',
|
246 |
+
if_dmax=True,
|
247 |
+
dmax_mode='fix',
|
248 |
+
dmax=0.1
|
249 |
+
)
|
250 |
+
output = b_output.unsqueeze(0)
|
251 |
+
|
252 |
+
# Check for interruption before post-processing
|
253 |
+
if check_stop_flag():
|
254 |
+
return None, "❌ Inference interrupted", None
|
255 |
+
|
256 |
+
# Post-processing
|
257 |
+
output = postprocess(output, gt_size[0], gt_size[1])
|
258 |
+
|
259 |
+
# Convert back to PIL image format
|
260 |
+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
261 |
+
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # BGR to RGB
|
262 |
+
output = (output * 255.0).round().astype(np.uint8)
|
263 |
+
|
264 |
+
# Convert to PIL image
|
265 |
+
output_pil = Image.fromarray(output)
|
266 |
+
|
267 |
+
# Generate result information
|
268 |
+
original_size = f"{img.shape[3]}x{img.shape[2]}"
|
269 |
+
output_size = f"{output.shape[1]}x{output.shape[0]}"
|
270 |
+
tile_info = f"Tile processing enabled (size: {tile_size})" if use_tile else "Direct processing (no tiles)"
|
271 |
+
result_info = f"✅ Processing completed successfully!\nOriginal size: {original_size}\nSuper-resolution size: {output_size}\nScale factor: {scale:.2f}x\nProcessing mode: {tile_info}\nAMP acceleration: Force enabled\nOverlap size: {tile_overlap}\nCrop size: {crop_size}"
|
272 |
+
|
273 |
+
return output_pil, result_info, output_pil
|
274 |
+
|
275 |
+
except Exception as e:
|
276 |
+
if check_stop_flag():
|
277 |
+
return None, "❌ Inference interrupted", None
|
278 |
+
return None, f"❌ Error during processing: {str(e)}", None
|
279 |
+
|
280 |
+
def predict(image, scale):
|
281 |
+
"""Gradio prediction function"""
|
282 |
+
output_image, info, download_image = super_resolution_inference(image, scale)
|
283 |
+
|
284 |
+
# If processing successful, save image for download
|
285 |
+
if output_image is not None:
|
286 |
+
# Create temporary filename
|
287 |
+
timestamp = int(time.time())
|
288 |
+
temp_filename = f"GSASR_SR_result_{scale}x_{timestamp}.png"
|
289 |
+
temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
|
290 |
+
|
291 |
+
# Save image
|
292 |
+
output_image.save(temp_path, "PNG")
|
293 |
+
|
294 |
+
return output_image, temp_path, "✅ Ready", gr.update(interactive=True)
|
295 |
+
else:
|
296 |
+
return output_image, None, info if info else "❌ Processing failed", gr.update(interactive=True)
|
297 |
+
|
298 |
+
# Create Gradio interface
|
299 |
+
with gr.Blocks(title="🚀 GSASR (2D Gaussian Splatting Super-Resolution)") as demo:
|
300 |
+
gr.Markdown("# **🚀 GSASR (Generalized and efficient 2d gaussian splatting for arbitrary-scale super-resolution)**")
|
301 |
+
gr.Markdown("Official demo for GSASR. Please refer to our [paper](https://arxiv.org/pdf/2501.06838), [project page](https://mt-cly.github.io/GSASR.github.io/), and [github](https://github.com/ChrisDud0257/GSASR) for more details.")
|
302 |
+
|
303 |
+
with gr.Row():
|
304 |
+
with gr.Column():
|
305 |
+
input_image = gr.Image(type="pil", label="Input Image")
|
306 |
+
|
307 |
+
# Scale parameters
|
308 |
+
with gr.Group():
|
309 |
+
gr.Markdown("### SR Scale")
|
310 |
+
scale_slider = gr.Slider(minimum=1.0, maximum=30.0, value=4.0, step=0.1, label="SR Scale")
|
311 |
+
|
312 |
+
# Control buttons
|
313 |
+
with gr.Row():
|
314 |
+
submit_btn = gr.Button("🚀 Start Super-Resolution", variant="primary")
|
315 |
+
stop_btn = gr.Button("🛑 Stop Inference", variant="stop")
|
316 |
+
|
317 |
+
with gr.Column():
|
318 |
+
output_image = gr.Image(type="pil", label="Super-Resolution Result")
|
319 |
+
|
320 |
+
# Status display
|
321 |
+
status_text = gr.Textbox(label="Status", value="✅ Ready", interactive=False)
|
322 |
+
|
323 |
+
# Download component
|
324 |
+
with gr.Group():
|
325 |
+
gr.Markdown("### 📥 Download Super-Resolution Result")
|
326 |
+
download_btn = gr.File(visible=True)
|
327 |
+
|
328 |
+
# Event handlers
|
329 |
+
submit_event = submit_btn.click(
|
330 |
+
fn=predict,
|
331 |
+
inputs=[input_image, scale_slider],
|
332 |
+
outputs=[output_image, download_btn, status_text, stop_btn]
|
333 |
+
)
|
334 |
+
|
335 |
+
stop_btn.click(
|
336 |
+
fn=set_stop_flag,
|
337 |
+
inputs=[],
|
338 |
+
outputs=[status_text, stop_btn],
|
339 |
+
cancels=[submit_event]
|
340 |
+
)
|
341 |
+
|
342 |
+
# Example images
|
343 |
+
gr.Markdown("### 📚 Example Images")
|
344 |
+
gr.Markdown("Try these examples with different scales:")
|
345 |
+
|
346 |
+
gr.Examples(
|
347 |
+
examples=[
|
348 |
+
["assets/0846x4.png", 1.5],
|
349 |
+
["assets/0892x4.png", 2.8],
|
350 |
+
["assets/0873x4_cropped_120x120.png", 30.0]
|
351 |
+
],
|
352 |
+
inputs=[input_image, scale_slider],
|
353 |
+
examples_per_page=3,
|
354 |
+
cache_examples=False,
|
355 |
+
label="Examples"
|
356 |
+
)
|
357 |
+
|
358 |
+
if __name__ == "__main__":
|
359 |
+
demo.launch(share=True, server_name="0.0.0.0")
|
assets/0846x4.png
ADDED
![]() |
Git LFS Details
|
assets/0873.png
ADDED
![]() |
Git LFS Details
|
assets/0873x4.png
ADDED
![]() |
Git LFS Details
|
assets/0873x4_cropped_120x120.png
ADDED
![]() |
Git LFS Details
|
assets/0892x4.png
ADDED
![]() |
Git LFS Details
|
assets/Screenshot_cropped_180x100.png
ADDED
![]() |
Git LFS Details
|
dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:832b5f0cd6cd078e39a8bf68c481488cf606ec9633591d4d981794338a3f2b29
|
3 |
+
size 90122
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124
|
2 |
+
torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124
|
3 |
+
torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
|
4 |
+
# gradio==5.32.0
|
5 |
+
gradio==5.23.0
|
6 |
+
huggingface-hub==0.32.3
|
7 |
+
pillow==11.2.1
|
8 |
+
numpy==1.23.0
|
9 |
+
einops==0.8.1
|
10 |
+
opencv-python==4.11.0.86
|
11 |
+
pydantic==2.10.6
|
12 |
+
# dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl
|
setup.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
print("Building gscuda")
|
7 |
+
# 假设源文件在 gs_cuda 目录下
|
8 |
+
file_path = "utils/gs_cuda_dmax"
|
9 |
+
|
10 |
+
setup(
|
11 |
+
name="gscuda", # 模块名
|
12 |
+
ext_modules=[
|
13 |
+
CUDAExtension(
|
14 |
+
name="gscuda", # 可以直接作为模块导入
|
15 |
+
sources=[
|
16 |
+
os.path.join(file_path, "gswrapper.cpp"),
|
17 |
+
os.path.join(file_path, "gs.cu")
|
18 |
+
],
|
19 |
+
# 设置运行时库路径(可选)
|
20 |
+
library_dirs=[os.path.join(os.path.dirname(torch.__file__), 'lib')],
|
21 |
+
)
|
22 |
+
],
|
23 |
+
cmdclass={
|
24 |
+
"build_ext": BuildExtension
|
25 |
+
},
|
26 |
+
)
|
utils/edsrbaseline.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
from itertools import repeat
|
7 |
+
from torch import nn as nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.nn import init as init
|
10 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
11 |
+
|
12 |
+
@torch.no_grad()
|
13 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
14 |
+
"""Initialize network weights.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
18 |
+
scale (float): Scale initialized weights, especially for residual
|
19 |
+
blocks. Default: 1.
|
20 |
+
bias_fill (float): The value to fill bias. Default: 0
|
21 |
+
kwargs (dict): Other arguments for initialization function.
|
22 |
+
"""
|
23 |
+
if not isinstance(module_list, list):
|
24 |
+
module_list = [module_list]
|
25 |
+
for module in module_list:
|
26 |
+
for m in module.modules():
|
27 |
+
if isinstance(m, nn.Conv2d):
|
28 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
29 |
+
m.weight.data *= scale
|
30 |
+
if m.bias is not None:
|
31 |
+
m.bias.data.fill_(bias_fill)
|
32 |
+
elif isinstance(m, nn.Linear):
|
33 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
34 |
+
m.weight.data *= scale
|
35 |
+
if m.bias is not None:
|
36 |
+
m.bias.data.fill_(bias_fill)
|
37 |
+
elif isinstance(m, _BatchNorm):
|
38 |
+
init.constant_(m.weight, 1)
|
39 |
+
if m.bias is not None:
|
40 |
+
m.bias.data.fill_(bias_fill)
|
41 |
+
|
42 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
43 |
+
"""Make layers by stacking the same blocks.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
basic_block (nn.module): nn.module class for basic block.
|
47 |
+
num_basic_block (int): number of blocks.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
51 |
+
"""
|
52 |
+
layers = []
|
53 |
+
for _ in range(num_basic_block):
|
54 |
+
layers.append(basic_block(**kwarg))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
class ResidualBlockNoBN(nn.Module):
|
58 |
+
"""Residual block without BN.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
num_feat (int): Channel number of intermediate features.
|
62 |
+
Default: 64.
|
63 |
+
res_scale (float): Residual scale. Default: 1.
|
64 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
65 |
+
otherwise, use default_init_weights. Default: False.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
69 |
+
super(ResidualBlockNoBN, self).__init__()
|
70 |
+
self.res_scale = res_scale
|
71 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
72 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
73 |
+
self.relu = nn.ReLU(inplace=True)
|
74 |
+
|
75 |
+
if not pytorch_init:
|
76 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
identity = x
|
80 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
81 |
+
return identity + out * self.res_scale
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
class EDSRNOUP(nn.Module):
|
86 |
+
def __init__(self,
|
87 |
+
num_in_ch=3,
|
88 |
+
num_out_ch=3,
|
89 |
+
num_feat=64,
|
90 |
+
num_block=16,
|
91 |
+
upscale=4,
|
92 |
+
res_scale=1):
|
93 |
+
super(EDSRNOUP, self).__init__()
|
94 |
+
|
95 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
96 |
+
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
|
97 |
+
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
98 |
+
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
|
102 |
+
x = self.conv_first(x)
|
103 |
+
res = self.conv_after_body(self.body(x))
|
104 |
+
x = res + x
|
105 |
+
|
106 |
+
return res
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
x = torch.randn(8,3,48,48)
|
111 |
+
model = EDSRNOUP(num_in_ch=3, num_out_ch=3)
|
112 |
+
y = model(x)
|
113 |
+
print(y.shape)
|
utils/fea2gsropeamp.py
ADDED
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import print_function
|
3 |
+
from __future__ import division
|
4 |
+
|
5 |
+
import warnings
|
6 |
+
import math
|
7 |
+
import copy
|
8 |
+
from einops import rearrange
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_, kaiming_normal_
|
13 |
+
from einops import rearrange
|
14 |
+
from torch.utils.checkpoint import checkpoint
|
15 |
+
from functools import partial
|
16 |
+
from typing import Any, Optional, Tuple
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
20 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
21 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
22 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
23 |
+
def norm_cdf(x):
|
24 |
+
# Computes standard normal cumulative distribution function
|
25 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
26 |
+
|
27 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
28 |
+
warnings.warn(
|
29 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
30 |
+
'The distribution of values may be incorrect.',
|
31 |
+
stacklevel=2)
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
# Values are generated by using a truncated uniform distribution and
|
35 |
+
# then using the inverse CDF for the normal distribution.
|
36 |
+
# Get upper and lower cdf values
|
37 |
+
low = norm_cdf((a - mean) / std)
|
38 |
+
up = norm_cdf((b - mean) / std)
|
39 |
+
|
40 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
41 |
+
# [2l-1, 2u-1].
|
42 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
43 |
+
|
44 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
45 |
+
# standard normal
|
46 |
+
tensor.erfinv_()
|
47 |
+
|
48 |
+
# Transform to proper mean, std
|
49 |
+
tensor.mul_(std * math.sqrt(2.))
|
50 |
+
tensor.add_(mean)
|
51 |
+
|
52 |
+
# Clamp to ensure it's in the proper range
|
53 |
+
tensor.clamp_(min=a, max=b)
|
54 |
+
return tensor
|
55 |
+
|
56 |
+
|
57 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
58 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
59 |
+
normal distribution.
|
60 |
+
|
61 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
62 |
+
|
63 |
+
The values are effectively drawn from the
|
64 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
65 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
66 |
+
the bounds. The method used for generating the random values works
|
67 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
tensor: an n-dimensional `torch.Tensor`
|
71 |
+
mean: the mean of the normal distribution
|
72 |
+
std: the standard deviation of the normal distribution
|
73 |
+
a: the minimum cutoff value
|
74 |
+
b: the maximum cutoff value
|
75 |
+
|
76 |
+
Examples:
|
77 |
+
>>> w = torch.empty(3, 5)
|
78 |
+
>>> nn.init.trunc_normal_(w)
|
79 |
+
"""
|
80 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
81 |
+
|
82 |
+
def init_t_xy(end_x: int, end_y: int, zero_center=False):
|
83 |
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
84 |
+
t_x = (t % end_x).float()
|
85 |
+
t_y = torch.div(t, end_x, rounding_mode='floor').float()
|
86 |
+
|
87 |
+
return t_x, t_y
|
88 |
+
|
89 |
+
def init_random_2d_freqs(head_dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
|
90 |
+
freqs_x = []
|
91 |
+
freqs_y = []
|
92 |
+
theta = theta
|
93 |
+
mag = 1 / (theta ** (torch.arange(0, head_dim, 4)[: (head_dim // 4)].float() / head_dim))
|
94 |
+
for i in range(num_heads):
|
95 |
+
angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
|
96 |
+
fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
|
97 |
+
fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
|
98 |
+
freqs_x.append(fx)
|
99 |
+
freqs_y.append(fy)
|
100 |
+
freqs_x = torch.stack(freqs_x, dim=0)
|
101 |
+
freqs_y = torch.stack(freqs_y, dim=0)
|
102 |
+
freqs = torch.stack([freqs_x, freqs_y], dim=0)
|
103 |
+
return freqs
|
104 |
+
|
105 |
+
def compute_cis(freqs, t_x, t_y):
|
106 |
+
N = t_x.shape[0]
|
107 |
+
# No float 16 for this range
|
108 |
+
with torch.cuda.amp.autocast(enabled=False):
|
109 |
+
freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
|
110 |
+
freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
|
111 |
+
freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)
|
112 |
+
|
113 |
+
return freqs_cis
|
114 |
+
|
115 |
+
|
116 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
117 |
+
ndim = x.ndim
|
118 |
+
assert 0 <= 1 < ndim
|
119 |
+
# assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
120 |
+
# print(f"freqs_cis shape is {freqs_cis.shape}, x shape is {x.shape}")
|
121 |
+
if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
|
122 |
+
shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
|
123 |
+
elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
|
124 |
+
shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
|
125 |
+
|
126 |
+
return freqs_cis.view(*shape)
|
127 |
+
|
128 |
+
def apply_rotary_emb(
|
129 |
+
xq: torch.Tensor,
|
130 |
+
xk: torch.Tensor,
|
131 |
+
freqs_cis: torch.Tensor,
|
132 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
133 |
+
# print(f"xq shape is {xq.shape}, xq.shape[:-1] is {xq.shape[:-1]}")
|
134 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
135 |
+
# print(f"xq_ shape is {xq_.shape}")
|
136 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
137 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
138 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
139 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
140 |
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
141 |
+
|
142 |
+
def apply_rotary_emb_single(x, freqs_cis):
|
143 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
144 |
+
seq_len = x_.shape[2]
|
145 |
+
freqs_cis = freqs_cis[:, :seq_len, :]
|
146 |
+
freqs_cis = freqs_cis.unsqueeze(0).expand_as(x_)
|
147 |
+
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
148 |
+
return x_out.type_as(x).to(x.device)
|
149 |
+
|
150 |
+
def window_partition(x, window_size):
|
151 |
+
# x is the feature from net_g
|
152 |
+
b, c, h, w = x.shape
|
153 |
+
windows = rearrange(x, 'b c (h_count dh) (w_count dw) -> (b h_count w_count) (dh dw) c', dh=window_size,
|
154 |
+
dw=window_size)
|
155 |
+
# h_count = h // window_size
|
156 |
+
# w_count = w // window_size
|
157 |
+
# windows = x.reshape(b,c,h_count, window_size, w_count, window_size)
|
158 |
+
# windows = windows.permute(0,1,2,4,3,5) #b,c,h_count,w_count,window_size,window_size
|
159 |
+
# windows = windows.reshape(b,c,h_count*w_count, window_size * window_size)
|
160 |
+
# windows = windows.permute(0,2,3,1) #b,h_count*w_count, window_size*window_size,c
|
161 |
+
# windows = windows.reshape(-1, window_size*window_size, c)
|
162 |
+
|
163 |
+
return windows
|
164 |
+
|
165 |
+
|
166 |
+
def with_pos_embed(tensor, pos):
|
167 |
+
return tensor if pos is None else tensor + pos
|
168 |
+
|
169 |
+
|
170 |
+
class MLP(nn.Module):
|
171 |
+
def __init__(self, in_features, hidden_features, out_features, act_layer=nn.ReLU):
|
172 |
+
super(MLP, self).__init__()
|
173 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
174 |
+
self.act = act_layer()
|
175 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
x = self.fc1(x)
|
179 |
+
x = self.act(x)
|
180 |
+
x = self.fc2(x)
|
181 |
+
return x
|
182 |
+
|
183 |
+
class WindowCrossAttn(nn.Module):
|
184 |
+
def __init__(self, dim=180, num_heads=6, window_size=12, num_gs_seed=2304, rope_mixed = True, rope_theta = 10.0):
|
185 |
+
super(WindowCrossAttn, self).__init__()
|
186 |
+
self.dim = dim
|
187 |
+
self.num_heads = num_heads
|
188 |
+
self.window_size = window_size
|
189 |
+
self.num_gs_seed = num_gs_seed
|
190 |
+
self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
|
191 |
+
|
192 |
+
|
193 |
+
self.rope_mixed = rope_mixed
|
194 |
+
|
195 |
+
t_x, t_y = init_t_xy(end_x=max(self.num_gs_seed_sqrt, self.window_size), end_y=max(self.num_gs_seed_sqrt, self.window_size))
|
196 |
+
self.register_buffer('rope_t_x', t_x)
|
197 |
+
self.register_buffer('rope_t_y', t_y)
|
198 |
+
|
199 |
+
freqs = init_random_2d_freqs(
|
200 |
+
head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
|
201 |
+
rotate=self.rope_mixed
|
202 |
+
)
|
203 |
+
if self.rope_mixed:
|
204 |
+
self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
|
205 |
+
else:
|
206 |
+
self.register_buffer('rope_freqs', freqs)
|
207 |
+
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
|
208 |
+
self.rope_freqs_cis = freqs_cis
|
209 |
+
|
210 |
+
self.qhead = nn.Linear(dim, dim, bias=True)
|
211 |
+
self.khead = nn.Linear(dim, dim, bias=True)
|
212 |
+
self.vhead = nn.Linear(dim, dim, bias=True)
|
213 |
+
|
214 |
+
self.proj = nn.Linear(dim, dim)
|
215 |
+
|
216 |
+
|
217 |
+
def forward(self, gs, feat):
|
218 |
+
# gs shape: b*h_count*w_count, num_gs, c the input gs here should already include pos embedding and scale embedding
|
219 |
+
# feat shape: b*h_count*w_count, dh*dw, c dh=dw=window_size
|
220 |
+
b_, num_gs, c = gs.shape
|
221 |
+
b_, n, c = feat.shape
|
222 |
+
|
223 |
+
q = self.qhead(gs) # b_, num_gs_, c
|
224 |
+
q = q.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
|
225 |
+
q = q.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
|
226 |
+
|
227 |
+
k = self.khead(feat) # b_, n_, c
|
228 |
+
k = k.reshape(b_, n, self.num_heads, c // self.num_heads)
|
229 |
+
k = k.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
|
230 |
+
|
231 |
+
v = self.vhead(feat) # b_, n_, c
|
232 |
+
v = v.reshape(b_, n, self.num_heads, c // self.num_heads)
|
233 |
+
v = v.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
|
234 |
+
|
235 |
+
###### Apply rotary position embedding
|
236 |
+
if self.rope_mixed:
|
237 |
+
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
|
238 |
+
else:
|
239 |
+
freqs_cis = self.rope_freqs_cis.to(gs.device)
|
240 |
+
q = apply_rotary_emb_single(q, freqs_cis)
|
241 |
+
k = apply_rotary_emb_single(k, freqs_cis)
|
242 |
+
#########
|
243 |
+
|
244 |
+
attn = F.scaled_dot_product_attention(q, k, v)
|
245 |
+
|
246 |
+
x = attn.transpose(1, 2).reshape(b_, num_gs, c)
|
247 |
+
|
248 |
+
x = self.proj(x)
|
249 |
+
|
250 |
+
return x
|
251 |
+
|
252 |
+
|
253 |
+
class WindowCrossAttnLayer(nn.Module):
|
254 |
+
def __init__(self, dim=180, num_heads=6, window_size=12, shift_size=0, num_gs_seed=2308, rope_mixed = True, rope_theta = 10.0):
|
255 |
+
super(WindowCrossAttnLayer, self).__init__()
|
256 |
+
|
257 |
+
self.gs_cross_attn_scale = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
258 |
+
|
259 |
+
self.norm1 = nn.LayerNorm(dim)
|
260 |
+
self.norm2 = nn.LayerNorm(dim)
|
261 |
+
self.norm3 = nn.LayerNorm(dim)
|
262 |
+
self.norm4 = nn.LayerNorm(dim)
|
263 |
+
self.shift_size = shift_size
|
264 |
+
self.window_size = window_size
|
265 |
+
|
266 |
+
self.window_cross_attn = WindowCrossAttn(dim=dim, num_heads=num_heads, window_size=window_size,
|
267 |
+
num_gs_seed=num_gs_seed, rope_mixed = rope_mixed, rope_theta = rope_theta)
|
268 |
+
self.mlp_crossattn_scale = MLP(in_features=dim, hidden_features=dim, out_features=dim)
|
269 |
+
self.mlp_crossattn_feature = MLP(in_features=dim, hidden_features=dim, out_features=dim)
|
270 |
+
|
271 |
+
def forward(self, x, query_pos, feat, scale_embedding):
|
272 |
+
# gs shape: b*h_count*w_count, num_gs, c
|
273 |
+
# query_pos shape: b*h_count*w_count, num_gs, c
|
274 |
+
# feat shape: b,c,h,w
|
275 |
+
# scale_embedding shape: b*h_count*w_count, 1, c
|
276 |
+
|
277 |
+
###GS cross attn with scale embedding
|
278 |
+
resi = x
|
279 |
+
x = self.norm1(x)
|
280 |
+
# print(f"x: {x.shape} {x.device}, query_pos: {query_pos.shape}, {query_pos.device}, scale_embedding: {scale_embedding.shape}, {scale_embedding.device}")
|
281 |
+
x, _ = self.gs_cross_attn_scale(with_pos_embed(x, query_pos), scale_embedding, scale_embedding)
|
282 |
+
x = resi + x
|
283 |
+
|
284 |
+
###FFN
|
285 |
+
resi = x
|
286 |
+
x = self.norm2(x)
|
287 |
+
x = self.mlp_crossattn_scale(x)
|
288 |
+
x = resi + x
|
289 |
+
|
290 |
+
###cross attention for Q,K,V
|
291 |
+
resi = x
|
292 |
+
x = self.norm3(x)
|
293 |
+
if self.shift_size > 0:
|
294 |
+
shift_feat = torch.roll(feat, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
|
295 |
+
else:
|
296 |
+
shift_feat = feat
|
297 |
+
shift_feat = window_partition(shift_feat, self.window_size) # b*h_count*w_count, dh*dw, c dh=dw=window_size
|
298 |
+
x = self.window_cross_attn(with_pos_embed(x, query_pos),
|
299 |
+
shift_feat) # b*h_count*w_count, num_gs, c dh=dw=window_size
|
300 |
+
x = resi + x
|
301 |
+
|
302 |
+
###FFN
|
303 |
+
resi = x
|
304 |
+
x = self.norm4(x)
|
305 |
+
x = self.mlp_crossattn_feature(x)
|
306 |
+
x = resi + x
|
307 |
+
|
308 |
+
return x
|
309 |
+
|
310 |
+
|
311 |
+
class WindowCrossAttnBlock(nn.Module):
|
312 |
+
def __init__(self, dim=180, window_size=12, num_heads=6, num_layers=4, num_gs_seed=230, rope_mixed = True, rope_theta = 10.0):
|
313 |
+
super(WindowCrossAttnBlock, self).__init__()
|
314 |
+
|
315 |
+
self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
|
316 |
+
|
317 |
+
self.mlp = nn.Sequential(
|
318 |
+
nn.Linear(dim, dim),
|
319 |
+
nn.ReLU(),
|
320 |
+
nn.Linear(dim, dim)
|
321 |
+
)
|
322 |
+
self.norm = nn.LayerNorm(dim)
|
323 |
+
self.blocks = nn.ModuleList([
|
324 |
+
WindowCrossAttnLayer(
|
325 |
+
dim=dim,
|
326 |
+
num_heads=num_heads,
|
327 |
+
window_size=window_size,
|
328 |
+
shift_size=0 if i % 2 == 0 else window_size // 2,
|
329 |
+
num_gs_seed=num_gs_seed,
|
330 |
+
rope_mixed = rope_mixed, rope_theta = rope_theta) for i in range(num_layers)
|
331 |
+
])
|
332 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
333 |
+
|
334 |
+
def forward(self, x, query_pos, feat, scale_embedding, h_count, w_count):
|
335 |
+
resi = x
|
336 |
+
x = self.norm(x)
|
337 |
+
for block in self.blocks:
|
338 |
+
x = block(x, query_pos, feat, scale_embedding)
|
339 |
+
x = self.mlp(x)
|
340 |
+
|
341 |
+
x = rearrange(x, '(b m n) (h w) c -> b c (m h) (n w)', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
|
342 |
+
x = self.conv(x)
|
343 |
+
x = rearrange(x, 'b c (m h) (n w) -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
|
344 |
+
|
345 |
+
x = resi + x
|
346 |
+
return x
|
347 |
+
|
348 |
+
|
349 |
+
class GSSelfAttn(nn.Module):
|
350 |
+
def __init__(self, dim=180, num_heads=6, num_gs_seed_sqrt = 12, rope_mixed = True, rope_theta=10.0):
|
351 |
+
super(GSSelfAttn, self).__init__()
|
352 |
+
self.dim = dim
|
353 |
+
self.num_heads = num_heads
|
354 |
+
self.num_gs_seed_sqrt = num_gs_seed_sqrt
|
355 |
+
|
356 |
+
self.proj = nn.Linear(dim, dim)
|
357 |
+
self.rope_mixed = rope_mixed
|
358 |
+
|
359 |
+
t_x, t_y = init_t_xy(end_x=self.num_gs_seed_sqrt, end_y=self.num_gs_seed_sqrt)
|
360 |
+
self.register_buffer('rope_t_x', t_x)
|
361 |
+
self.register_buffer('rope_t_y', t_y)
|
362 |
+
|
363 |
+
freqs = init_random_2d_freqs(
|
364 |
+
head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
|
365 |
+
rotate=self.rope_mixed
|
366 |
+
)
|
367 |
+
if self.rope_mixed:
|
368 |
+
self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
|
369 |
+
else:
|
370 |
+
self.register_buffer('rope_freqs', freqs)
|
371 |
+
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
|
372 |
+
self.rope_freqs_cis = freqs_cis
|
373 |
+
|
374 |
+
self.qhead = nn.Linear(dim, dim, bias=True)
|
375 |
+
self.khead = nn.Linear(dim, dim, bias=True)
|
376 |
+
self.vhead = nn.Linear(dim, dim, bias=True)
|
377 |
+
|
378 |
+
def forward(self, gs):
|
379 |
+
# gs shape: b*h_count*w_count, num_gs, c
|
380 |
+
# pos shape: b*h_count*w_count, num_gs, c
|
381 |
+
b_, num_gs, c = gs.shape
|
382 |
+
|
383 |
+
q = self.qhead(gs)
|
384 |
+
q = q.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
|
385 |
+
q = q.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
|
386 |
+
|
387 |
+
k = self.khead(gs)
|
388 |
+
k = k.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
|
389 |
+
k = k.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
|
390 |
+
|
391 |
+
v = self.vhead(gs)
|
392 |
+
v = v.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
|
393 |
+
v = v.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
|
394 |
+
|
395 |
+
###### Apply rotary position embedding
|
396 |
+
if self.rope_mixed:
|
397 |
+
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
|
398 |
+
else:
|
399 |
+
freqs_cis = self.rope_freqs_cis.to(gs.device)
|
400 |
+
q, k = apply_rotary_emb(q, k, freqs_cis)
|
401 |
+
#########
|
402 |
+
|
403 |
+
attn = F.scaled_dot_product_attention(q, k, v)
|
404 |
+
|
405 |
+
attn = attn.transpose(1, 2).reshape(b_, num_gs, c)
|
406 |
+
|
407 |
+
|
408 |
+
attn = self.proj(attn)
|
409 |
+
|
410 |
+
return attn
|
411 |
+
|
412 |
+
|
413 |
+
class GSSelfAttnLayer(nn.Module):
|
414 |
+
def __init__(self, dim=180, num_heads=6, num_gs_seed_sqrt = 12, shift_size = 0, rope_mixed = True, rope_theta=10.0):
|
415 |
+
super(GSSelfAttnLayer, self).__init__()
|
416 |
+
|
417 |
+
self.norm1 = nn.LayerNorm(dim)
|
418 |
+
self.norm2 = nn.LayerNorm(dim)
|
419 |
+
self.norm3 = nn.LayerNorm(dim)
|
420 |
+
self.norm4 = nn.LayerNorm(dim)
|
421 |
+
|
422 |
+
self.gs_self_attn = GSSelfAttn(dim = dim, num_heads = num_heads, num_gs_seed_sqrt = num_gs_seed_sqrt, rope_mixed = rope_mixed, rope_theta=rope_theta)
|
423 |
+
|
424 |
+
self.mlp_selfattn = MLP(in_features=dim, hidden_features=dim, out_features=dim)
|
425 |
+
|
426 |
+
self.num_gs_seed_sqrt = num_gs_seed_sqrt
|
427 |
+
self.shift_size = shift_size
|
428 |
+
|
429 |
+
self.gs_cross_attn_scale = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
430 |
+
|
431 |
+
self.mlp_crossattn = MLP(in_features=dim, hidden_features=dim, out_features=dim)
|
432 |
+
|
433 |
+
def forward(self, gs, pos, h_count, w_count, scale_embedding):
|
434 |
+
# gs shape:b*h_count*w_count, num_gs_seed, channel
|
435 |
+
# pos shape: b*h_count*w_count, num_gs_seed, channel
|
436 |
+
# scale_embedding shape: b*h_count*w_count, 1, channel
|
437 |
+
|
438 |
+
# gs cross attn with scale_embedding
|
439 |
+
resi = gs
|
440 |
+
gs = self.norm3(gs)
|
441 |
+
gs, _ = self.gs_cross_attn_scale(with_pos_embed(gs, pos), scale_embedding, scale_embedding)
|
442 |
+
gs = gs + resi
|
443 |
+
|
444 |
+
# FFN
|
445 |
+
resi = gs
|
446 |
+
gs = self.norm4(gs)
|
447 |
+
gs = self.mlp_crossattn(gs)
|
448 |
+
gs = gs + resi
|
449 |
+
|
450 |
+
resi = gs
|
451 |
+
gs = self.norm1(gs)
|
452 |
+
|
453 |
+
#### shift gs
|
454 |
+
if self.shift_size > 0:
|
455 |
+
shift_gs = rearrange(gs, '(b m n) (h w) c -> b (m h) (n w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
|
456 |
+
shift_gs = torch.roll(shift_gs, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
457 |
+
shift_gs = rearrange(shift_gs, 'b (m h) (n w) c -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
|
458 |
+
else:
|
459 |
+
shift_gs = gs
|
460 |
+
|
461 |
+
#### gs self attention
|
462 |
+
gs = self.gs_self_attn(shift_gs)
|
463 |
+
|
464 |
+
#### shift gs back
|
465 |
+
if self.shift_size > 0:
|
466 |
+
shift_gs = rearrange(gs, '(b m n) (h w) c -> b (m h) (n w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
|
467 |
+
shift_gs = torch.roll(shift_gs, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
468 |
+
shift_gs = rearrange(shift_gs, 'b (m h) (n w) c -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
|
469 |
+
else:
|
470 |
+
shift_gs = gs
|
471 |
+
|
472 |
+
gs = shift_gs + resi
|
473 |
+
|
474 |
+
#FFN
|
475 |
+
resi = gs
|
476 |
+
gs = self.norm2(gs)
|
477 |
+
gs = self.mlp_selfattn(gs)
|
478 |
+
gs = gs + resi
|
479 |
+
return gs
|
480 |
+
|
481 |
+
|
482 |
+
class GSSelfAttnBlock(nn.Module):
|
483 |
+
def __init__(self, dim=180, num_heads=6, num_selfattn_layers=4, num_gs_seed_sqrt = 12, rope_mixed = True, rope_theta=10.0):
|
484 |
+
super(GSSelfAttnBlock, self).__init__()
|
485 |
+
self.num_gs_seed_sqrt = num_gs_seed_sqrt
|
486 |
+
|
487 |
+
self.mlp = nn.Sequential(
|
488 |
+
nn.Linear(dim, dim),
|
489 |
+
nn.ReLU(),
|
490 |
+
nn.Linear(dim, dim)
|
491 |
+
)
|
492 |
+
self.norm = nn.LayerNorm(dim)
|
493 |
+
self.blocks = nn.ModuleList([
|
494 |
+
GSSelfAttnLayer(
|
495 |
+
dim = dim,
|
496 |
+
num_heads = num_heads,
|
497 |
+
num_gs_seed_sqrt=num_gs_seed_sqrt,
|
498 |
+
shift_size=0 if i % 2 == 0 else num_gs_seed_sqrt // 2,
|
499 |
+
rope_mixed = rope_mixed, rope_theta=rope_theta
|
500 |
+
) for i in range(num_selfattn_layers)
|
501 |
+
])
|
502 |
+
|
503 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
504 |
+
|
505 |
+
def forward(self, gs, pos, h_count, w_count, scale_embedding):
|
506 |
+
resi = gs
|
507 |
+
gs = self.norm(gs)
|
508 |
+
for block in self.blocks:
|
509 |
+
gs = block(gs, pos, h_count, w_count, scale_embedding)
|
510 |
+
|
511 |
+
gs = self.mlp(gs)
|
512 |
+
gs = rearrange(gs, '(b m n) (h w) c -> b c (m h) (n w)', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
|
513 |
+
gs = self.conv(gs)
|
514 |
+
gs = rearrange(gs, 'b c (m h) (n w) -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
|
515 |
+
gs = gs + resi
|
516 |
+
return gs
|
517 |
+
|
518 |
+
class Fea2GS_ROPE_AMP(nn.Module):
|
519 |
+
def __init__(self, inchannel=64, channel=192, num_heads=6, num_crossattn_blocks=1, num_crossattn_layers=2, num_selfattn_blocks = 6, num_selfattn_layers = 6,
|
520 |
+
num_gs_seed=144, gs_up_factor=1.0, window_size=12, img_range=1.0, shuffle_scale1 = 2, shuffle_scale2 = 2, use_checkpoint = False,
|
521 |
+
rope_mixed = True, rope_theta = 10.0):
|
522 |
+
"""
|
523 |
+
Args:
|
524 |
+
gs_repeat_factor: the ratio of gs embedding number and pixel number along width&height, will generate
|
525 |
+
(h * gs_repeat_factor) * (w * gs_repeat_factor) gs embedding, higher values means repeat more gs embedding.
|
526 |
+
gs_up_factor: how many 2d gaussian are generated by one gasussian embedding.
|
527 |
+
"""
|
528 |
+
super(Fea2GS_ROPE_AMP, self).__init__()
|
529 |
+
self.channel = channel
|
530 |
+
self.nhead = num_heads
|
531 |
+
self.gs_up_factor = gs_up_factor
|
532 |
+
self.num_gs_seed = num_gs_seed
|
533 |
+
self.window_size = window_size
|
534 |
+
self.img_range = img_range
|
535 |
+
self.use_checkpoint = use_checkpoint
|
536 |
+
|
537 |
+
self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
|
538 |
+
self.gs_up_factor_sqrt = int(math.sqrt(gs_up_factor))
|
539 |
+
|
540 |
+
self.shuffle_scale1 = shuffle_scale1
|
541 |
+
self.shuffle_scale2 = shuffle_scale2
|
542 |
+
|
543 |
+
# shared gaussian embedding and its pos embedding
|
544 |
+
self.gs_embedding = nn.Parameter(torch.randn(self.num_gs_seed, channel), requires_grad=True)
|
545 |
+
self.pos_embedding = nn.Parameter(torch.randn(self.num_gs_seed, channel), requires_grad=True)
|
546 |
+
|
547 |
+
self.img_feat_proj = nn.Sequential(
|
548 |
+
nn.Conv2d(inchannel, channel, 3, 1, 1),
|
549 |
+
nn.ReLU(),
|
550 |
+
nn.Conv2d(channel, channel, 3, 1, 1)
|
551 |
+
)
|
552 |
+
|
553 |
+
self.window_crossattn_blocks = nn.ModuleList([
|
554 |
+
WindowCrossAttnBlock(dim=channel,
|
555 |
+
window_size=window_size,
|
556 |
+
num_heads=num_heads,
|
557 |
+
num_layers=num_crossattn_layers,
|
558 |
+
num_gs_seed=num_gs_seed, rope_mixed = rope_mixed, rope_theta = rope_theta) for i in range(num_crossattn_blocks)
|
559 |
+
])
|
560 |
+
|
561 |
+
self.gs_selfattn_blocks = nn.ModuleList([
|
562 |
+
GSSelfAttnBlock(dim=channel,
|
563 |
+
num_heads=num_heads,
|
564 |
+
num_selfattn_layers=num_selfattn_layers,
|
565 |
+
num_gs_seed_sqrt=self.num_gs_seed_sqrt,
|
566 |
+
rope_mixed = rope_mixed, rope_theta=rope_theta
|
567 |
+
) for i in range(num_selfattn_blocks)
|
568 |
+
])
|
569 |
+
|
570 |
+
# GS sigma_x, sigma_y
|
571 |
+
self.mlp_block_sigma = nn.Sequential(
|
572 |
+
nn.Linear(channel, channel),
|
573 |
+
nn.ReLU(),
|
574 |
+
nn.Linear(channel, channel * 4),
|
575 |
+
nn.ReLU(),
|
576 |
+
nn.Linear(channel * 4, int(2 * gs_up_factor))
|
577 |
+
)
|
578 |
+
|
579 |
+
# GS rho
|
580 |
+
self.mlp_block_rho = nn.Sequential(
|
581 |
+
nn.Linear(channel, channel),
|
582 |
+
nn.ReLU(),
|
583 |
+
nn.Linear(channel, channel * 4),
|
584 |
+
nn.ReLU(),
|
585 |
+
nn.Linear(channel * 4, int(1 * gs_up_factor))
|
586 |
+
)
|
587 |
+
|
588 |
+
# GS alpha
|
589 |
+
self.mlp_block_alpha = nn.Sequential(
|
590 |
+
nn.Linear(channel, channel),
|
591 |
+
nn.ReLU(),
|
592 |
+
nn.Linear(channel, channel * 4),
|
593 |
+
nn.ReLU(),
|
594 |
+
nn.Linear(channel * 4, int(1 * gs_up_factor))
|
595 |
+
)
|
596 |
+
|
597 |
+
# GS RGB values
|
598 |
+
self.mlp_block_rgb = nn.Sequential(
|
599 |
+
nn.Linear(channel, channel),
|
600 |
+
nn.ReLU(),
|
601 |
+
nn.Linear(channel, channel * 4),
|
602 |
+
nn.ReLU(),
|
603 |
+
nn.Linear(channel * 4, int(3 * gs_up_factor))
|
604 |
+
)
|
605 |
+
|
606 |
+
# GS mean_x, mean_y
|
607 |
+
self.mlp_block_mean = nn.Sequential(
|
608 |
+
nn.Linear(channel, channel),
|
609 |
+
nn.ReLU(),
|
610 |
+
nn.Linear(channel, channel * 4),
|
611 |
+
nn.ReLU(),
|
612 |
+
nn.Linear(channel * 4, int(2 * gs_up_factor))
|
613 |
+
)
|
614 |
+
|
615 |
+
self.scale_mlp = nn.Sequential(
|
616 |
+
nn.Linear(1, channel * 4),
|
617 |
+
nn.ReLU(),
|
618 |
+
nn.Linear(channel * 4, channel)
|
619 |
+
)
|
620 |
+
|
621 |
+
self.UPNet = nn.Sequential(
|
622 |
+
nn.Conv2d(channel, channel * self.shuffle_scale1 * self.shuffle_scale1, 3, 1, 1),
|
623 |
+
nn.PixelShuffle(self.shuffle_scale1),
|
624 |
+
nn.Conv2d(channel, channel * self.shuffle_scale2 * self.shuffle_scale2, 3, 1, 1),
|
625 |
+
nn.PixelShuffle(self.shuffle_scale2)
|
626 |
+
)
|
627 |
+
|
628 |
+
self.conv_final = nn.Conv2d(channel, channel, 3, 1, 1)
|
629 |
+
|
630 |
+
@staticmethod
|
631 |
+
def get_N_reference_points(h, w, device='cuda'):
|
632 |
+
# step_y = 1/(h+1)
|
633 |
+
# step_x = 1/(w+1)
|
634 |
+
step_y = 1 / h
|
635 |
+
step_x = 1 / w
|
636 |
+
ref_y, ref_x = torch.meshgrid(torch.linspace(step_y / 2, 1 - step_y / 2, h, dtype=torch.float32, device=device),
|
637 |
+
torch.linspace(step_x / 2, 1 - step_x / 2, w, dtype=torch.float32, device=device))
|
638 |
+
reference_points = torch.stack((ref_x.reshape(-1), ref_y.reshape(-1)), -1)
|
639 |
+
reference_points = reference_points[None, :, None]
|
640 |
+
return reference_points
|
641 |
+
|
642 |
+
def forward(self, srcs, scale):
|
643 |
+
'''
|
644 |
+
using deformable detr decoder for cross attention
|
645 |
+
Args:
|
646 |
+
query: (batch_size, num_query, dim)
|
647 |
+
query_pos: (batch_size, num_query, dim)
|
648 |
+
srcs: (batch_size, dim, h1, w1)
|
649 |
+
'''
|
650 |
+
b, c, h, w = srcs.shape ###srcs is pad to the size that could be divided by window_size
|
651 |
+
query = self.gs_embedding.unsqueeze(0).unsqueeze(1).repeat(b, (h // self.window_size) * (w // self.window_size),
|
652 |
+
1, 1) # b, h_count*w_count, num_gs_seed, channel
|
653 |
+
query = query.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
|
654 |
+
self.channel) # b*h_count*w_count, num_gs_seed, channel
|
655 |
+
|
656 |
+
scale = 1 / scale
|
657 |
+
scale = scale.unsqueeze(1) # b*1
|
658 |
+
scale_embedding = self.scale_mlp(scale) # b*channel
|
659 |
+
scale_embedding = scale_embedding.unsqueeze(1).unsqueeze(2).repeat(1, (h // self.window_size) * (
|
660 |
+
w // self.window_size), self.num_gs_seed, 1) # b, h_count*w_count, num_gs_seed, channel
|
661 |
+
scale_embedding = scale_embedding.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
|
662 |
+
self.channel) # b*h_count*w_count, num_gs_seed, channel
|
663 |
+
|
664 |
+
query_pos = self.pos_embedding.unsqueeze(0).unsqueeze(1).repeat(b, (h // self.window_size) * (
|
665 |
+
w // self.window_size), 1, 1) # b, h_count*w_count, num_gs_seed, channel
|
666 |
+
|
667 |
+
feat = self.img_feat_proj(srcs) # b*channel*h*w
|
668 |
+
|
669 |
+
query_pos = query_pos.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
|
670 |
+
self.channel) # b*h_count*w_count, num_gs_seed, channel
|
671 |
+
|
672 |
+
for block in self.window_crossattn_blocks:
|
673 |
+
if self.use_checkpoint:
|
674 |
+
query = checkpoint(block, query, query_pos, feat, scale_embedding, h // self.window_size, w // self.window_size)
|
675 |
+
else:
|
676 |
+
query = block(query, query_pos, feat, scale_embedding, h // self.window_size, w // self.window_size) # b*h_count*w_count, num_gs_seed, channel
|
677 |
+
|
678 |
+
resi = query
|
679 |
+
for block in self.gs_selfattn_blocks:
|
680 |
+
if self.use_checkpoint:
|
681 |
+
query = checkpoint(block, query, query_pos, h // self.window_size, w // self.window_size, scale_embedding)
|
682 |
+
else:
|
683 |
+
query = block(query, query_pos, h // self.window_size, w // self.window_size, scale_embedding)
|
684 |
+
|
685 |
+
|
686 |
+
query = rearrange(query, '(b m n) (h w) c -> b c (m h) (n w)', m=h // self.window_size, n=w // self.window_size,
|
687 |
+
h=self.num_gs_seed_sqrt)
|
688 |
+
query = self.conv_final(query)
|
689 |
+
|
690 |
+
|
691 |
+
resi = rearrange(resi, '(b m n) (h w) c -> b c (m h) (n w)', m=h // self.window_size, n=w // self.window_size,
|
692 |
+
h=self.num_gs_seed_sqrt)
|
693 |
+
|
694 |
+
query = query + resi
|
695 |
+
query = self.UPNet(query)
|
696 |
+
query = query.permute(0,2,3,1)
|
697 |
+
|
698 |
+
# query = rearrange(query, '(b m n) (h w) c -> b m h n w c', m=h // self.window_size, n=w // self.window_size,
|
699 |
+
# h=self.num_gs_seed_sqrt)
|
700 |
+
|
701 |
+
query_sigma = self.mlp_block_sigma(query).reshape(b, -1, 2)
|
702 |
+
query_rho = self.mlp_block_rho(query).reshape(b, -1, 1)
|
703 |
+
query_alpha = self.mlp_block_alpha(query).reshape(b, -1, 1)
|
704 |
+
query_rgb = self.mlp_block_rgb(query).reshape(b, -1, 3)
|
705 |
+
query_mean = self.mlp_block_mean(query).reshape(b, -1, 2)
|
706 |
+
|
707 |
+
query_mean = query_mean / torch.tensor(
|
708 |
+
[self.num_gs_seed_sqrt * (w // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2,
|
709 |
+
self.num_gs_seed_sqrt * (h // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2])[
|
710 |
+
None, None].to(query_mean.device) # b, h_count*w_count*num_gs_seed, 2
|
711 |
+
|
712 |
+
reference_offset = self.get_N_reference_points(self.num_gs_seed_sqrt * (h // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2,
|
713 |
+
self.num_gs_seed_sqrt * (w // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2, srcs.device)
|
714 |
+
query_mean = query_mean + reference_offset.reshape(1, -1, 2)
|
715 |
+
|
716 |
+
query = torch.cat([query_sigma, query_rho, query_alpha, query_rgb, query_mean],
|
717 |
+
dim=-1) # b, h_count*w_count*num_gs_seed, 9
|
718 |
+
|
719 |
+
return query
|
720 |
+
|
721 |
+
|
722 |
+
if __name__ == '__main__':
|
723 |
+
srcs = torch.randn(6, 64, 64, 64, requires_grad = True).cuda()
|
724 |
+
scale = torch.randn(6).cuda()
|
725 |
+
decoder = Fea2GS_ROPE_AMP(inchannel=64, channel=192, num_heads=6,
|
726 |
+
num_crossattn_blocks=1, num_crossattn_layers=2,
|
727 |
+
num_selfattn_blocks = 6, num_selfattn_layers = 6,
|
728 |
+
num_gs_seed=256, gs_up_factor=1.0, window_size=16,
|
729 |
+
img_range=1.0, shuffle_scale1 = 2, shuffle_scale2 = 2).cuda()
|
730 |
+
import time
|
731 |
+
|
732 |
+
for i in range(10):
|
733 |
+
torch.cuda.synchronize()
|
734 |
+
time1 = time.time()
|
735 |
+
# with torch.autocast(device_type = 'cuda'):
|
736 |
+
y = decoder(srcs, scale)
|
737 |
+
torch.cuda.synchronize()
|
738 |
+
time2 = time.time()
|
739 |
+
print(f"decoder time is {time2 - time1}")
|
740 |
+
print(y.shape)
|
741 |
+
|
742 |
+
torch.cuda.synchronize()
|
743 |
+
time3 = time.time()
|
744 |
+
y.sum().backward()
|
745 |
+
torch.cuda.synchronize()
|
746 |
+
time4 = time.time()
|
747 |
+
print(f"backward time is {time4 - time3}")
|
748 |
+
|
749 |
+
|
utils/gaussian_splatting.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
import torchvision.utils
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
|
10 |
+
|
11 |
+
def rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device):
|
12 |
+
sr_h, sr_w = sr_size[0], sr_size[1]
|
13 |
+
num_gs = sigma_x.shape[0]
|
14 |
+
|
15 |
+
sigma_x = sigma_x[...,None]
|
16 |
+
sigma_y = sigma_y[...,None]
|
17 |
+
rho = rho[...,None]
|
18 |
+
covariance = torch.stack(
|
19 |
+
[torch.stack([sigma_x**2, rho*sigma_x*sigma_y], dim=-1),
|
20 |
+
torch.stack([rho*sigma_x*sigma_y, sigma_y**2], dim=-1)],
|
21 |
+
dim=-2
|
22 |
+
)
|
23 |
+
|
24 |
+
# Check for positive semi-definiteness
|
25 |
+
determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y)**2
|
26 |
+
if (determinant < 0).any():
|
27 |
+
raise ValueError("Covariance matrix must be positive semi-definite")
|
28 |
+
|
29 |
+
inv_covariance = torch.inverse(covariance)
|
30 |
+
|
31 |
+
# Sampling progress
|
32 |
+
num_step = int(10 * 2 / step_size)
|
33 |
+
ax_h_batch = torch.tensor([i * step_size for i in range(num_step)]).to(device)[None]
|
34 |
+
ax_h_batch -= ax_h_batch.mean()
|
35 |
+
ax_w_batch = torch.tensor([i * step_size for i in range(num_step)]).to(device)[None]
|
36 |
+
ax_w_batch -= ax_w_batch.mean()
|
37 |
+
|
38 |
+
# Expanding dims for broadcasting
|
39 |
+
ax_batch_expanded_x = ax_h_batch.unsqueeze(-1).expand(-1, -1, num_step)
|
40 |
+
ax_batch_expanded_y = ax_w_batch.unsqueeze(1).expand(-1, num_step, -1)
|
41 |
+
|
42 |
+
# Creating a batch-wise meshgrid using broadcasting
|
43 |
+
xx, yy = ax_batch_expanded_x, ax_batch_expanded_y
|
44 |
+
|
45 |
+
xy = torch.stack([xx, yy], dim=-1)
|
46 |
+
|
47 |
+
max_buffer = 2000
|
48 |
+
final_image = torch.zeros((3, sr_h, sr_w), device=device)
|
49 |
+
for i in range(num_gs // max_buffer + 1):
|
50 |
+
# print('processing gs buffer id:', i, num_gs // max_buffer )
|
51 |
+
s_idx, e_idx = i * max_buffer, min((i + 1) * max_buffer, num_gs)
|
52 |
+
buffer_size = e_idx - s_idx
|
53 |
+
if buffer_size == 0:
|
54 |
+
break
|
55 |
+
# print(f"buffer_size is {buffer_size}")
|
56 |
+
buff_inv_covariance = inv_covariance[s_idx:e_idx]
|
57 |
+
buff_covariance = covariance[s_idx:e_idx]
|
58 |
+
buffer_pixel_coords = coords[s_idx:e_idx]
|
59 |
+
buffer_alpha = colours_with_alpha[s_idx:e_idx].unsqueeze(-1).unsqueeze(-1)
|
60 |
+
|
61 |
+
z = torch.einsum('b...i,b...ij,b...j->b...', xy, -0.5 * buff_inv_covariance, xy)
|
62 |
+
kernel = torch.exp(z) / (2 * torch.tensor(np.pi, device=device) * torch.sqrt(torch.det(buff_covariance)).view(buffer_size, 1, 1))
|
63 |
+
|
64 |
+
kernel_max = kernel.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
|
65 |
+
kernel_normalized = kernel / (kernel_max + 1e-4)
|
66 |
+
kernel_reshaped = kernel_normalized.repeat(1, 3, 1).view(buffer_size * 3, num_step, num_step)
|
67 |
+
kernel_reshaped = kernel_reshaped.unsqueeze(0).reshape(buffer_size, 3, num_step, num_step)
|
68 |
+
|
69 |
+
b, c, h, w = kernel_reshaped.shape
|
70 |
+
|
71 |
+
# Create a batch of 2D affine matrices
|
72 |
+
theta = torch.zeros(b, 2, 3, dtype=torch.float32, device=device)
|
73 |
+
theta[:, 0, 0] = 1 * sr_w / num_step
|
74 |
+
theta[:, 1, 1] = 1 * sr_h / num_step
|
75 |
+
theta[:, 0, 2] = -buffer_pixel_coords[:, 0] * sr_w / num_step # !!!!!!!! note -1
|
76 |
+
theta[:, 1, 2] = -buffer_pixel_coords[:, 1] * sr_h / num_step # !!!!!!!! note -1
|
77 |
+
|
78 |
+
grid = F.affine_grid(theta, size=(b, c, sr_h, sr_w), align_corners=False) # !!!!! align_corners=False
|
79 |
+
kernel_reshaped_translated = F.grid_sample(kernel_reshaped, grid,
|
80 |
+
align_corners=False) # !!!! align_corners=False
|
81 |
+
buffer_final_image = buffer_alpha * kernel_reshaped_translated
|
82 |
+
final_image += buffer_final_image.sum(0)
|
83 |
+
|
84 |
+
return final_image
|
85 |
+
|
86 |
+
def rendering_cuda(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device):
|
87 |
+
from utils.gs_cuda.gswrapper import GSCUDA
|
88 |
+
sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() # (gs num, 3)
|
89 |
+
coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0
|
90 |
+
coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0
|
91 |
+
colours_with_alpha = colours_with_alpha.contiguous() # (gs num, 3)
|
92 |
+
rendered_img = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous()
|
93 |
+
# with torch.no_grad():
|
94 |
+
# final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img)
|
95 |
+
# final_image = (torch.sum(sigmas)+torch.sum(coords)+torch.sum(colours_with_alpha))*final_image
|
96 |
+
final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img)
|
97 |
+
final_image = final_image.permute(2, 0, 1).contiguous()
|
98 |
+
return final_image
|
99 |
+
|
100 |
+
def rendering_cuda_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, buffer_size = 1000000):
|
101 |
+
from utils.gs_cuda.gswrapper import GSCUDA
|
102 |
+
sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() # (gs num, 3)
|
103 |
+
coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0
|
104 |
+
coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0
|
105 |
+
colours_with_alpha = colours_with_alpha.contiguous() # (gs num, 3)
|
106 |
+
final_image = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous()
|
107 |
+
|
108 |
+
# buffer
|
109 |
+
buffer_num = len(sigma_x)// buffer_size+1
|
110 |
+
for buffer_id in range(buffer_num):
|
111 |
+
# print(f'processing{buffer_id+1}/{buffer_num}')
|
112 |
+
idx_start, idx_end = buffer_id * buffer_size, (buffer_id+1) * buffer_size
|
113 |
+
final_image = GSCUDA.apply(sigmas[idx_start:idx_end], coords[idx_start:idx_end],
|
114 |
+
colours_with_alpha[idx_start:idx_end], final_image)
|
115 |
+
# final_image += buffer_image
|
116 |
+
final_image = final_image.permute(2, 0, 1).contiguous()
|
117 |
+
return final_image
|
118 |
+
|
119 |
+
def rendering_cuda_dmax(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, dmax=1):
|
120 |
+
from utils.gs_cuda_dmax.gswrapper import GSCUDA
|
121 |
+
sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() # (gs num, 3)
|
122 |
+
coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0
|
123 |
+
coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0
|
124 |
+
colours_with_alpha = colours_with_alpha.contiguous() # (gs num, 3)
|
125 |
+
rendered_img = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous()
|
126 |
+
# with torch.no_grad():
|
127 |
+
# final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img, dmax)
|
128 |
+
# final_image = (torch.sum(sigmas)+torch.sum(coords)+torch.sum(colours_with_alpha))*final_image
|
129 |
+
final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img, dmax)
|
130 |
+
final_image = final_image.permute(2, 0, 1).contiguous()
|
131 |
+
return final_image
|
132 |
+
|
133 |
+
def rendering_cuda_dmax_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, dmax=1, buffer_size = 1000000):
|
134 |
+
from utils.gs_cuda_dmax.gswrapper import GSCUDA
|
135 |
+
sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() # (gs num, 3)
|
136 |
+
coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0
|
137 |
+
coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0
|
138 |
+
colours_with_alpha = colours_with_alpha.contiguous() # (gs num, 3)
|
139 |
+
|
140 |
+
final_image = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous()
|
141 |
+
# with torch.no_grad():
|
142 |
+
# final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img, dmax)
|
143 |
+
# final_image = (torch.sum(sigmas)+torch.sum(coords)+torch.sum(colours_with_alpha))*final_image
|
144 |
+
|
145 |
+
# buffer
|
146 |
+
buffer_num = len(sigma_x)// buffer_size+1
|
147 |
+
for buffer_id in range(buffer_num):
|
148 |
+
# print(f'processing{buffer_id+1}/{buffer_num}')
|
149 |
+
idx_start, idx_end = buffer_id * buffer_size, (buffer_id+1) * buffer_size
|
150 |
+
final_image = GSCUDA.apply(sigmas[idx_start:idx_end], coords[idx_start:idx_end],
|
151 |
+
colours_with_alpha[idx_start:idx_end], final_image, dmax)
|
152 |
+
# final_image += buffer_image
|
153 |
+
|
154 |
+
final_image = final_image.permute(2, 0, 1).contiguous()
|
155 |
+
return final_image
|
156 |
+
|
157 |
+
|
158 |
+
def generate_2D_gaussian_splatting_step(sr_size, gs_parameters, scale, scale_modify,
|
159 |
+
sample_coords = None, default_step_size = 1.2,
|
160 |
+
cuda_rendering=True, mode = 'scale_modify',
|
161 |
+
if_dmax = True,
|
162 |
+
dmax_mode = 'fix',
|
163 |
+
dmax = 25):
|
164 |
+
|
165 |
+
# set step_size according to scale factor
|
166 |
+
if mode == 'scale':
|
167 |
+
final_scale = scale
|
168 |
+
elif mode == 'scale_modify':
|
169 |
+
assert scale_modify[0] == scale_modify[1], f"scale_modify is not the same-{scale_modify}"
|
170 |
+
final_scale = scale_modify[0]
|
171 |
+
step_size = default_step_size/ final_scale
|
172 |
+
|
173 |
+
# prepare gaussian properties
|
174 |
+
sigma_x = 0.99999 * torch.sigmoid(gs_parameters[:, 0:1]) + 1e-6
|
175 |
+
sigma_y = 0.99999 * torch.sigmoid(gs_parameters[:, 1:2]) + 1e-6
|
176 |
+
rho = 0.999999 * torch.tanh(gs_parameters[:, 2:3])
|
177 |
+
alpha = torch.sigmoid(gs_parameters[:, 3:4])
|
178 |
+
colours = torch.sigmoid(gs_parameters[:, 4:7])
|
179 |
+
coords = (gs_parameters[:, 7:9] * 2 - 1)
|
180 |
+
colours_with_alpha = colours * alpha
|
181 |
+
|
182 |
+
|
183 |
+
## todo for save GS parameters
|
184 |
+
# GS_parameters = torch.cat([sigma_x, sigma_y, rho, alpha, colours, coords], dim = 1)
|
185 |
+
# torch.save(GS_parameters.cpu(), "/home/notebook/code/personal/S9053766/chendu/myprojects/GSSR_20240606/results/0804_48*48.pt")
|
186 |
+
# print(f"GS_parameter shape is {GS_parameters.shape}")
|
187 |
+
# print(f"-------")
|
188 |
+
|
189 |
+
# todo for visualization the position of Gaussian
|
190 |
+
# select = (torch.randn_like(alpha[..., 0])>2.5)
|
191 |
+
# colours_with_alpha[select, 0] = 1
|
192 |
+
# colours_with_alpha[select, 1] = 0
|
193 |
+
# colours_with_alpha[select, 2] = 0
|
194 |
+
# todo for visualization the shape of Gaussian
|
195 |
+
# sigma_x = torch.ones_like(sigma_x)*0.05
|
196 |
+
# sigma_y = torch.ones_like(sigma_y)*0.05
|
197 |
+
# rho = torch.ones_like(rho) * 0
|
198 |
+
# colours_with_alpha = torch.ones_like(colours_with_alpha)*0.5
|
199 |
+
|
200 |
+
# rendering
|
201 |
+
if cuda_rendering:
|
202 |
+
if if_dmax:
|
203 |
+
if dmax_mode == 'dynamic':
|
204 |
+
dmax = (dmax + 2) / min(sr_size[0], sr_size[1])
|
205 |
+
elif dmax_mode == 'fix':
|
206 |
+
pass
|
207 |
+
else:
|
208 |
+
raise ValueError(f"dmax_mode-{dmax_mode} must be fix or dynamic")
|
209 |
+
final_image = rendering_cuda_dmax(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, dmax=dmax, device=sigma_x.device)
|
210 |
+
else:
|
211 |
+
final_image = rendering_cuda(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device)
|
212 |
+
else:
|
213 |
+
final_image = rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device)
|
214 |
+
if sample_coords is not None:
|
215 |
+
sample_RGB_values = [final_image[:, coord[0], coord[1]] for coord in sample_coords]
|
216 |
+
final_image = torch.stack(sample_RGB_values, dim = 1)
|
217 |
+
return final_image
|
218 |
+
|
219 |
+
def generate_2D_gaussian_splatting_step_buffer(sr_size, gs_parameters, scale, scale_modify,
|
220 |
+
sample_coords = None, default_step_size = 1.2,
|
221 |
+
cuda_rendering=True, mode = 'scale_modify',
|
222 |
+
if_dmax = True,
|
223 |
+
dmax_mode = 'fix',
|
224 |
+
dmax = 25,
|
225 |
+
buffer_size = 4000000):
|
226 |
+
|
227 |
+
# set step_size according to scale factor
|
228 |
+
if mode == 'scale':
|
229 |
+
final_scale = scale
|
230 |
+
elif mode == 'scale_modify':
|
231 |
+
assert scale_modify[0] == scale_modify[1], f"scale_modify is not the same-{scale_modify}"
|
232 |
+
final_scale = scale_modify[0]
|
233 |
+
step_size = default_step_size/ final_scale
|
234 |
+
|
235 |
+
# prepare gaussian properties
|
236 |
+
sigma_x = 0.99999 * torch.sigmoid(gs_parameters[:, 0:1]) + 1e-6
|
237 |
+
sigma_y = 0.99999 * torch.sigmoid(gs_parameters[:, 1:2]) + 1e-6
|
238 |
+
rho = 0.999999 * torch.tanh(gs_parameters[:, 2:3])
|
239 |
+
alpha = torch.sigmoid(gs_parameters[:, 3:4])
|
240 |
+
colours = torch.sigmoid(gs_parameters[:, 4:7])
|
241 |
+
coords = (gs_parameters[:, 7:9] * 2 - 1)
|
242 |
+
colours_with_alpha = colours * alpha
|
243 |
+
|
244 |
+
# rendering
|
245 |
+
if cuda_rendering:
|
246 |
+
if if_dmax:
|
247 |
+
if dmax_mode == 'dynamic':
|
248 |
+
dmax = (dmax + 2) / min(sr_size[0], sr_size[1])
|
249 |
+
elif dmax_mode == 'fix':
|
250 |
+
pass
|
251 |
+
else:
|
252 |
+
raise ValueError(f"dmax_mode-{dmax_mode} must be fix or dynamic")
|
253 |
+
final_image = rendering_cuda_dmax_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha,
|
254 |
+
sr_size, step_size, dmax=dmax, device=sigma_x.device,
|
255 |
+
buffer_size = buffer_size)
|
256 |
+
else:
|
257 |
+
final_image = rendering_cuda_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha,
|
258 |
+
sr_size, step_size, device=sigma_x.device,
|
259 |
+
buffer_size = buffer_size)
|
260 |
+
else:
|
261 |
+
final_image = rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device)
|
262 |
+
if sample_coords is not None:
|
263 |
+
sample_RGB_values = [final_image[:, coord[0], coord[1]] for coord in sample_coords]
|
264 |
+
final_image = torch.stack(sample_RGB_values, dim = 1)
|
265 |
+
return final_image
|
utils/gs_cuda/check.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from gswrapper import gaussiansplatting_render
|
3 |
+
|
4 |
+
def torch_version(sigmas, coords, colors, image_size):
|
5 |
+
h, w = image_size
|
6 |
+
c = colors.shape[-1]
|
7 |
+
|
8 |
+
if h >= 50 or w >= 50:
|
9 |
+
logger.warning(f'too large values for h({h}), w({w}), torch version would be slow')
|
10 |
+
|
11 |
+
rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
|
12 |
+
|
13 |
+
for hi in range(h):
|
14 |
+
for wi in range(w):
|
15 |
+
curh = 2*hi/(h-1)-1.0
|
16 |
+
curw = 2*wi/(w-1)-1.0
|
17 |
+
|
18 |
+
v = (curw-coords[:,0])**2/sigmas[:,0]**2
|
19 |
+
v -= (2*sigmas[:,2])*(curw-coords[:,0])*(curh-coords[:,1])/sigmas[:,0]/sigmas[:,1]
|
20 |
+
v += (curh-coords[:,1])**2/sigmas[:,1]**2
|
21 |
+
v *= -1.0/(2.0*(1-sigmas[:,2]**2))
|
22 |
+
v = torch.exp(v)
|
23 |
+
|
24 |
+
for ci in range(c):
|
25 |
+
rendered_img[hi, wi, ci] = torch.sum(v*colors[:, ci])
|
26 |
+
|
27 |
+
return rendered_img
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
s = 40 # the number of gs
|
32 |
+
image_size = (49, 49)
|
33 |
+
|
34 |
+
for _ in range(1):
|
35 |
+
print(f"--------------------------- begins --------------------------------")
|
36 |
+
|
37 |
+
sigmas = 0.999*torch.rand(s, 3).to(torch.float32).to("cuda")
|
38 |
+
# sigmas[:,:2] = 5*sigmas[:, :2]
|
39 |
+
coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
|
40 |
+
colors = torch.rand(s, 3).to(torch.float32).to("cuda")
|
41 |
+
|
42 |
+
# sigmas = torch.Tensor([[0.9196, 0.3979, 0.7784]]).to(torch.float32).to("cuda")
|
43 |
+
# coords = torch.Tensor([[-0.0469, -0.1726]]).to(torch.float32).to("cuda")
|
44 |
+
# colors = torch.Tensor([[0.3775, 0.2346, 0.1513]]).to(torch.float32).to("cuda")
|
45 |
+
# colors = torch.ones_like(coords[:,0:1])
|
46 |
+
|
47 |
+
print(f"sigmas: {sigmas}, \ncoords:{coords}, \ncolors:{colors}")
|
48 |
+
|
49 |
+
# --- check forward ---
|
50 |
+
with torch.no_grad():
|
51 |
+
rendered_img_th = torch_version(sigmas,coords,colors,image_size)
|
52 |
+
rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size)
|
53 |
+
|
54 |
+
#
|
55 |
+
distance = (rendered_img_th-rendered_img_cuda)**2
|
56 |
+
print(f"check forward - torch: {rendered_img_th[:2,:2,0]}")
|
57 |
+
print(f"check forward - cuda: {rendered_img_cuda[:2,:2,0]}")
|
58 |
+
print(f"check forward - distance: {distance[:2, :2, 0]}")
|
59 |
+
print(f"check forward - sum: {torch.sum(distance)}\n")
|
60 |
+
# --- ends ---
|
61 |
+
|
62 |
+
# --- check backward ---
|
63 |
+
sigmas.requires_grad_(True)
|
64 |
+
coords.requires_grad_(True)
|
65 |
+
colors.requires_grad_(True)
|
66 |
+
# sigmas.retain_grad()
|
67 |
+
# coords.retain_grad()
|
68 |
+
# colors.retain_grad()
|
69 |
+
weight = torch.rand_like(rendered_img_th) # make each pixel has different grads
|
70 |
+
|
71 |
+
sigmas.grad = None
|
72 |
+
coords.grad = None
|
73 |
+
colors.grad = None
|
74 |
+
rendered_img_th = torch_version(sigmas,coords,colors,image_size)
|
75 |
+
loss_th = torch.sum(weight*rendered_img_th)
|
76 |
+
loss_th.backward()
|
77 |
+
|
78 |
+
sigmas_grad_th = sigmas.grad
|
79 |
+
coords_grad_th = coords.grad
|
80 |
+
colors_grad_th = colors.grad
|
81 |
+
|
82 |
+
sigmas.grad = None
|
83 |
+
coords.grad = None
|
84 |
+
colors.grad = None
|
85 |
+
rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size)
|
86 |
+
loss_cuda = torch.sum(weight*rendered_img_cuda)
|
87 |
+
# loss_cuda = torch.sum(rendered_img_cuda)
|
88 |
+
loss_cuda.backward()
|
89 |
+
|
90 |
+
sigmas_grad_cuda = sigmas.grad
|
91 |
+
coords_grad_cuda = coords.grad
|
92 |
+
colors_grad_cuda = colors.grad
|
93 |
+
|
94 |
+
distance_sigmas_grad = (sigmas_grad_th-sigmas_grad_cuda)**2
|
95 |
+
distance_coords_grad = (coords_grad_th-coords_grad_cuda)**2
|
96 |
+
distance_colors_grad = (colors_grad_th-colors_grad_cuda)**2
|
97 |
+
|
98 |
+
print(f"check backward - sigmas - torch: {sigmas_grad_th[:2]}")
|
99 |
+
print(f"check backward - sigmas - cuda: {sigmas_grad_cuda[:2]}")
|
100 |
+
print(f"check backward - sigmas - distance: {distance_sigmas_grad[:2]}")
|
101 |
+
print(f"check backward - sigmas - sum: {torch.sum(distance_sigmas_grad)}\n")
|
102 |
+
|
103 |
+
print(f"check backward - coords - torch: {coords_grad_th[:2]}")
|
104 |
+
print(f"check backward - coords - cuda: {coords_grad_cuda[:2]}")
|
105 |
+
print(f"check backward - coords - distance: {distance_coords_grad[:2]}")
|
106 |
+
print(f"check backward - coords - sum: {torch.sum(distance_coords_grad)}\n")
|
107 |
+
|
108 |
+
print(f"check backward - colors - torch: {colors_grad_th[:2]}")
|
109 |
+
print(f"check backward - colors - cuda: {colors_grad_cuda[:2]}")
|
110 |
+
print(f"check backward - colors - distance: {distance_colors_grad[:2]}")
|
111 |
+
print(f"check backward - colors - sum: {torch.sum(distance_colors_grad)}\n")
|
112 |
+
|
113 |
+
print(f"--------------------------- ends --------------------------------\n\n")
|
114 |
+
|
115 |
+
|
utils/gs_cuda/gs.cu
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <cmath>
|
3 |
+
#include <curand_kernel.h>
|
4 |
+
|
5 |
+
#define PI 3.1415926536
|
6 |
+
#define PI2 6.283153072
|
7 |
+
|
8 |
+
extern "C"
|
9 |
+
__global__ void _gs_render_cuda(
|
10 |
+
const float *sigmas,
|
11 |
+
const float *coords,
|
12 |
+
const float *colors,
|
13 |
+
float *rendered_img,
|
14 |
+
const int s, // gs num
|
15 |
+
const int h,
|
16 |
+
const int w,
|
17 |
+
const int c
|
18 |
+
){
|
19 |
+
|
20 |
+
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
21 |
+
int curw = index % w;
|
22 |
+
int curh = int((index-curw)/w);
|
23 |
+
if(curw >= w || curh >=h){
|
24 |
+
return;
|
25 |
+
}
|
26 |
+
|
27 |
+
float curw_f = 2.0*curw/(w-1) - 1.0;
|
28 |
+
float curh_f = 2.0*curh/(h-1) - 1.0;
|
29 |
+
|
30 |
+
// printf("index:%d, curw:%d, curh:%d, curw_f:%f, curh_f:%f\n",index,curw,curh,curw_f,curh_f);
|
31 |
+
|
32 |
+
for(int si=0; si<s; si++){
|
33 |
+
|
34 |
+
// compute the 2d gs value
|
35 |
+
float sigma_x = sigmas[si*3+0];
|
36 |
+
float sigma_y = sigmas[si*3+1];
|
37 |
+
float rho = sigmas[si*3+2];
|
38 |
+
float x = coords[si*2+0];
|
39 |
+
float y = coords[si*2+1];
|
40 |
+
|
41 |
+
//
|
42 |
+
float one_div_one_minus_rho2 = 1.0 / (1-rho*rho) ;
|
43 |
+
float one_div_sigma_x = 1.0 / sigma_x;
|
44 |
+
float one_div_sigma_y = 1.0 / sigma_y;
|
45 |
+
float d_x = curw_f - x;
|
46 |
+
float d_y = curh_f - y;
|
47 |
+
|
48 |
+
float v = one_div_sigma_x*one_div_sigma_x*d_x*d_x;
|
49 |
+
v -= 2*rho*d_x*d_y*one_div_sigma_x*one_div_sigma_y;
|
50 |
+
v += d_y*d_y*one_div_sigma_y*one_div_sigma_y;
|
51 |
+
v *= -one_div_one_minus_rho2 / 2.0;
|
52 |
+
v = exp(v);
|
53 |
+
// since we normlize the v with the max, we remove this step to obtain equal result
|
54 |
+
// v *= one_div_sigma_x * one_div_sigma_y * pow(one_div_one_minus_rho2, 0.5) / PI2 ;
|
55 |
+
// printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
|
56 |
+
|
57 |
+
for(int ci=0; ci<c; ci++){
|
58 |
+
rendered_img[(curh*w+curw)*c+ci] += v*colors[si*3+ci];
|
59 |
+
}
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
void _gs_render(
|
65 |
+
const float *sigmas,
|
66 |
+
const float *coords,
|
67 |
+
const float *colors,
|
68 |
+
float *rendered_img,
|
69 |
+
const int s,
|
70 |
+
const int h,
|
71 |
+
const int w,
|
72 |
+
const int c
|
73 |
+
) {
|
74 |
+
|
75 |
+
int threads=64;
|
76 |
+
dim3 grid( h*w, 1);
|
77 |
+
dim3 block( threads, 1);
|
78 |
+
_gs_render_cuda<<<grid, block>>>(sigmas, coords, colors, rendered_img, s, h, w, c);
|
79 |
+
}
|
80 |
+
|
81 |
+
extern "C"
|
82 |
+
__global__ void _gs_render_backward_cuda(
|
83 |
+
const float *sigmas,
|
84 |
+
const float *coords,
|
85 |
+
const float *colors,
|
86 |
+
const float *grads,
|
87 |
+
float *grads_sigmas,
|
88 |
+
float *grads_coords,
|
89 |
+
float *grads_colors,
|
90 |
+
const int s, // gs num
|
91 |
+
const int h,
|
92 |
+
const int w,
|
93 |
+
const int c
|
94 |
+
){
|
95 |
+
|
96 |
+
int curs = blockIdx.x*blockDim.x + threadIdx.x;
|
97 |
+
if(curs >= s){
|
98 |
+
return ;
|
99 |
+
}
|
100 |
+
|
101 |
+
// obtain parameters of gs
|
102 |
+
float sigma_x = sigmas[curs*3+0];
|
103 |
+
float sigma_y = sigmas[curs*3+1];
|
104 |
+
float rho = sigmas[curs*3+2];
|
105 |
+
float x = coords[curs*2+0];
|
106 |
+
float y = coords[curs*2+1];
|
107 |
+
float cr = colors[curs*3+0];
|
108 |
+
float cg = colors[curs*3+1];
|
109 |
+
float cb = colors[curs*3+2];
|
110 |
+
|
111 |
+
//
|
112 |
+
float w1 = -0.5 / (1-rho*rho) ;
|
113 |
+
float w2 = 1.0 / (sigma_x*sigma_x);
|
114 |
+
float w3 = 1.0 / (sigma_x*sigma_y);
|
115 |
+
float w4 = 1.0 / (sigma_y*sigma_y);
|
116 |
+
float od_sx = 1.0 / sigma_x;
|
117 |
+
float od_sy = 1.0 / sigma_y;
|
118 |
+
|
119 |
+
// init
|
120 |
+
float _gr=0.0, _gg=0.0, _gb=0.0;
|
121 |
+
float _gx=0.0, _gy=0.0;
|
122 |
+
float _gsx=0.0, _gsy=0.0, _gsr=0.0;
|
123 |
+
|
124 |
+
for(int hi = 0; hi < h; hi++){
|
125 |
+
for( int wi=0; wi < w; wi++){
|
126 |
+
|
127 |
+
float curw_f = 2.0*wi/(w-1) - 1.0;
|
128 |
+
float curh_f = 2.0*hi/(h-1) - 1.0;
|
129 |
+
|
130 |
+
// obtain grad to p^t_r, p^t_g, p^t_b
|
131 |
+
float gptr = grads[(hi*w+wi)*c+0]; // grad of loss to P^t_r
|
132 |
+
float gptg = grads[(hi*w+wi)*c+1];
|
133 |
+
float gptb = grads[(hi*w+wi)*c+2];
|
134 |
+
|
135 |
+
// compute the 2d gs value
|
136 |
+
|
137 |
+
float d_x = curw_f - x; // distance along x axis
|
138 |
+
float d_y = curh_f - y;
|
139 |
+
float d = w2*d_x*d_x - 2*rho*w3*d_x*d_y + w4*d_y*d_y;
|
140 |
+
float v = w1*d;
|
141 |
+
v = exp(v);
|
142 |
+
// printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
|
143 |
+
|
144 |
+
// compute grad of colors
|
145 |
+
_gr += v*gptr;
|
146 |
+
_gg += v*gptg;
|
147 |
+
_gb += v*gptb;
|
148 |
+
|
149 |
+
// compute grad of coords
|
150 |
+
float gpt = gptr*cr+gptg*cg+gptb*cb;
|
151 |
+
float v_2_w1 = v*2*w1;
|
152 |
+
|
153 |
+
float g_vst_to_gsx = v_2_w1*(-w2*d_x+rho*w3*d_y); // grad of v^{st} to G^s_x
|
154 |
+
_gx += gpt*g_vst_to_gsx;
|
155 |
+
float g_vst_to_gsy = v_2_w1*(-w4*d_y+rho*w3*d_x); // grad of v^{st} to G^s_y
|
156 |
+
_gy += gpt*g_vst_to_gsy;
|
157 |
+
|
158 |
+
// compute grad of sigmas
|
159 |
+
float g_vst_to_gsigx = v_2_w1*od_sx* (w3*rho*d_x*d_y - w2*d_x*d_x);
|
160 |
+
_gsx += gpt*g_vst_to_gsigx;
|
161 |
+
float g_vst_to_gsigy = v_2_w1*od_sy* (w3*rho*d_x*d_y - w4*d_y*d_y);
|
162 |
+
_gsy += gpt*g_vst_to_gsigy;
|
163 |
+
float g_vst_to_rho = -v_2_w1*(2*w1*rho*d+w3*d_x*d_y);
|
164 |
+
_gsr += gpt*g_vst_to_rho;
|
165 |
+
}
|
166 |
+
}
|
167 |
+
|
168 |
+
// write the values
|
169 |
+
grads_sigmas[curs*3+0] = _gsx;
|
170 |
+
grads_sigmas[curs*3+1] = _gsy;
|
171 |
+
grads_sigmas[curs*3+2] = _gsr;
|
172 |
+
grads_coords[curs*2+0] = _gx;
|
173 |
+
grads_coords[curs*2+1] = _gy;
|
174 |
+
grads_colors[curs*3+0] = _gr;
|
175 |
+
grads_colors[curs*3+1] = _gg;
|
176 |
+
grads_colors[curs*3+2] = _gb;
|
177 |
+
|
178 |
+
}
|
179 |
+
|
180 |
+
void _gs_render_backward(
|
181 |
+
const float *sigmas,
|
182 |
+
const float *coords,
|
183 |
+
const float *colors,
|
184 |
+
const float *grads, // (h, w, c)
|
185 |
+
float *grads_sigmas,
|
186 |
+
float *grads_coords,
|
187 |
+
float *grads_colors,
|
188 |
+
const int s,
|
189 |
+
const int h,
|
190 |
+
const int w,
|
191 |
+
const int c
|
192 |
+
) {
|
193 |
+
|
194 |
+
int threads=64;
|
195 |
+
dim3 grid(s, 1);
|
196 |
+
dim3 block( threads, 1);
|
197 |
+
_gs_render_backward_cuda<<<grid, block>>>(sigmas, coords, colors, grads, grads_sigmas, grads_coords, grads_colors, s, h, w, c);
|
198 |
+
}
|
199 |
+
|
utils/gs_cuda/gs.h
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
void _gs_render(
|
2 |
+
const float *sigmas,
|
3 |
+
const float *coords,
|
4 |
+
const float *colors,
|
5 |
+
float *rendered_img,
|
6 |
+
const int s,
|
7 |
+
const int h,
|
8 |
+
const int w,
|
9 |
+
const int c
|
10 |
+
);
|
11 |
+
|
12 |
+
void _gs_render_backward(
|
13 |
+
const float *sigmas,
|
14 |
+
const float *coords,
|
15 |
+
const float *colors,
|
16 |
+
const float *grads,
|
17 |
+
float *grads_sigmas,
|
18 |
+
float *grads_coords,
|
19 |
+
float *grads_colors,
|
20 |
+
const int s,
|
21 |
+
const int h,
|
22 |
+
const int w,
|
23 |
+
const int c
|
24 |
+
);
|
utils/gs_cuda/gswrapper.cpp
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "gs.h"
|
2 |
+
#include <torch/extension.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
6 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
7 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
8 |
+
|
9 |
+
void gs_render(
|
10 |
+
torch::Tensor &sigmas,
|
11 |
+
torch::Tensor &coords,
|
12 |
+
torch::Tensor &colors,
|
13 |
+
torch::Tensor &rendered_img,
|
14 |
+
const int s,
|
15 |
+
const int h,
|
16 |
+
const int w,
|
17 |
+
const int c
|
18 |
+
){
|
19 |
+
|
20 |
+
CHECK_INPUT(sigmas);
|
21 |
+
CHECK_INPUT(coords);
|
22 |
+
CHECK_INPUT(colors);
|
23 |
+
CHECK_INPUT(rendered_img);
|
24 |
+
|
25 |
+
// run the code at the cuda device same with the input
|
26 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));
|
27 |
+
|
28 |
+
_gs_render(
|
29 |
+
(const float *) sigmas.data_ptr(),
|
30 |
+
(const float *) coords.data_ptr(),
|
31 |
+
(const float *) colors.data_ptr(),
|
32 |
+
(float *) rendered_img.data_ptr(),
|
33 |
+
s, h, w, c);
|
34 |
+
}
|
35 |
+
|
36 |
+
void gs_render_backward(
|
37 |
+
torch::Tensor &sigmas,
|
38 |
+
torch::Tensor &coords,
|
39 |
+
torch::Tensor &colors,
|
40 |
+
torch::Tensor &grads,
|
41 |
+
torch::Tensor &grads_sigmas,
|
42 |
+
torch::Tensor &grads_coords,
|
43 |
+
torch::Tensor &grads_colors,
|
44 |
+
const int s,
|
45 |
+
const int h,
|
46 |
+
const int w,
|
47 |
+
const int c
|
48 |
+
){
|
49 |
+
|
50 |
+
CHECK_INPUT(sigmas);
|
51 |
+
CHECK_INPUT(coords);
|
52 |
+
CHECK_INPUT(colors);
|
53 |
+
CHECK_INPUT(grads);
|
54 |
+
CHECK_INPUT(grads_sigmas);
|
55 |
+
CHECK_INPUT(grads_coords);
|
56 |
+
CHECK_INPUT(grads_colors);
|
57 |
+
|
58 |
+
|
59 |
+
// run the code at the cuda device same with the input
|
60 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));
|
61 |
+
|
62 |
+
_gs_render_backward(
|
63 |
+
(const float *) sigmas.data_ptr(),
|
64 |
+
(const float *) coords.data_ptr(),
|
65 |
+
(const float *) colors.data_ptr(),
|
66 |
+
(const float *) grads.data_ptr(),
|
67 |
+
(float *) grads_sigmas.data_ptr(),
|
68 |
+
(float *) grads_coords.data_ptr(),
|
69 |
+
(float *) grads_colors.data_ptr(),
|
70 |
+
s, h, w, c);
|
71 |
+
}
|
72 |
+
|
73 |
+
PYBIND11_MODULE( TORCH_EXTENSION_NAME, m) {
|
74 |
+
m.def( "gs_render",
|
75 |
+
&gs_render,
|
76 |
+
"cuda forward wrapper");
|
77 |
+
m.def( "gs_render_backward",
|
78 |
+
&gs_render_backward,
|
79 |
+
"cuda backward wrapper");
|
80 |
+
}
|
utils/gs_cuda/gswrapper.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.utils.cpp_extension import load
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.autograd.function import once_differentiable
|
6 |
+
|
7 |
+
build_path = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'build')
|
8 |
+
os.makedirs(build_path, exist_ok=True)
|
9 |
+
|
10 |
+
file_path = os.path.split(os.path.abspath(__file__))[0]
|
11 |
+
GSWrapper = load(
|
12 |
+
name="gscuda",
|
13 |
+
# sources=["gs_cuda/gswrapper.cpp", "gs_cuda/gs.cu"],
|
14 |
+
sources=[os.path.join(file_path, "gswrapper.cpp"),
|
15 |
+
os.path.join(file_path, "gs.cu")],
|
16 |
+
build_directory=build_path,
|
17 |
+
verbose=True)
|
18 |
+
|
19 |
+
class GSCUDA(Function):
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, sigmas, coords, colors, rendered_img):
|
23 |
+
ctx.save_for_backward(sigmas, coords, colors)
|
24 |
+
h, w, c = rendered_img.shape
|
25 |
+
s = sigmas.shape[0]
|
26 |
+
GSWrapper.gs_render(sigmas, coords, colors, rendered_img, s, h, w, c)
|
27 |
+
return rendered_img
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
@once_differentiable
|
31 |
+
def backward(ctx, grad_output):
|
32 |
+
sigmas, coords, colors = ctx.saved_tensors
|
33 |
+
h, w, c = grad_output.shape
|
34 |
+
s = sigmas.shape[0]
|
35 |
+
grads_sigmas = torch.zeros_like(sigmas)
|
36 |
+
grads_coords = torch.zeros_like(coords)
|
37 |
+
grads_colors = torch.zeros_like(colors)
|
38 |
+
GSWrapper.gs_render_backward(sigmas, coords, colors, grad_output.contiguous(), grads_sigmas, grads_coords, grads_colors, s, h, w, c)
|
39 |
+
return (grads_sigmas, grads_coords, grads_colors, None)
|
40 |
+
|
41 |
+
def gaussiansplatting_render(sigmas, coords, colors, image_size):
|
42 |
+
sigmas = sigmas.contiguous() # (gs num, 3)
|
43 |
+
coords = coords.contiguous() # (gs num, 2)
|
44 |
+
colors = colors.contiguous() # (gs num, c)
|
45 |
+
h, w = image_size[:2]
|
46 |
+
c = colors.shape[-1]
|
47 |
+
rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
|
48 |
+
return GSCUDA.apply(sigmas, coords, colors, rendered_img)
|
49 |
+
|
utils/gs_cuda/mylineprofiler.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import sys
|
4 |
+
import timeit
|
5 |
+
import tokenize
|
6 |
+
import torch
|
7 |
+
import psutil
|
8 |
+
import inspect
|
9 |
+
from loguru import logger
|
10 |
+
from prettytable import PrettyTable
|
11 |
+
|
12 |
+
# implement by xtudbxk
|
13 |
+
# github: https://github.com/xtudbxk/lineprofiler
|
14 |
+
class MyLineProfiler():
|
15 |
+
def __init__(self, base='ms', cuda_sync=True, gpuids=(0,), warmup=0, warmup_lineno=-1):
|
16 |
+
|
17 |
+
if base == 'ms':
|
18 |
+
self.base_n = 1000
|
19 |
+
elif base == 's':
|
20 |
+
self.base_n = 1
|
21 |
+
else:
|
22 |
+
logguru.warning(f'Unsupported base - {base}, using "s" instead')
|
23 |
+
|
24 |
+
self.base = base
|
25 |
+
self.cuda_sync = cuda_sync
|
26 |
+
self.gpuids = gpuids
|
27 |
+
self.warmup = warmup
|
28 |
+
self.warmup_counter = warmup
|
29 |
+
# we should wait this line execute warup_counter times
|
30 |
+
# before recording the stats
|
31 |
+
self.warmup_lineno = warmup_lineno
|
32 |
+
|
33 |
+
# for time profiling
|
34 |
+
self._times = {}
|
35 |
+
self._func_name = None
|
36 |
+
self._func_filename = None
|
37 |
+
self._last_time = -1
|
38 |
+
self._last_lineno = -1
|
39 |
+
self._func_hit_count = 0
|
40 |
+
self._func_firstlineno = 0
|
41 |
+
|
42 |
+
# for memory profiling
|
43 |
+
self._process = psutil.Process(os.getpid())
|
44 |
+
self._memory = {}
|
45 |
+
self._last_memory = 0
|
46 |
+
|
47 |
+
# for cuda memory profiling
|
48 |
+
self._gpu_memory = {}
|
49 |
+
self._gpu_last_memory = 0
|
50 |
+
|
51 |
+
def __trace_func__(self, frame, event, arg):
|
52 |
+
# print(f'in {frame.f_code.co_filename} func {frame.f_code.co_name} line {frame.f_lineno}, event - {event}')
|
53 |
+
|
54 |
+
# check if run into the decorated func
|
55 |
+
if self._func_firstlineno == frame.f_code.co_firstlineno and frame.f_code.co_name == self._func_name and frame.f_code.co_filename == self._func_filename:
|
56 |
+
|
57 |
+
# --- obtain info for current hit ---
|
58 |
+
# cuda related
|
59 |
+
if self.cuda_sync is True:
|
60 |
+
torch.cuda.synchronize()
|
61 |
+
|
62 |
+
current_time = timeit.default_timer()
|
63 |
+
memory = self._process.memory_info().rss
|
64 |
+
gpu_memory = torch.cuda.memory_allocated()
|
65 |
+
# --- ends ---
|
66 |
+
|
67 |
+
# --- initilize the info when first hit ---
|
68 |
+
if frame.f_lineno not in self._times: # first hit time for this line
|
69 |
+
self._times[frame.f_lineno] = {'hit':0, 'time': 0}
|
70 |
+
self._memory[frame.f_lineno] = 0
|
71 |
+
self._gpu_memory[frame.f_lineno] = 0
|
72 |
+
# --- ends ---
|
73 |
+
|
74 |
+
# --- record info before call the decorated func ---
|
75 |
+
# 'call' - before call the func
|
76 |
+
if event == 'call':
|
77 |
+
self._last_time = current_time
|
78 |
+
self._last_lineno = frame.f_lineno
|
79 |
+
self._last_memory = memory
|
80 |
+
self._last_gpu_memory = gpu_memory
|
81 |
+
|
82 |
+
if self.warmup_lineno < 0:
|
83 |
+
self.warmup_counter -= 1
|
84 |
+
if self.warmup_counter < 0:
|
85 |
+
self._func_hit_count += 1
|
86 |
+
# --- ends ---
|
87 |
+
|
88 |
+
# 'line' - after excuting the line
|
89 |
+
# 'return' - return from the function
|
90 |
+
if event == 'line' or event == 'return':
|
91 |
+
|
92 |
+
if event == 'line' and self.warmup_counter < 0:
|
93 |
+
self._times[frame.f_lineno]['hit'] += 1
|
94 |
+
|
95 |
+
|
96 |
+
# --- obtain the memory and time consumed by this line ---
|
97 |
+
if self.warmup_counter < 0:
|
98 |
+
self._times[self._last_lineno]['time'] += current_time - self._last_time
|
99 |
+
self._memory[self._last_lineno] += memory - self._last_memory
|
100 |
+
self._gpu_memory[self._last_lineno] += gpu_memory - self._gpu_last_memory
|
101 |
+
# --- ends ---
|
102 |
+
|
103 |
+
if self.cuda_sync is True:
|
104 |
+
torch.cuda.synchronize()
|
105 |
+
|
106 |
+
self._last_time = timeit.default_timer()
|
107 |
+
self._last_memory = memory
|
108 |
+
self._gpu_last_memory = gpu_memory
|
109 |
+
self._last_lineno = frame.f_lineno
|
110 |
+
|
111 |
+
return self.__trace_func__
|
112 |
+
|
113 |
+
def decorate(self, func):
|
114 |
+
if self._func_name is not None:
|
115 |
+
logger.warning(f'Only support decorate only one func. Aready decorated "{self._func_name}"')
|
116 |
+
self._func_name = func.__name__
|
117 |
+
self._func_filename = func.__code__.co_filename
|
118 |
+
self._func_firstlineno = func.__code__.co_firstlineno
|
119 |
+
|
120 |
+
def _f(*args, **kwargs):
|
121 |
+
origin_trace_func = sys.gettrace()
|
122 |
+
sys.settrace(self.__trace_func__)
|
123 |
+
ret = func(*args, **kwargs)
|
124 |
+
sys.settrace(origin_trace_func)
|
125 |
+
return ret
|
126 |
+
return _f
|
127 |
+
|
128 |
+
def _get_table(self):
|
129 |
+
|
130 |
+
if len(self._times) <= 0:
|
131 |
+
logger.warning(f"un recorded datas, please ensure the function is executed")
|
132 |
+
return None
|
133 |
+
|
134 |
+
# --- load the source code ---
|
135 |
+
with open(self._func_filename, 'r') as f:
|
136 |
+
source_lines = [line.strip('\n') for line in f.readlines()]
|
137 |
+
code_str = "\n".join(source_lines)
|
138 |
+
|
139 |
+
def_lineno = min(self._times.keys())
|
140 |
+
final_lineno = max(self._times.keys())
|
141 |
+
|
142 |
+
# remove the additional blank content
|
143 |
+
pre_blank_count = len(source_lines[def_lineno-1]) - len(source_lines[def_lineno-1].lstrip(' ').lstrip('\t'))
|
144 |
+
# --- ends ---
|
145 |
+
|
146 |
+
# --- analysize the source code and collect infos for multi-line code ---
|
147 |
+
new_logic_linenos = [token.start[0] for token in tokenize.generate_tokens(
|
148 |
+
io.StringIO(code_str).readline) if token.type == 4]
|
149 |
+
# --- ends ---
|
150 |
+
|
151 |
+
# --- merge the stats multi-line code ---
|
152 |
+
sorted_linenos = [lineno for lineno in self._times.keys()]
|
153 |
+
sorted_linenos.sort(key=int)
|
154 |
+
|
155 |
+
lineno_cache = []
|
156 |
+
for lineno in sorted_linenos:
|
157 |
+
if lineno not in new_logic_linenos:
|
158 |
+
lineno_cache.append(lineno)
|
159 |
+
else:
|
160 |
+
# we should merge its info to the prev_lineno
|
161 |
+
if len(lineno_cache) <= 0:
|
162 |
+
continue
|
163 |
+
else:
|
164 |
+
lineno_cache.append(lineno)
|
165 |
+
first_lineno = lineno_cache[0]
|
166 |
+
for prev_lineno in lineno_cache[1:]:
|
167 |
+
self._times[first_lineno]["hit"] = min(self._times[first_lineno]["hit"], self._times[prev_lineno]["hit"])
|
168 |
+
self._times[first_lineno]["time"] += self._times[prev_lineno]["time"]
|
169 |
+
del self._times[prev_lineno]
|
170 |
+
|
171 |
+
self._memory[first_lineno] += self._memory[prev_lineno]
|
172 |
+
del self._memory[prev_lineno]
|
173 |
+
|
174 |
+
self._gpu_memory[first_lineno] += self._gpu_memory[prev_lineno]
|
175 |
+
del self._gpu_memory[prev_lineno]
|
176 |
+
lineno_cache = []
|
177 |
+
# --- ends ---
|
178 |
+
|
179 |
+
# --- initialize the pretty table for output ---
|
180 |
+
table = PrettyTable(['lineno', 'hits', 'time', 'time per hit', 'hit perc', 'time perc', 'mem inc', 'mem peak', 'gpu mem inc', 'gpu mem peak'])
|
181 |
+
# --- ends ---
|
182 |
+
|
183 |
+
# --- compute some statisticals ---
|
184 |
+
total_hit = 0 # for compute the hit percentage
|
185 |
+
total_time = 0
|
186 |
+
for lineno, stats in self._times.items():
|
187 |
+
if lineno == def_lineno: continue
|
188 |
+
total_hit += stats['hit']
|
189 |
+
total_time += stats['time']
|
190 |
+
|
191 |
+
total_memory = sum([m for l,m in self._memory.items()]) / 1024 / 1024
|
192 |
+
total_gpu_memory = sum([m for l,m in self._gpu_memory.items()]) / 1024 / 1024
|
193 |
+
# --- ends ---
|
194 |
+
|
195 |
+
peak_cpu_memory = 0
|
196 |
+
peak_gpu_memory = 0
|
197 |
+
for lineno in range(def_lineno, final_lineno+1):
|
198 |
+
if lineno not in self._times:
|
199 |
+
# the comment line, empty line or merged line from multi-lines code
|
200 |
+
table.add_row([lineno, '-', '-', '-', '-', '-', '-',f'{peak_cpu_memory:5.3f} MB', '-', f'{peak_gpu_memory:5.3f} MB'])
|
201 |
+
else:
|
202 |
+
stats = self._times[lineno]
|
203 |
+
if lineno == def_lineno:
|
204 |
+
table.add_row([lineno, self._func_hit_count, f'{total_time*self.base_n:.4f} {self.base}', f'{total_time/self._func_hit_count*self.base_n:.4f} {self.base}', '-', '-', f'{total_memory:5.3f} MB', 'baseline', f'{total_gpu_memory:5.3f} MB', 'baseline'])
|
205 |
+
else:
|
206 |
+
|
207 |
+
line_result = [lineno, stats['hit'],
|
208 |
+
f'{stats["time"]*self.base_n:.4f} {self.base}',
|
209 |
+
f'{stats["time"]/stats["hit"]*self.base_n:.4f} {self.base}' if stats['hit'] > 0 else 'nan',
|
210 |
+
f'{stats["hit"]/total_hit*100:.3f}%' if total_hit > 0 else 'nan',
|
211 |
+
f'{stats["time"]/total_time*100:.3f}%'] if total_time > 0 else 'nan'
|
212 |
+
|
213 |
+
line_result += [f'{self._memory[lineno]/1024/1024:5.3f} MB' if stats['hit'] > 0 else '0 MB']
|
214 |
+
peak_cpu_memory = peak_cpu_memory + self._memory[lineno]/1024/1024
|
215 |
+
line_result += [f'{peak_cpu_memory:5.3f} MB']
|
216 |
+
|
217 |
+
line_result += [f'{self._gpu_memory[lineno]/1024/1024:5.3f} MB' if stats['hit'] > 0 else '0 MB']
|
218 |
+
peak_gpu_memory = peak_gpu_memory + self._gpu_memory[lineno]/1024/1024
|
219 |
+
line_result += [f'{peak_gpu_memory:5.3f} MB']
|
220 |
+
|
221 |
+
table.add_row(line_result)
|
222 |
+
|
223 |
+
table.add_column('sources', [source_lines[i-1][pre_blank_count:] if len(source_lines[i-1])>pre_blank_count else '' for i in range(def_lineno, final_lineno+1)], 'l')
|
224 |
+
return table
|
225 |
+
|
226 |
+
def print(self, filename=None, mode="w"):
|
227 |
+
introducation = '''
|
228 |
+
1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
|
229 |
+
2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
|
230 |
+
3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
|
231 |
+
4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
|
232 |
+
'''
|
233 |
+
print(introducation)
|
234 |
+
|
235 |
+
table = PrettyTable(['lineno', 'hits', 'time', 'time per hit', 'hit perc', 'time perc', 'mem inc', 'mem peak', 'gpu mem inc', 'gpu mem peak'])
|
236 |
+
table = self._get_table()
|
237 |
+
print(table)
|
238 |
+
if filename is not None:
|
239 |
+
with open(filename, mode) as f:
|
240 |
+
f.write(introducation)
|
241 |
+
f.write(f"args - base={self.base}, cuda_sync={self.cuda_sync}, gpuids={self.gpuids}, warmup={self.warmup}\n")
|
242 |
+
f.write(str(table))
|
243 |
+
|
244 |
+
if __name__ == '__main__':
|
245 |
+
import numpy as np
|
246 |
+
def mytest(h='hello',
|
247 |
+
xx="xx"):
|
248 |
+
|
249 |
+
h = h + 'world'
|
250 |
+
a = []
|
251 |
+
for _ in range(200):
|
252 |
+
# a = np.zeros((1000, 1000), dtype=np.float32)
|
253 |
+
a.append(np.zeros((1000, 1000), dtype=np.float32))
|
254 |
+
a.append(
|
255 |
+
np.zeros((1000, 1000),
|
256 |
+
dtype=np.float32))
|
257 |
+
# print(a[0,0])
|
258 |
+
print(h)
|
259 |
+
|
260 |
+
profiler = MyLineProfiler(cuda_sync=False, warmup=2)
|
261 |
+
mytest = profiler.decorate(mytest)
|
262 |
+
for _ in range(5):
|
263 |
+
mytest()
|
264 |
+
profiler.print()
|
utils/gs_cuda/profile.log
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
|
3 |
+
2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
|
4 |
+
3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
|
5 |
+
4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
|
6 |
+
args - base=ms, cuda_sync=True, gpuids=(0,), warmup=0
|
7 |
+
+--------+------+------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
|
8 |
+
| lineno | hits | time | time per hit | hit perc | time perc | mem inc | mem peak | gpu mem inc | gpu mem peak | sources |
|
9 |
+
+--------+------+------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
|
10 |
+
| 41 | 1 | 76.8299 ms | 76.8299 ms | - | - | 0.902 MB | baseline | 3.500 MB | baseline | def gaussiansplatting_render(sigmas, coords, colors, image_size): |
|
11 |
+
| 42 | 1 | 0.0353 ms | 0.0353 ms | 14.286% | 0.046% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | sigmas = sigmas.contiguous() # (gs num, 3) |
|
12 |
+
| 43 | 1 | 0.0078 ms | 0.0078 ms | 14.286% | 0.010% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | coords = coords.contiguous() # (gs num, 2) |
|
13 |
+
| 44 | 1 | 0.0063 ms | 0.0063 ms | 14.286% | 0.008% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | colors = colors.contiguous() # (gs num, c) |
|
14 |
+
| 45 | 1 | 0.0063 ms | 0.0063 ms | 14.286% | 0.008% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | h, w = image_size[:2] |
|
15 |
+
| 46 | 1 | 0.0093 ms | 0.0093 ms | 14.286% | 0.012% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | c = colors.shape[-1] |
|
16 |
+
| 47 | 1 | 1.8306 ms | 1.8306 ms | 14.286% | 2.383% | 0.438 MB | 0.438 MB | 3.000 MB | 3.000 MB | rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
|
17 |
+
| 48 | 1 | 74.9344 ms | 74.9344 ms | 14.286% | 97.533% | 0.465 MB | 0.902 MB | 0.000 MB | 3.000 MB | return GSCUDA.apply(sigmas, coords, colors, rendered_img) |
|
18 |
+
+--------+------+------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
|
19 |
+
1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
|
20 |
+
2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
|
21 |
+
3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
|
22 |
+
4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
|
23 |
+
args - base=ms, cuda_sync=True, gpuids=(0,), warmup=0
|
24 |
+
+--------+------+--------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
|
25 |
+
| lineno | hits | time | time per hit | hit perc | time perc | mem inc | mem peak | gpu mem inc | gpu mem peak | sources |
|
26 |
+
+--------+------+--------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
|
27 |
+
| 41 | 1 | 1175.7406 ms | 1175.7406 ms | - | - | 0.777 MB | baseline | 12.000 MB | baseline | def gaussiansplatting_render(sigmas, coords, colors, image_size): |
|
28 |
+
| 42 | 1 | 0.0304 ms | 0.0304 ms | 14.286% | 0.003% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | sigmas = sigmas.contiguous() # (gs num, 3) |
|
29 |
+
| 43 | 1 | 0.0069 ms | 0.0069 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | coords = coords.contiguous() # (gs num, 2) |
|
30 |
+
| 44 | 1 | 0.0064 ms | 0.0064 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | colors = colors.contiguous() # (gs num, c) |
|
31 |
+
| 45 | 1 | 0.0065 ms | 0.0065 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | h, w = image_size[:2] |
|
32 |
+
| 46 | 1 | 0.0099 ms | 0.0099 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | c = colors.shape[-1] |
|
33 |
+
| 47 | 1 | 1.2594 ms | 1.2594 ms | 14.286% | 0.107% | 0.133 MB | 0.133 MB | 3.000 MB | 3.000 MB | rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
|
34 |
+
| 48 | 1 | 1174.4211 ms | 1174.4211 ms | 14.286% | 99.888% | 0.645 MB | 0.777 MB | 0.000 MB | 3.000 MB | return GSCUDA.apply(sigmas, coords, colors, rendered_img) |
|
35 |
+
+--------+------+--------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
|
36 |
+
1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
|
37 |
+
2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
|
38 |
+
3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
|
39 |
+
4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
|
40 |
+
args - base=ms, cuda_sync=True, gpuids=(0,), warmup=0
|
41 |
+
+--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
|
42 |
+
| lineno | hits | time | time per hit | hit perc | time perc | mem inc | mem peak | gpu mem inc | gpu mem peak | sources |
|
43 |
+
+--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
|
44 |
+
| 41 | 10 | 11844.9229 ms | 1184.4923 ms | - | - | 20.227 MB | baseline | 15.000 MB | baseline | def gaussiansplatting_render(sigmas, coords, colors, image_size): |
|
45 |
+
| 42 | 10 | 0.1342 ms | 0.0134 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | sigmas = sigmas.contiguous() # (gs num, 3) |
|
46 |
+
| 43 | 10 | 0.0654 ms | 0.0065 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | coords = coords.contiguous() # (gs num, 2) |
|
47 |
+
| 44 | 10 | 0.0618 ms | 0.0062 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | colors = colors.contiguous() # (gs num, c) |
|
48 |
+
| 45 | 10 | 0.0710 ms | 0.0071 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | h, w = image_size[:2] |
|
49 |
+
| 46 | 10 | 0.0803 ms | 0.0080 ms | 14.286% | 0.001% | 0.062 MB | 0.062 MB | 0.000 MB | 0.000 MB | c = colors.shape[-1] |
|
50 |
+
| 47 | 10 | 7.2555 ms | 0.7256 ms | 14.286% | 0.061% | 19.105 MB | 19.168 MB | 30.000 MB | 30.000 MB | rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
|
51 |
+
| 48 | 10 | 11837.2547 ms | 1183.7255 ms | 14.286% | 99.935% | 1.059 MB | 20.227 MB | 0.000 MB | 30.000 MB | return GSCUDA.apply(sigmas, coords, colors, rendered_img) |
|
52 |
+
+--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
|
53 |
+
1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
|
54 |
+
2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
|
55 |
+
3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
|
56 |
+
4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
|
57 |
+
args - base=ms, cuda_sync=True, gpuids=(0,), warmup=0
|
58 |
+
+--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
|
59 |
+
| lineno | hits | time | time per hit | hit perc | time perc | mem inc | mem peak | gpu mem inc | gpu mem peak | sources |
|
60 |
+
+--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
|
61 |
+
| 41 | 10 | 11855.0900 ms | 1185.5090 ms | - | - | 20.242 MB | baseline | 15.000 MB | baseline | def gaussiansplatting_render(sigmas, coords, colors, image_size): |
|
62 |
+
| 42 | 10 | 0.1263 ms | 0.0126 ms | 14.286% | 0.001% | 0.078 MB | 0.078 MB | 0.000 MB | 0.000 MB | sigmas = sigmas.contiguous() # (gs num, 3) |
|
63 |
+
| 43 | 10 | 0.0632 ms | 0.0063 ms | 14.286% | 0.001% | 0.000 MB | 0.078 MB | 0.000 MB | 0.000 MB | coords = coords.contiguous() # (gs num, 2) |
|
64 |
+
| 44 | 10 | 0.0588 ms | 0.0059 ms | 14.286% | 0.000% | 0.000 MB | 0.078 MB | 0.000 MB | 0.000 MB | colors = colors.contiguous() # (gs num, c) |
|
65 |
+
| 45 | 10 | 0.0626 ms | 0.0063 ms | 14.286% | 0.001% | 0.000 MB | 0.078 MB | 0.000 MB | 0.000 MB | h, w = image_size[:2] |
|
66 |
+
| 46 | 10 | 0.0747 ms | 0.0075 ms | 14.286% | 0.001% | 0.000 MB | 0.078 MB | 0.000 MB | 0.000 MB | c = colors.shape[-1] |
|
67 |
+
| 47 | 10 | 7.0497 ms | 0.7050 ms | 14.286% | 0.059% | 19.078 MB | 19.156 MB | 30.000 MB | 30.000 MB | rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
|
68 |
+
| 48 | 10 | 11847.6547 ms | 1184.7655 ms | 14.286% | 99.937% | 0.820 MB | 19.977 MB | 0.000 MB | 30.000 MB | return GSCUDA.apply(sigmas, coords, colors, rendered_img) |
|
69 |
+
+--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
|
utils/gs_cuda/profile.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.utils import save_image
|
6 |
+
from gswrapper import gaussiansplatting_render
|
7 |
+
|
8 |
+
def generate_2D_gaussian_splatting(kernel_size, sigma_x, sigma_y, rho, coords,
|
9 |
+
colours, image_size=(256, 256, 3), device="cuda"):
|
10 |
+
|
11 |
+
batch_size = colours.shape[0]
|
12 |
+
|
13 |
+
sigma_x = sigma_x.view(batch_size, 1, 1)
|
14 |
+
sigma_y = sigma_y.view(batch_size, 1, 1)
|
15 |
+
rho = rho.view(batch_size, 1, 1)
|
16 |
+
|
17 |
+
covariance = torch.stack(
|
18 |
+
[torch.stack([sigma_x**2, rho*sigma_x*sigma_y], dim=-1),
|
19 |
+
torch.stack([rho*sigma_x*sigma_y, sigma_y**2], dim=-1)],
|
20 |
+
dim=-2
|
21 |
+
)
|
22 |
+
|
23 |
+
# Check for positive semi-definiteness
|
24 |
+
# determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y)**2
|
25 |
+
# if (determinant <= 0).any():
|
26 |
+
# raise ValueError("Covariance matrix must be positive semi-definite")
|
27 |
+
|
28 |
+
inv_covariance = torch.inverse(covariance)
|
29 |
+
|
30 |
+
# Choosing quite a broad range for the distribution [-5,5] to avoid any clipping
|
31 |
+
start = torch.tensor([-5.0], device=device).view(-1, 1)
|
32 |
+
end = torch.tensor([5.0], device=device).view(-1, 1)
|
33 |
+
base_linspace = torch.linspace(0, 1, steps=kernel_size, device=device)
|
34 |
+
ax_batch = start + (end - start) * base_linspace
|
35 |
+
|
36 |
+
# Expanding dims for broadcasting
|
37 |
+
ax_batch_expanded_x = ax_batch.unsqueeze(-1).expand(-1, -1, kernel_size)
|
38 |
+
ax_batch_expanded_y = ax_batch.unsqueeze(1).expand(-1, kernel_size, -1)
|
39 |
+
|
40 |
+
# Creating a batch-wise meshgrid using broadcasting
|
41 |
+
xx, yy = ax_batch_expanded_x, ax_batch_expanded_y # (batchsize, kernelsize, kernelsize)
|
42 |
+
|
43 |
+
xy = torch.stack([xx, yy], dim=-1) # (batchsize, kernelsize, kernelsize, 2)
|
44 |
+
z = torch.einsum('b...i,b...ij,b...j->b...', xy, -0.5 * inv_covariance, xy) # (batchsize, kernelsize, kernelsize, 2)
|
45 |
+
kernel = torch.exp(z) / (2 * torch.tensor(np.pi, device=device) * torch.sqrt(torch.det(covariance)).view(batch_size, 1, 1)) # (batchsize, kernelsize, kernelsize)
|
46 |
+
|
47 |
+
|
48 |
+
kernel_max_1, _ = kernel.max(dim=-1, keepdim=True) # Find max along the last dimension
|
49 |
+
kernel_max_2, _ = kernel_max_1.max(dim=-2, keepdim=True) # Find max along the second-to-last dimension
|
50 |
+
kernel_normalized = kernel / kernel_max_2 # (batchsize, kernelsize, kernelsize)
|
51 |
+
|
52 |
+
|
53 |
+
kernel_reshaped = kernel_normalized.repeat(1, 3, 1).view(batch_size * 3, kernel_size, kernel_size)
|
54 |
+
kernel_rgb = kernel_reshaped.unsqueeze(0).reshape(batch_size, 3, kernel_size, kernel_size) # (batchsize, 3, kernelsize, kernelsize)
|
55 |
+
|
56 |
+
# Calculating the padding needed to match the image size
|
57 |
+
pad_h = image_size[0] - kernel_size
|
58 |
+
pad_w = image_size[1] - kernel_size
|
59 |
+
|
60 |
+
if pad_h < 0 or pad_w < 0:
|
61 |
+
raise ValueError("Kernel size should be smaller or equal to the image size.")
|
62 |
+
|
63 |
+
# Adding padding to make kernel size equal to the image size
|
64 |
+
padding = (pad_w // 2, pad_w // 2 + pad_w % 2, # padding left and right
|
65 |
+
pad_h // 2, pad_h // 2 + pad_h % 2) # padding top and bottom
|
66 |
+
|
67 |
+
kernel_rgb_padded = torch.nn.functional.pad(kernel_rgb, padding, "constant", 0) # (batchsize, 3, h, w)
|
68 |
+
|
69 |
+
# Extracting shape information
|
70 |
+
b, c, h, w = kernel_rgb_padded.shape
|
71 |
+
|
72 |
+
# Create a batch of 2D affine matrices
|
73 |
+
theta = torch.zeros(b, 2, 3, dtype=torch.float32, device=device)
|
74 |
+
theta[:, 0, 0] = 1.0
|
75 |
+
theta[:, 1, 1] = 1.0
|
76 |
+
theta[:, :, 2] = -coords # (b, 2) - the offset of gaussian splating
|
77 |
+
|
78 |
+
# Creating grid and performing grid sampling
|
79 |
+
grid = F.affine_grid(theta, size=(b, c, h, w), align_corners=True) # (b, 3, h, w)
|
80 |
+
# grid_y = torch.linspace(-1, 1, steps=h, device=device).reshape(1, h, 1, 1).repeat(1, 1, w, 1)
|
81 |
+
# grid_x = torch.linspace(-1, 1, steps=w, device=device).reshape(1, 1, w, 1).repeat(1, h, 1, 1)
|
82 |
+
# grid = torch.cat([grid_x, grid_y], dim=-1)
|
83 |
+
# grid = grid - coords.reshape(-1, 1, 1, 2)
|
84 |
+
|
85 |
+
kernel_rgb_padded_translated = F.grid_sample(kernel_rgb_padded, grid, align_corners=True) # (b, 3, h, w)
|
86 |
+
|
87 |
+
rgb_values_reshaped = colours.unsqueeze(-1).unsqueeze(-1)
|
88 |
+
|
89 |
+
final_image_layers = rgb_values_reshaped * kernel_rgb_padded_translated
|
90 |
+
final_image = final_image_layers.sum(dim=0)
|
91 |
+
# final_image = torch.clamp(final_image, 0, 1)
|
92 |
+
final_image = final_image.permute(1,2,0)
|
93 |
+
|
94 |
+
return final_image
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
from mylineprofiler import MyLineProfiler
|
99 |
+
profiler_th = MyLineProfiler(cuda_sync=True)
|
100 |
+
generate_2D_gaussian_splatting = profiler_th.decorate(generate_2D_gaussian_splatting)
|
101 |
+
profiler_cuda = MyLineProfiler(cuda_sync=True)
|
102 |
+
gaussiansplatting_render = profiler_cuda.decorate(gaussiansplatting_render)
|
103 |
+
|
104 |
+
|
105 |
+
# --- test ---
|
106 |
+
s = int(512 * 512)
|
107 |
+
# s = 5
|
108 |
+
image_size = (512, 512, 3)
|
109 |
+
|
110 |
+
sigmas = 0.2*torch.rand(s, 3).to(torch.float32).to("cuda")
|
111 |
+
sigmas[:,:2] = 5*sigmas[:, :2]
|
112 |
+
coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
|
113 |
+
colors = torch.rand(s, 3).to(torch.float32).to("cuda")
|
114 |
+
|
115 |
+
# --- torch version ---
|
116 |
+
import gc
|
117 |
+
# gc.collect()
|
118 |
+
# torch.cuda.empty_cache()
|
119 |
+
# for _ in range(1):
|
120 |
+
# img_python = generate_2D_gaussian_splatting(128, sigmas[:,1], sigmas[:,0], sigmas[:,2], coords, colors, image_size)
|
121 |
+
# profiler_th.print("profile.log", "w")
|
122 |
+
# cv2.imwrite("th.png", 255.0*img_python.detach().clamp(0,1).cpu().numpy())
|
123 |
+
# --- ends ---
|
124 |
+
|
125 |
+
# --- cuda version ---
|
126 |
+
sigmas[:, 0] = sigmas[:, 0]
|
127 |
+
sigmas[:, 1] = sigmas[:, 1]
|
128 |
+
gc.collect()
|
129 |
+
torch.cuda.empty_cache()
|
130 |
+
for _ in range(10):
|
131 |
+
with torch.no_grad():
|
132 |
+
img_cuda = gaussiansplatting_render(sigmas, coords, colors, image_size)
|
133 |
+
|
134 |
+
profiler_cuda.print("profile.log", "a")
|
135 |
+
cv2.imwrite("cuda.png", 255.0*img_cuda.detach().clamp(0,1).cpu().numpy())
|
136 |
+
# --- ends ---
|
137 |
+
pass
|
utils/gs_cuda_dmax/__init__.py
ADDED
File without changes
|
utils/gs_cuda_dmax/check.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from gswrapper import gaussiansplatting_render
|
3 |
+
|
4 |
+
def torch_version(sigmas, coords, colors, image_size, dmax=100):
|
5 |
+
h, w = image_size
|
6 |
+
c = colors.shape[-1]
|
7 |
+
|
8 |
+
if h >= 50 or w >= 50:
|
9 |
+
logger.warning(f'too large values for h({h}), w({w}), torch version would be slow')
|
10 |
+
|
11 |
+
rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
|
12 |
+
|
13 |
+
for hi in range(h):
|
14 |
+
for wi in range(w):
|
15 |
+
curh = 2*hi/(h-1)-1.0
|
16 |
+
curw = 2*wi/(w-1)-1.0
|
17 |
+
|
18 |
+
v = (curw-coords[:,0])**2/sigmas[:,0]**2
|
19 |
+
v -= (2*sigmas[:,2])*(curw-coords[:,0])*(curh-coords[:,1])/sigmas[:,0]/sigmas[:,1]
|
20 |
+
v += (curh-coords[:,1])**2/sigmas[:,1]**2
|
21 |
+
v *= -1.0/(2.0*(1-sigmas[:,2]**2))
|
22 |
+
v = torch.exp(v)
|
23 |
+
|
24 |
+
mask_w = abs(curw-coords[:,0]) <= dmax
|
25 |
+
mask_h = abs(curh-coords[:,1]) <= dmax
|
26 |
+
mask = torch.logical_and(mask_w, mask_h)
|
27 |
+
|
28 |
+
for ci in range(c):
|
29 |
+
rendered_img[hi, wi, ci] = torch.sum((v*colors[:, ci])[mask])
|
30 |
+
|
31 |
+
return rendered_img
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
s = 4 # the number of gs
|
36 |
+
image_size = (10, 10)
|
37 |
+
|
38 |
+
for _ in range(1):
|
39 |
+
print(f"--------------------------- begins --------------------------------")
|
40 |
+
|
41 |
+
sigmas = 0.999*torch.rand(s, 3).to(torch.float32).to("cuda")
|
42 |
+
sigmas[:,:2] = 5*sigmas[:, :2]
|
43 |
+
coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
|
44 |
+
colors = torch.rand(s, 3).to(torch.float32).to("cuda")
|
45 |
+
# colors = torch.rand(s, 5).to(torch.float32).to("cuda")
|
46 |
+
dmax = 0.5
|
47 |
+
|
48 |
+
# sigmas = torch.Tensor([[0.9196, 0.3979, 0.7784]]).to(torch.float32).to("cuda")
|
49 |
+
# coords = torch.Tensor([[-0.0469, -0.1726]]).to(torch.float32).to("cuda")
|
50 |
+
# colors = torch.Tensor([[0.3775, 0.2346, 0.1513]]).to(torch.float32).to("cuda")
|
51 |
+
# colors = torch.ones_like(coords[:,0:1])
|
52 |
+
|
53 |
+
print(f"sigmas: {sigmas}, \ncoords:{coords}, \ncolors:{colors}\ndmax:{dmax}")
|
54 |
+
|
55 |
+
# --- check forward ---
|
56 |
+
with torch.no_grad():
|
57 |
+
rendered_img_th = torch_version(sigmas,coords,colors,image_size,dmax)
|
58 |
+
rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size,dmax)
|
59 |
+
|
60 |
+
#
|
61 |
+
distance = (rendered_img_th-rendered_img_cuda)**2
|
62 |
+
print(f"check forward - torch: {rendered_img_th[:2,:2,0]}")
|
63 |
+
print(f"check forward - cuda: {rendered_img_cuda[:2,:2,0]}")
|
64 |
+
print(f"check forward - distance: {distance[:2, :2, 0]}")
|
65 |
+
print(f"check forward - sum: {torch.sum(distance)}\n")
|
66 |
+
# --- ends ---
|
67 |
+
|
68 |
+
# --- check backward ---
|
69 |
+
sigmas.requires_grad_(True)
|
70 |
+
coords.requires_grad_(True)
|
71 |
+
colors.requires_grad_(True)
|
72 |
+
# sigmas.retain_grad()
|
73 |
+
# coords.retain_grad()
|
74 |
+
# colors.retain_grad()
|
75 |
+
weight = torch.rand_like(rendered_img_th) # make each pixel has different grads
|
76 |
+
|
77 |
+
sigmas.grad = None
|
78 |
+
coords.grad = None
|
79 |
+
colors.grad = None
|
80 |
+
rendered_img_th = torch_version(sigmas,coords,colors,image_size,dmax)
|
81 |
+
loss_th = torch.sum(weight*rendered_img_th)
|
82 |
+
# loss_th = torch.sum(rendered_img_th)
|
83 |
+
loss_th.backward()
|
84 |
+
|
85 |
+
sigmas_grad_th = sigmas.grad
|
86 |
+
coords_grad_th = coords.grad
|
87 |
+
colors_grad_th = colors.grad
|
88 |
+
|
89 |
+
sigmas.grad = None
|
90 |
+
coords.grad = None
|
91 |
+
colors.grad = None
|
92 |
+
rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size,dmax)
|
93 |
+
loss_cuda = torch.sum(weight*rendered_img_cuda)
|
94 |
+
# loss_cuda = torch.sum(rendered_img_cuda)
|
95 |
+
loss_cuda.backward()
|
96 |
+
|
97 |
+
sigmas_grad_cuda = sigmas.grad
|
98 |
+
coords_grad_cuda = coords.grad
|
99 |
+
colors_grad_cuda = colors.grad
|
100 |
+
|
101 |
+
distance_sigmas_grad = (sigmas_grad_th-sigmas_grad_cuda)**2
|
102 |
+
distance_coords_grad = (coords_grad_th-coords_grad_cuda)**2
|
103 |
+
distance_colors_grad = (colors_grad_th-colors_grad_cuda)**2
|
104 |
+
|
105 |
+
print(f"check backward - sigmas - torch: {sigmas_grad_th[:2]}")
|
106 |
+
print(f"check backward - sigmas - cuda: {sigmas_grad_cuda[:2]}")
|
107 |
+
print(f"check backward - sigmas - distance: {distance_sigmas_grad[:2]}")
|
108 |
+
print(f"check backward - sigmas - sum: {torch.sum(distance_sigmas_grad)}\n")
|
109 |
+
|
110 |
+
print(f"check backward - coords - torch: {coords_grad_th[:2]}")
|
111 |
+
print(f"check backward - coords - cuda: {coords_grad_cuda[:2]}")
|
112 |
+
print(f"check backward - coords - distance: {distance_coords_grad[:2]}")
|
113 |
+
print(f"check backward - coords - sum: {torch.sum(distance_coords_grad)}\n")
|
114 |
+
|
115 |
+
print(f"check backward - colors - torch: {colors_grad_th[:2]}")
|
116 |
+
print(f"check backward - colors - cuda: {colors_grad_cuda[:2]}")
|
117 |
+
print(f"check backward - colors - distance: {distance_colors_grad[:2]}")
|
118 |
+
print(f"check backward - colors - sum: {torch.sum(distance_colors_grad)}\n")
|
119 |
+
|
120 |
+
print(f"--------------------------- ends --------------------------------\n\n")
|
121 |
+
|
122 |
+
|
utils/gs_cuda_dmax/gs copy.cu
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <cmath>
|
3 |
+
|
4 |
+
#define PI 3.1415926536
|
5 |
+
#define PI2 6.283153072
|
6 |
+
|
7 |
+
__global__ void _gs_render_cuda(
|
8 |
+
const float *sigmas,
|
9 |
+
const float *coords,
|
10 |
+
const float *colors,
|
11 |
+
float *rendered_img,
|
12 |
+
const int s, // gs num
|
13 |
+
const int h,
|
14 |
+
const int w,
|
15 |
+
const int c,
|
16 |
+
const float dmax
|
17 |
+
){
|
18 |
+
|
19 |
+
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
20 |
+
int curw = index % w;
|
21 |
+
int curh = int((index-curw)/w);
|
22 |
+
if(curw >= w || curh >=h){
|
23 |
+
return;
|
24 |
+
}
|
25 |
+
|
26 |
+
float curw_f = 2.0*curw/(w-1) - 1.0;
|
27 |
+
float curh_f = 2.0*curh/(h-1) - 1.0;
|
28 |
+
|
29 |
+
// printf("index:%d, curw:%d, curh:%d, curw_f:%f, curh_f:%f\n",index,curw,curh,curw_f,curh_f);
|
30 |
+
|
31 |
+
for(int si=0; si<s; si++){
|
32 |
+
|
33 |
+
// compute the 2d gs value
|
34 |
+
float sigma_x = sigmas[si*3+0];
|
35 |
+
float sigma_y = sigmas[si*3+1];
|
36 |
+
float rho = sigmas[si*3+2];
|
37 |
+
float x = coords[si*2+0];
|
38 |
+
float y = coords[si*2+1];
|
39 |
+
|
40 |
+
//
|
41 |
+
float one_div_one_minus_rho2 = 1.0 / (1-rho*rho) ;
|
42 |
+
float one_div_sigma_x = 1.0 / sigma_x;
|
43 |
+
float one_div_sigma_y = 1.0 / sigma_y;
|
44 |
+
float d_x = curw_f - x;
|
45 |
+
float d_y = curh_f - y;
|
46 |
+
|
47 |
+
if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
|
48 |
+
continue;
|
49 |
+
}
|
50 |
+
|
51 |
+
float v = one_div_sigma_x*one_div_sigma_x*d_x*d_x;
|
52 |
+
v -= 2*rho*d_x*d_y*one_div_sigma_x*one_div_sigma_y;
|
53 |
+
v += d_y*d_y*one_div_sigma_y*one_div_sigma_y;
|
54 |
+
v *= -one_div_one_minus_rho2 / 2.0;
|
55 |
+
v = exp(v);
|
56 |
+
// since we normlize the v with the max, we remove this step to obtain equal result
|
57 |
+
// v *= one_div_sigma_x * one_div_sigma_y * pow(one_div_one_minus_rho2, 0.5) / PI2 ;
|
58 |
+
// printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
|
59 |
+
|
60 |
+
for(int ci=0; ci<c; ci++){
|
61 |
+
rendered_img[(curh*w+curw)*c+ci] += v*colors[si*c+ci];
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
void _gs_render(
|
69 |
+
const float *sigmas,
|
70 |
+
const float *coords,
|
71 |
+
const float *colors,
|
72 |
+
float *rendered_img,
|
73 |
+
const int s,
|
74 |
+
const int h,
|
75 |
+
const int w,
|
76 |
+
const int c,
|
77 |
+
const float dmax
|
78 |
+
) {
|
79 |
+
|
80 |
+
int threads=16;
|
81 |
+
dim3 grid( h*w, 1);
|
82 |
+
dim3 block( threads, 1);
|
83 |
+
_gs_render_cuda<<<grid, block>>>(sigmas, coords, colors, rendered_img, s, h, w, c, dmax);
|
84 |
+
}
|
85 |
+
|
86 |
+
|
87 |
+
__global__ void _gs_render_backward_cuda(
|
88 |
+
const float *sigmas,
|
89 |
+
const float *coords,
|
90 |
+
const float *colors,
|
91 |
+
const float *grads,
|
92 |
+
float *grads_sigmas,
|
93 |
+
float *grads_coords,
|
94 |
+
float *grads_colors,
|
95 |
+
const int s, // gs num
|
96 |
+
const int h,
|
97 |
+
const int w,
|
98 |
+
const int c,
|
99 |
+
const float dmax
|
100 |
+
|
101 |
+
){
|
102 |
+
|
103 |
+
int curs = blockIdx.x*blockDim.x + threadIdx.x;
|
104 |
+
if(curs >= s){
|
105 |
+
return ;
|
106 |
+
}
|
107 |
+
|
108 |
+
// obtain parameters of gs
|
109 |
+
float sigma_x = sigmas[curs*3+0];
|
110 |
+
float sigma_y = sigmas[curs*3+1];
|
111 |
+
float rho = sigmas[curs*3+2];
|
112 |
+
float x = coords[curs*2+0];
|
113 |
+
float y = coords[curs*2+1];
|
114 |
+
float cr = colors[curs*3+0];
|
115 |
+
float cg = colors[curs*3+1];
|
116 |
+
float cb = colors[curs*3+2];
|
117 |
+
|
118 |
+
//
|
119 |
+
float w1 = -0.5 / (1-rho*rho) ;
|
120 |
+
float w2 = 1.0 / (sigma_x*sigma_x);
|
121 |
+
float w3 = 1.0 / (sigma_x*sigma_y);
|
122 |
+
float w4 = 1.0 / (sigma_y*sigma_y);
|
123 |
+
float od_sx = 1.0 / sigma_x;
|
124 |
+
float od_sy = 1.0 / sigma_y;
|
125 |
+
|
126 |
+
// init
|
127 |
+
float _gr=0.0, _gg=0.0, _gb=0.0;
|
128 |
+
float _gx=0.0, _gy=0.0;
|
129 |
+
float _gsx=0.0, _gsy=0.0, _gsr=0.0;
|
130 |
+
|
131 |
+
for(int hi = 0; hi < h; hi++){
|
132 |
+
for( int wi=0; wi < w; wi++){
|
133 |
+
|
134 |
+
float curw_f = 2.0*wi/(w-1) - 1.0;
|
135 |
+
float curh_f = 2.0*hi/(h-1) - 1.0;
|
136 |
+
|
137 |
+
// obtain grad to p^t_r, p^t_g, p^t_b
|
138 |
+
float gptr = grads[(hi*w+wi)*c+0]; // grad of loss to P^t_r
|
139 |
+
float gptg = grads[(hi*w+wi)*c+1];
|
140 |
+
float gptb = grads[(hi*w+wi)*c+2];
|
141 |
+
|
142 |
+
// compute the 2d gs value
|
143 |
+
|
144 |
+
float d_x = curw_f - x; // distance along x axis
|
145 |
+
float d_y = curh_f - y;
|
146 |
+
// if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
|
147 |
+
// continue;
|
148 |
+
// }
|
149 |
+
// printf("here");
|
150 |
+
|
151 |
+
float d = w2*d_x*d_x - 2*rho*w3*d_x*d_y + w4*d_y*d_y;
|
152 |
+
float v = w1*d;
|
153 |
+
v = exp(v);
|
154 |
+
// printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
|
155 |
+
|
156 |
+
// compute grad of colors
|
157 |
+
_gr += v*gptr;
|
158 |
+
_gg += v*gptg;
|
159 |
+
_gb += v*gptb;
|
160 |
+
|
161 |
+
// compute grad of coords
|
162 |
+
float gpt = gptr*cr+gptg*cg+gptb*cb;
|
163 |
+
float v_2_w1 = v*2*w1;
|
164 |
+
|
165 |
+
float g_vst_to_gsx = v_2_w1*(-w2*d_x+rho*w3*d_y); // grad of v^{st} to G^s_x
|
166 |
+
_gx += gpt*g_vst_to_gsx;
|
167 |
+
float g_vst_to_gsy = v_2_w1*(-w4*d_y+rho*w3*d_x); // grad of v^{st} to G^s_y
|
168 |
+
_gy += gpt*g_vst_to_gsy;
|
169 |
+
|
170 |
+
// compute grad of sigmas
|
171 |
+
float g_vst_to_gsigx = v_2_w1*od_sx* (w3*rho*d_x*d_y - w2*d_x*d_x);
|
172 |
+
_gsx += gpt*g_vst_to_gsigx;
|
173 |
+
float g_vst_to_gsigy = v_2_w1*od_sy* (w3*rho*d_x*d_y - w4*d_y*d_y);
|
174 |
+
_gsy += gpt*g_vst_to_gsigy;
|
175 |
+
float g_vst_to_rho = -v_2_w1*(2*w1*rho*d+w3*d_x*d_y);
|
176 |
+
_gsr += gpt*g_vst_to_rho;
|
177 |
+
}
|
178 |
+
}
|
179 |
+
|
180 |
+
// write the values
|
181 |
+
grads_sigmas[curs*3+0] = _gsx;
|
182 |
+
grads_sigmas[curs*3+1] = _gsy;
|
183 |
+
grads_sigmas[curs*3+2] = _gsr;
|
184 |
+
grads_coords[curs*2+0] = _gx;
|
185 |
+
grads_coords[curs*2+1] = _gy;
|
186 |
+
grads_colors[curs*3+0] = _gr;
|
187 |
+
grads_colors[curs*3+1] = _gg;
|
188 |
+
grads_colors[curs*3+2] = _gb;
|
189 |
+
|
190 |
+
}
|
191 |
+
|
192 |
+
void _gs_render_backward(
|
193 |
+
const float *sigmas,
|
194 |
+
const float *coords,
|
195 |
+
const float *colors,
|
196 |
+
const float *grads, // (h, w, c)
|
197 |
+
float *grads_sigmas,
|
198 |
+
float *grads_coords,
|
199 |
+
float *grads_colors,
|
200 |
+
const int s,
|
201 |
+
const int h,
|
202 |
+
const int w,
|
203 |
+
const int c,
|
204 |
+
const float dmax
|
205 |
+
) {
|
206 |
+
|
207 |
+
int threads=16;
|
208 |
+
dim3 grid(s, 1);
|
209 |
+
dim3 block( threads, 1);
|
210 |
+
_gs_render_backward_cuda<<<grid, block>>>(sigmas, coords, colors, grads, grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax);
|
211 |
+
}
|
212 |
+
|
utils/gs_cuda_dmax/gs.backup.cu
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <cmath>
|
3 |
+
|
4 |
+
#define PI 3.1415926536
|
5 |
+
#define PI2 6.283153072
|
6 |
+
|
7 |
+
__global__ void _gs_render_cuda(
|
8 |
+
const float *sigmas,
|
9 |
+
const float *coords,
|
10 |
+
const float *colors,
|
11 |
+
float *rendered_img,
|
12 |
+
const int s, // gs num
|
13 |
+
const int h,
|
14 |
+
const int w,
|
15 |
+
const int c,
|
16 |
+
const float dmax
|
17 |
+
){
|
18 |
+
|
19 |
+
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
20 |
+
int curw = index % w;
|
21 |
+
int curh = int((index-curw)/w);
|
22 |
+
if(curw >= w || curh >=h){
|
23 |
+
return;
|
24 |
+
}
|
25 |
+
|
26 |
+
float curw_f = 2.0*curw/(w-1) - 1.0;
|
27 |
+
float curh_f = 2.0*curh/(h-1) - 1.0;
|
28 |
+
|
29 |
+
// printf("index:%d, curw:%d, curh:%d, curw_f:%f, curh_f:%f\n",index,curw,curh,curw_f,curh_f);
|
30 |
+
|
31 |
+
for(int si=0; si<s; si++){
|
32 |
+
|
33 |
+
// compute the 2d gs value
|
34 |
+
float sigma_x = sigmas[si*3+0];
|
35 |
+
float sigma_y = sigmas[si*3+1];
|
36 |
+
float rho = sigmas[si*3+2];
|
37 |
+
float x = coords[si*2+0];
|
38 |
+
float y = coords[si*2+1];
|
39 |
+
|
40 |
+
//
|
41 |
+
float one_div_one_minus_rho2 = 1.0 / (1-rho*rho) ;
|
42 |
+
float one_div_sigma_x = 1.0 / sigma_x;
|
43 |
+
float one_div_sigma_y = 1.0 / sigma_y;
|
44 |
+
float d_x = curw_f - x;
|
45 |
+
float d_y = curh_f - y;
|
46 |
+
|
47 |
+
if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
|
48 |
+
continue;
|
49 |
+
}
|
50 |
+
|
51 |
+
float v = one_div_sigma_x*one_div_sigma_x*d_x*d_x;
|
52 |
+
v -= 2*rho*d_x*d_y*one_div_sigma_x*one_div_sigma_y;
|
53 |
+
v += d_y*d_y*one_div_sigma_y*one_div_sigma_y;
|
54 |
+
v *= -one_div_one_minus_rho2 / 2.0;
|
55 |
+
v = exp(v);
|
56 |
+
// since we normlize the v with the max, we remove this step to obtain equal result
|
57 |
+
// v *= one_div_sigma_x * one_div_sigma_y * pow(one_div_one_minus_rho2, 0.5) / PI2 ;
|
58 |
+
// printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
|
59 |
+
|
60 |
+
for(int ci=0; ci<c; ci++){
|
61 |
+
rendered_img[(curh*w+curw)*c+ci] += v*colors[si*c+ci];
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
void _gs_render(
|
69 |
+
const float *sigmas,
|
70 |
+
const float *coords,
|
71 |
+
const float *colors,
|
72 |
+
float *rendered_img,
|
73 |
+
const int s,
|
74 |
+
const int h,
|
75 |
+
const int w,
|
76 |
+
const int c,
|
77 |
+
const float dmax
|
78 |
+
) {
|
79 |
+
|
80 |
+
int threads=16;
|
81 |
+
dim3 grid( h*w, 1);
|
82 |
+
dim3 block( threads, 1);
|
83 |
+
_gs_render_cuda<<<grid, block>>>(sigmas, coords, colors, rendered_img, s, h, w, c, dmax);
|
84 |
+
}
|
85 |
+
|
86 |
+
__global__ void _gs_render_backward_cuda(
|
87 |
+
const float *sigmas,
|
88 |
+
const float *coords,
|
89 |
+
const float *colors,
|
90 |
+
const float *grads,
|
91 |
+
float *grads_sigmas,
|
92 |
+
float *grads_coords,
|
93 |
+
float *grads_colors,
|
94 |
+
const int s, // gs num
|
95 |
+
const int h,
|
96 |
+
const int w,
|
97 |
+
const int c,
|
98 |
+
const float dmax
|
99 |
+
){
|
100 |
+
|
101 |
+
int curs = blockIdx.x*blockDim.x + threadIdx.x;
|
102 |
+
if(curs >= s){
|
103 |
+
return ;
|
104 |
+
}
|
105 |
+
|
106 |
+
// obtain parameters of gs
|
107 |
+
float sigma_x = sigmas[curs*3+0];
|
108 |
+
float sigma_y = sigmas[curs*3+1];
|
109 |
+
float rho = sigmas[curs*3+2];
|
110 |
+
float x = coords[curs*2+0];
|
111 |
+
float y = coords[curs*2+1];
|
112 |
+
|
113 |
+
//
|
114 |
+
float w1 = -0.5 / (1-rho*rho) ;
|
115 |
+
float w2 = 1.0 / (sigma_x*sigma_x);
|
116 |
+
float w3 = 1.0 / (sigma_x*sigma_y);
|
117 |
+
float w4 = 1.0 / (sigma_y*sigma_y);
|
118 |
+
float od_sx = 1.0 / sigma_x;
|
119 |
+
float od_sy = 1.0 / sigma_y;
|
120 |
+
|
121 |
+
// init
|
122 |
+
for(int hi = 0; hi < h; hi++){
|
123 |
+
for( int wi=0; wi < w; wi++){
|
124 |
+
|
125 |
+
float curw_f = 2.0*wi/(w-1) - 1.0;
|
126 |
+
float curh_f = 2.0*hi/(h-1) - 1.0;
|
127 |
+
|
128 |
+
// compute the 2d gs value
|
129 |
+
float d_x = curw_f - x; // distance along x axis
|
130 |
+
float d_y = curh_f - y;
|
131 |
+
if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
|
132 |
+
continue;
|
133 |
+
}
|
134 |
+
float d = w2*d_x*d_x - 2*rho*w3*d_x*d_y + w4*d_y*d_y;
|
135 |
+
float v = w1*d;
|
136 |
+
v = exp(v);
|
137 |
+
// printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
|
138 |
+
|
139 |
+
// compute grad of coords
|
140 |
+
float v_2_w1 = v*2*w1;
|
141 |
+
float g_vst_to_gsx = v_2_w1*(-w2*d_x+rho*w3*d_y); // grad of v^{st} to G^s_x
|
142 |
+
float g_vst_to_gsy = v_2_w1*(-w4*d_y+rho*w3*d_x); // grad of v^{st} to G^s_y
|
143 |
+
|
144 |
+
// compute grad of sigmas
|
145 |
+
float g_vst_to_gsigx = v_2_w1*od_sx* (w3*rho*d_x*d_y - w2*d_x*d_x);
|
146 |
+
float g_vst_to_gsigy = v_2_w1*od_sy* (w3*rho*d_x*d_y - w4*d_y*d_y);
|
147 |
+
float g_vst_to_rho = -v_2_w1*(2*w1*rho*d+w3*d_x*d_y);
|
148 |
+
|
149 |
+
for(int ci=0; ci<c; ci++){
|
150 |
+
float _gptc = grads[(hi*w+wi)*c+ci];
|
151 |
+
float _gpt = _gptc*colors[curs*c+ci];
|
152 |
+
|
153 |
+
grads_colors[curs*c+ci] += v*_gptc;
|
154 |
+
|
155 |
+
grads_coords[curs*2+0] += _gpt*g_vst_to_gsx;
|
156 |
+
grads_coords[curs*2+1] += _gpt*g_vst_to_gsy;
|
157 |
+
|
158 |
+
grads_sigmas[curs*3+0] += _gpt*g_vst_to_gsigx;
|
159 |
+
grads_sigmas[curs*3+1] += _gpt*g_vst_to_gsigy;
|
160 |
+
grads_sigmas[curs*3+2] += _gpt*g_vst_to_rho;
|
161 |
+
}
|
162 |
+
|
163 |
+
}
|
164 |
+
}
|
165 |
+
|
166 |
+
}
|
167 |
+
|
168 |
+
void _gs_render_backward(
|
169 |
+
const float *sigmas,
|
170 |
+
const float *coords,
|
171 |
+
const float *colors,
|
172 |
+
const float *grads, // (h, w, c)
|
173 |
+
float *grads_sigmas,
|
174 |
+
float *grads_coords,
|
175 |
+
float *grads_colors,
|
176 |
+
const int s,
|
177 |
+
const int h,
|
178 |
+
const int w,
|
179 |
+
const int c,
|
180 |
+
const float dmax
|
181 |
+
) {
|
182 |
+
|
183 |
+
int threads=16;
|
184 |
+
dim3 grid(s, 1);
|
185 |
+
dim3 block( threads, 1);
|
186 |
+
_gs_render_backward_cuda<<<grid, block>>>(sigmas, coords, colors, grads, grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax);
|
187 |
+
}
|
188 |
+
|
utils/gs_cuda_dmax/gs.cu
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <cmath>
|
3 |
+
|
4 |
+
#define PI 3.1415926536
|
5 |
+
#define PI2 6.283153072
|
6 |
+
|
7 |
+
__global__ void _gs_render_cuda(
|
8 |
+
const float *sigmas,
|
9 |
+
const float *coords,
|
10 |
+
const float *colors,
|
11 |
+
float *rendered_img,
|
12 |
+
const int s, // gs num
|
13 |
+
const int h,
|
14 |
+
const int w,
|
15 |
+
const int c,
|
16 |
+
const float dmax
|
17 |
+
){
|
18 |
+
|
19 |
+
int curs = blockIdx.x*blockDim.x + threadIdx.x;
|
20 |
+
if(curs >= s){
|
21 |
+
return;
|
22 |
+
}
|
23 |
+
|
24 |
+
float sigma_x = sigmas[curs*3+0];
|
25 |
+
float sigma_y = sigmas[curs*3+1];
|
26 |
+
float rho = sigmas[curs*3+2];
|
27 |
+
float x = coords[curs*2+0];
|
28 |
+
float y = coords[curs*2+1];
|
29 |
+
float r = colors[curs*3];
|
30 |
+
float g = colors[curs*3+1];
|
31 |
+
float b = colors[curs*3+2];
|
32 |
+
|
33 |
+
float negative_half_one_div_one_minus_rho2 = -0.5 / (1-rho*rho);
|
34 |
+
float one_div_sigma_x_2 = 1.0 / sigma_x / sigma_x;
|
35 |
+
float one_div_sigma_y_2 = 1.0 / sigma_y / sigma_y;
|
36 |
+
float two_rho_div_sigma_x_one_div_sigma_y = 2*rho / sigma_x / sigma_y;
|
37 |
+
|
38 |
+
for(int hi=0; hi<h; hi++){
|
39 |
+
float curh_f = 2.0*hi/(h-1) - 1.0;
|
40 |
+
float d_y = curh_f - y;
|
41 |
+
if(d_y > dmax || d_y < -dmax){
|
42 |
+
continue;
|
43 |
+
}
|
44 |
+
|
45 |
+
for(int wi=0; wi<w; wi++){
|
46 |
+
float curw_f = 2.0*wi/(w-1) - 1.0;
|
47 |
+
float d_x = curw_f - x;
|
48 |
+
if(d_x > dmax || d_x < -dmax){
|
49 |
+
continue;
|
50 |
+
}
|
51 |
+
|
52 |
+
float v = one_div_sigma_x_2*d_x*d_x;
|
53 |
+
v -= two_rho_div_sigma_x_one_div_sigma_y*d_x*d_y;
|
54 |
+
v += one_div_sigma_y_2*d_y*d_y;
|
55 |
+
v *= negative_half_one_div_one_minus_rho2;
|
56 |
+
v = exp(v);
|
57 |
+
|
58 |
+
atomicAdd(&rendered_img[(hi*w+wi)*c+0], v*r);
|
59 |
+
atomicAdd(&rendered_img[(hi*w+wi)*c+1], v*g);
|
60 |
+
atomicAdd(&rendered_img[(hi*w+wi)*c+2], v*b);
|
61 |
+
}
|
62 |
+
}
|
63 |
+
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
void _gs_render(
|
68 |
+
const float *sigmas,
|
69 |
+
const float *coords,
|
70 |
+
const float *colors,
|
71 |
+
float *rendered_img,
|
72 |
+
const int s,
|
73 |
+
const int h,
|
74 |
+
const int w,
|
75 |
+
const int c,
|
76 |
+
const float dmax
|
77 |
+
) {
|
78 |
+
|
79 |
+
int threads=64;
|
80 |
+
dim3 grid(int(s/threads)+1);
|
81 |
+
dim3 block(threads);
|
82 |
+
_gs_render_cuda<<<grid, block>>>(sigmas, coords, colors, rendered_img, s, h, w, c, dmax);
|
83 |
+
}
|
84 |
+
|
85 |
+
__global__ void _gs_render_backward_cuda(
|
86 |
+
const float *sigmas,
|
87 |
+
const float *coords,
|
88 |
+
const float *colors,
|
89 |
+
const float *grads,
|
90 |
+
float *grads_sigmas,
|
91 |
+
float *grads_coords,
|
92 |
+
float *grads_colors,
|
93 |
+
const int s, // gs num
|
94 |
+
const int h,
|
95 |
+
const int w,
|
96 |
+
const int c,
|
97 |
+
const float dmax
|
98 |
+
){
|
99 |
+
|
100 |
+
int curs = blockIdx.x*blockDim.x + threadIdx.x;
|
101 |
+
if(curs >= s){
|
102 |
+
return ;
|
103 |
+
}
|
104 |
+
|
105 |
+
// obtain parameters of gs
|
106 |
+
float sigma_x = sigmas[curs*3+0];
|
107 |
+
float sigma_y = sigmas[curs*3+1];
|
108 |
+
float rho = sigmas[curs*3+2];
|
109 |
+
float x = coords[curs*2+0];
|
110 |
+
float y = coords[curs*2+1];
|
111 |
+
|
112 |
+
//
|
113 |
+
float w1 = -0.5 / (1-rho*rho) ;
|
114 |
+
float w2 = 1.0 / (sigma_x*sigma_x);
|
115 |
+
float w3 = 1.0 / (sigma_x*sigma_y);
|
116 |
+
float w4 = 1.0 / (sigma_y*sigma_y);
|
117 |
+
float od_sx = 1.0 / sigma_x;
|
118 |
+
float od_sy = 1.0 / sigma_y;
|
119 |
+
|
120 |
+
// init
|
121 |
+
for(int hi = 0; hi < h; hi++){
|
122 |
+
for( int wi=0; wi < w; wi++){
|
123 |
+
|
124 |
+
float curw_f = 2.0*wi/(w-1) - 1.0;
|
125 |
+
float curh_f = 2.0*hi/(h-1) - 1.0;
|
126 |
+
|
127 |
+
// compute the 2d gs value
|
128 |
+
float d_x = curw_f - x; // distance along x axis
|
129 |
+
float d_y = curh_f - y;
|
130 |
+
if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
|
131 |
+
continue;
|
132 |
+
}
|
133 |
+
float d = w2*d_x*d_x - 2*rho*w3*d_x*d_y + w4*d_y*d_y;
|
134 |
+
float v = w1*d;
|
135 |
+
v = exp(v);
|
136 |
+
// printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
|
137 |
+
|
138 |
+
// compute grad of coords
|
139 |
+
float v_2_w1 = v*2*w1;
|
140 |
+
float g_vst_to_gsx = v_2_w1*(-w2*d_x+rho*w3*d_y); // grad of v^{st} to G^s_x
|
141 |
+
float g_vst_to_gsy = v_2_w1*(-w4*d_y+rho*w3*d_x); // grad of v^{st} to G^s_y
|
142 |
+
|
143 |
+
// compute grad of sigmas
|
144 |
+
float g_vst_to_gsigx = v_2_w1*od_sx* (w3*rho*d_x*d_y - w2*d_x*d_x);
|
145 |
+
float g_vst_to_gsigy = v_2_w1*od_sy* (w3*rho*d_x*d_y - w4*d_y*d_y);
|
146 |
+
float g_vst_to_rho = -v_2_w1*(2*w1*rho*d+w3*d_x*d_y);
|
147 |
+
|
148 |
+
for(int ci=0; ci<c; ci++){
|
149 |
+
float _gptc = grads[(hi*w+wi)*c+ci];
|
150 |
+
float _gpt = _gptc*colors[curs*c+ci];
|
151 |
+
|
152 |
+
grads_colors[curs*c+ci] += v*_gptc;
|
153 |
+
|
154 |
+
grads_coords[curs*2+0] += _gpt*g_vst_to_gsx;
|
155 |
+
grads_coords[curs*2+1] += _gpt*g_vst_to_gsy;
|
156 |
+
|
157 |
+
grads_sigmas[curs*3+0] += _gpt*g_vst_to_gsigx;
|
158 |
+
grads_sigmas[curs*3+1] += _gpt*g_vst_to_gsigy;
|
159 |
+
grads_sigmas[curs*3+2] += _gpt*g_vst_to_rho;
|
160 |
+
}
|
161 |
+
|
162 |
+
}
|
163 |
+
}
|
164 |
+
|
165 |
+
}
|
166 |
+
|
167 |
+
void _gs_render_backward(
|
168 |
+
const float *sigmas,
|
169 |
+
const float *coords,
|
170 |
+
const float *colors,
|
171 |
+
const float *grads, // (h, w, c)
|
172 |
+
float *grads_sigmas,
|
173 |
+
float *grads_coords,
|
174 |
+
float *grads_colors,
|
175 |
+
const int s,
|
176 |
+
const int h,
|
177 |
+
const int w,
|
178 |
+
const int c,
|
179 |
+
const float dmax
|
180 |
+
) {
|
181 |
+
|
182 |
+
int threads=64;
|
183 |
+
dim3 grid(s, 1);
|
184 |
+
dim3 block( threads, 1);
|
185 |
+
_gs_render_backward_cuda<<<grid, block>>>(sigmas, coords, colors, grads, grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax);
|
186 |
+
}
|
187 |
+
|
utils/gs_cuda_dmax/gs.h
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
void _gs_render(
|
2 |
+
const float *sigmas,
|
3 |
+
const float *coords,
|
4 |
+
const float *colors,
|
5 |
+
float *rendered_img,
|
6 |
+
const int s,
|
7 |
+
const int h,
|
8 |
+
const int w,
|
9 |
+
const int c,
|
10 |
+
const float dmax
|
11 |
+
);
|
12 |
+
|
13 |
+
void _gs_render_backward(
|
14 |
+
const float *sigmas,
|
15 |
+
const float *coords,
|
16 |
+
const float *colors,
|
17 |
+
const float *grads,
|
18 |
+
float *grads_sigmas,
|
19 |
+
float *grads_coords,
|
20 |
+
float *grads_colors,
|
21 |
+
const int s,
|
22 |
+
const int h,
|
23 |
+
const int w,
|
24 |
+
const int c,
|
25 |
+
const float dmax
|
26 |
+
);
|
utils/gs_cuda_dmax/gswrapper.cpp
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "gs.h"
|
2 |
+
#include <torch/extension.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
6 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
7 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
8 |
+
|
9 |
+
void gs_render(
|
10 |
+
torch::Tensor &sigmas,
|
11 |
+
torch::Tensor &coords,
|
12 |
+
torch::Tensor &colors,
|
13 |
+
torch::Tensor &rendered_img,
|
14 |
+
const int s,
|
15 |
+
const int h,
|
16 |
+
const int w,
|
17 |
+
const int c,
|
18 |
+
const float dmax
|
19 |
+
){
|
20 |
+
|
21 |
+
CHECK_INPUT(sigmas);
|
22 |
+
CHECK_INPUT(coords);
|
23 |
+
CHECK_INPUT(colors);
|
24 |
+
CHECK_INPUT(rendered_img);
|
25 |
+
|
26 |
+
// run the code at the cuda device same with the input
|
27 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));
|
28 |
+
|
29 |
+
_gs_render(
|
30 |
+
(const float *) sigmas.data_ptr(),
|
31 |
+
(const float *) coords.data_ptr(),
|
32 |
+
(const float *) colors.data_ptr(),
|
33 |
+
(float *) rendered_img.data_ptr(),
|
34 |
+
s, h, w, c, dmax);
|
35 |
+
}
|
36 |
+
|
37 |
+
void gs_render_backward(
|
38 |
+
torch::Tensor &sigmas,
|
39 |
+
torch::Tensor &coords,
|
40 |
+
torch::Tensor &colors,
|
41 |
+
torch::Tensor &grads,
|
42 |
+
torch::Tensor &grads_sigmas,
|
43 |
+
torch::Tensor &grads_coords,
|
44 |
+
torch::Tensor &grads_colors,
|
45 |
+
const int s,
|
46 |
+
const int h,
|
47 |
+
const int w,
|
48 |
+
const int c,
|
49 |
+
const float dmax
|
50 |
+
){
|
51 |
+
|
52 |
+
CHECK_INPUT(sigmas);
|
53 |
+
CHECK_INPUT(coords);
|
54 |
+
CHECK_INPUT(colors);
|
55 |
+
CHECK_INPUT(grads);
|
56 |
+
CHECK_INPUT(grads_sigmas);
|
57 |
+
CHECK_INPUT(grads_coords);
|
58 |
+
CHECK_INPUT(grads_colors);
|
59 |
+
|
60 |
+
|
61 |
+
// run the code at the cuda device same with the input
|
62 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));
|
63 |
+
|
64 |
+
_gs_render_backward(
|
65 |
+
(const float *) sigmas.data_ptr(),
|
66 |
+
(const float *) coords.data_ptr(),
|
67 |
+
(const float *) colors.data_ptr(),
|
68 |
+
(const float *) grads.data_ptr(),
|
69 |
+
(float *) grads_sigmas.data_ptr(),
|
70 |
+
(float *) grads_coords.data_ptr(),
|
71 |
+
(float *) grads_colors.data_ptr(),
|
72 |
+
s, h, w, c, dmax);
|
73 |
+
}
|
74 |
+
|
75 |
+
PYBIND11_MODULE( TORCH_EXTENSION_NAME, m) {
|
76 |
+
m.def( "gs_render",
|
77 |
+
&gs_render,
|
78 |
+
"cuda forward wrapper");
|
79 |
+
m.def( "gs_render_backward",
|
80 |
+
&gs_render_backward,
|
81 |
+
"cuda backward wrapper");
|
82 |
+
}
|
utils/gs_cuda_dmax/gswrapper.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.utils.cpp_extension import load
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.autograd.function import once_differentiable
|
6 |
+
|
7 |
+
#
|
8 |
+
build_path = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'build')
|
9 |
+
os.makedirs(build_path, exist_ok=True)
|
10 |
+
|
11 |
+
file_path = os.path.split(os.path.abspath(__file__))[0]
|
12 |
+
# GSWrapper = load(
|
13 |
+
# name="gscuda",
|
14 |
+
# # sources=["gs_cuda/gswrapper.cpp", "gs_cuda/gs.cu"],
|
15 |
+
# sources=[os.path.join(file_path, "gswrapper.cpp"),
|
16 |
+
# os.path.join(file_path, "gs.cu")],
|
17 |
+
# build_directory=build_path,
|
18 |
+
# verbose=True)
|
19 |
+
|
20 |
+
import gscuda
|
21 |
+
GSWrapper = gscuda
|
22 |
+
|
23 |
+
class GSCUDA(Function):
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def forward(ctx, sigmas, coords, colors, rendered_img, dmax):
|
27 |
+
ctx.save_for_backward(sigmas, coords, colors)
|
28 |
+
ctx.dmax = dmax
|
29 |
+
h, w, c = rendered_img.shape
|
30 |
+
s = sigmas.shape[0]
|
31 |
+
GSWrapper.gs_render(sigmas, coords, colors, rendered_img, s, h, w, c, dmax)
|
32 |
+
return rendered_img
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
@once_differentiable
|
36 |
+
def backward(ctx, grad_output):
|
37 |
+
sigmas, coords, colors = ctx.saved_tensors
|
38 |
+
dmax = ctx.dmax
|
39 |
+
h, w, c = grad_output.shape
|
40 |
+
s = sigmas.shape[0]
|
41 |
+
grads_sigmas = torch.zeros_like(sigmas)
|
42 |
+
grads_coords = torch.zeros_like(coords)
|
43 |
+
grads_colors = torch.zeros_like(colors)
|
44 |
+
GSWrapper.gs_render_backward(sigmas, coords, colors, grad_output.contiguous(), grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax)
|
45 |
+
return (grads_sigmas, grads_coords, grads_colors, None, None)
|
46 |
+
|
47 |
+
def gaussiansplatting_render(sigmas, coords, colors, image_size,dmax=100):
|
48 |
+
sigmas = sigmas.contiguous() # (gs num, 3)
|
49 |
+
coords = coords.contiguous() # (gs num, 2)
|
50 |
+
colors = colors.contiguous() # (gs num, c)
|
51 |
+
h, w = image_size[:2]
|
52 |
+
c = colors.shape[-1]
|
53 |
+
rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
|
54 |
+
return GSCUDA.apply(sigmas, coords, colors, rendered_img, dmax)
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
sigmas = torch.randn(10, 3).cuda()
|
58 |
+
coords = torch.randn(10, 2).cuda()
|
59 |
+
colors = torch.randn(10, 3).cuda()
|
60 |
+
image_size = (100, 100)
|
61 |
+
dmax = 0.1
|
62 |
+
rendered_img = gaussiansplatting_render(sigmas, coords, colors, image_size, dmax)
|
63 |
+
print(rendered_img.shape)
|
utils/gs_cuda_dmax/mylineprofiler.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import sys
|
4 |
+
import timeit
|
5 |
+
import tokenize
|
6 |
+
import torch
|
7 |
+
import psutil
|
8 |
+
import inspect
|
9 |
+
from loguru import logger
|
10 |
+
from prettytable import PrettyTable
|
11 |
+
|
12 |
+
# implement by xtudbxk
|
13 |
+
# github: https://github.com/xtudbxk/lineprofiler
|
14 |
+
class MyLineProfiler():
|
15 |
+
def __init__(self, base='ms', cuda_sync=True, gpuids=(0,), warmup=0, warmup_lineno=-1):
|
16 |
+
|
17 |
+
if base == 'ms':
|
18 |
+
self.base_n = 1000
|
19 |
+
elif base == 's':
|
20 |
+
self.base_n = 1
|
21 |
+
else:
|
22 |
+
logguru.warning(f'Unsupported base - {base}, using "s" instead')
|
23 |
+
|
24 |
+
self.base = base
|
25 |
+
self.cuda_sync = cuda_sync
|
26 |
+
self.gpuids = gpuids
|
27 |
+
self.warmup = warmup
|
28 |
+
self.warmup_counter = warmup
|
29 |
+
# we should wait this line execute warup_counter times
|
30 |
+
# before recording the stats
|
31 |
+
self.warmup_lineno = warmup_lineno
|
32 |
+
|
33 |
+
# for time profiling
|
34 |
+
self._times = {}
|
35 |
+
self._func_name = None
|
36 |
+
self._func_filename = None
|
37 |
+
self._last_time = -1
|
38 |
+
self._last_lineno = -1
|
39 |
+
self._func_hit_count = 0
|
40 |
+
self._func_firstlineno = 0
|
41 |
+
|
42 |
+
# for memory profiling
|
43 |
+
self._process = psutil.Process(os.getpid())
|
44 |
+
self._memory = {}
|
45 |
+
self._last_memory = 0
|
46 |
+
|
47 |
+
# for cuda memory profiling
|
48 |
+
self._gpu_memory = {}
|
49 |
+
self._gpu_last_memory = 0
|
50 |
+
|
51 |
+
def __trace_func__(self, frame, event, arg):
|
52 |
+
# print(f'in {frame.f_code.co_filename} func {frame.f_code.co_name} line {frame.f_lineno}, event - {event}')
|
53 |
+
|
54 |
+
# check if run into the decorated func
|
55 |
+
if self._func_firstlineno == frame.f_code.co_firstlineno and frame.f_code.co_name == self._func_name and frame.f_code.co_filename == self._func_filename:
|
56 |
+
|
57 |
+
# --- obtain info for current hit ---
|
58 |
+
# cuda related
|
59 |
+
if self.cuda_sync is True:
|
60 |
+
torch.cuda.synchronize()
|
61 |
+
|
62 |
+
current_time = timeit.default_timer()
|
63 |
+
memory = self._process.memory_info().rss
|
64 |
+
gpu_memory = torch.cuda.memory_allocated()
|
65 |
+
# --- ends ---
|
66 |
+
|
67 |
+
# --- initilize the info when first hit ---
|
68 |
+
if frame.f_lineno not in self._times: # first hit time for this line
|
69 |
+
self._times[frame.f_lineno] = {'hit':0, 'time': 0}
|
70 |
+
self._memory[frame.f_lineno] = 0
|
71 |
+
self._gpu_memory[frame.f_lineno] = 0
|
72 |
+
# --- ends ---
|
73 |
+
|
74 |
+
# --- record info before call the decorated func ---
|
75 |
+
# 'call' - before call the func
|
76 |
+
if event == 'call':
|
77 |
+
self._last_time = current_time
|
78 |
+
self._last_lineno = frame.f_lineno
|
79 |
+
self._last_memory = memory
|
80 |
+
self._last_gpu_memory = gpu_memory
|
81 |
+
|
82 |
+
if self.warmup_lineno < 0:
|
83 |
+
self.warmup_counter -= 1
|
84 |
+
if self.warmup_counter < 0:
|
85 |
+
self._func_hit_count += 1
|
86 |
+
# --- ends ---
|
87 |
+
|
88 |
+
# 'line' - after excuting the line
|
89 |
+
# 'return' - return from the function
|
90 |
+
if event == 'line' or event == 'return':
|
91 |
+
|
92 |
+
if event == 'line' and self.warmup_counter < 0:
|
93 |
+
self._times[frame.f_lineno]['hit'] += 1
|
94 |
+
|
95 |
+
|
96 |
+
# --- obtain the memory and time consumed by this line ---
|
97 |
+
if self.warmup_counter < 0:
|
98 |
+
self._times[self._last_lineno]['time'] += current_time - self._last_time
|
99 |
+
self._memory[self._last_lineno] += memory - self._last_memory
|
100 |
+
self._gpu_memory[self._last_lineno] += gpu_memory - self._gpu_last_memory
|
101 |
+
# --- ends ---
|
102 |
+
|
103 |
+
if self.cuda_sync is True:
|
104 |
+
torch.cuda.synchronize()
|
105 |
+
|
106 |
+
self._last_time = timeit.default_timer()
|
107 |
+
self._last_memory = memory
|
108 |
+
self._gpu_last_memory = gpu_memory
|
109 |
+
self._last_lineno = frame.f_lineno
|
110 |
+
|
111 |
+
return self.__trace_func__
|
112 |
+
|
113 |
+
def decorate(self, func):
|
114 |
+
if self._func_name is not None:
|
115 |
+
logger.warning(f'Only support decorate only one func. Aready decorated "{self._func_name}"')
|
116 |
+
self._func_name = func.__name__
|
117 |
+
self._func_filename = func.__code__.co_filename
|
118 |
+
self._func_firstlineno = func.__code__.co_firstlineno
|
119 |
+
|
120 |
+
def _f(*args, **kwargs):
|
121 |
+
origin_trace_func = sys.gettrace()
|
122 |
+
sys.settrace(self.__trace_func__)
|
123 |
+
ret = func(*args, **kwargs)
|
124 |
+
sys.settrace(origin_trace_func)
|
125 |
+
return ret
|
126 |
+
return _f
|
127 |
+
|
128 |
+
def _get_table(self):
|
129 |
+
|
130 |
+
if len(self._times) <= 0:
|
131 |
+
logger.warning(f"un recorded datas, please ensure the function is executed")
|
132 |
+
return None
|
133 |
+
|
134 |
+
# --- load the source code ---
|
135 |
+
with open(self._func_filename, 'r') as f:
|
136 |
+
source_lines = [line.strip('\n') for line in f.readlines()]
|
137 |
+
code_str = "\n".join(source_lines)
|
138 |
+
|
139 |
+
def_lineno = min(self._times.keys())
|
140 |
+
final_lineno = max(self._times.keys())
|
141 |
+
|
142 |
+
# remove the additional blank content
|
143 |
+
pre_blank_count = len(source_lines[def_lineno-1]) - len(source_lines[def_lineno-1].lstrip(' ').lstrip('\t'))
|
144 |
+
# --- ends ---
|
145 |
+
|
146 |
+
# --- analysize the source code and collect infos for multi-line code ---
|
147 |
+
new_logic_linenos = [token.start[0] for token in tokenize.generate_tokens(
|
148 |
+
io.StringIO(code_str).readline) if token.type == 4]
|
149 |
+
# --- ends ---
|
150 |
+
|
151 |
+
# --- merge the stats multi-line code ---
|
152 |
+
sorted_linenos = [lineno for lineno in self._times.keys()]
|
153 |
+
sorted_linenos.sort(key=int)
|
154 |
+
|
155 |
+
lineno_cache = []
|
156 |
+
for lineno in sorted_linenos:
|
157 |
+
if lineno not in new_logic_linenos:
|
158 |
+
lineno_cache.append(lineno)
|
159 |
+
else:
|
160 |
+
# we should merge its info to the prev_lineno
|
161 |
+
if len(lineno_cache) <= 0:
|
162 |
+
continue
|
163 |
+
else:
|
164 |
+
lineno_cache.append(lineno)
|
165 |
+
first_lineno = lineno_cache[0]
|
166 |
+
for prev_lineno in lineno_cache[1:]:
|
167 |
+
self._times[first_lineno]["hit"] = min(self._times[first_lineno]["hit"], self._times[prev_lineno]["hit"])
|
168 |
+
self._times[first_lineno]["time"] += self._times[prev_lineno]["time"]
|
169 |
+
del self._times[prev_lineno]
|
170 |
+
|
171 |
+
self._memory[first_lineno] += self._memory[prev_lineno]
|
172 |
+
del self._memory[prev_lineno]
|
173 |
+
|
174 |
+
self._gpu_memory[first_lineno] += self._gpu_memory[prev_lineno]
|
175 |
+
del self._gpu_memory[prev_lineno]
|
176 |
+
lineno_cache = []
|
177 |
+
# --- ends ---
|
178 |
+
|
179 |
+
# --- initialize the pretty table for output ---
|
180 |
+
table = PrettyTable(['lineno', 'hits', 'time', 'time per hit', 'hit perc', 'time perc', 'mem inc', 'mem peak', 'gpu mem inc', 'gpu mem peak'])
|
181 |
+
# --- ends ---
|
182 |
+
|
183 |
+
# --- compute some statisticals ---
|
184 |
+
total_hit = 0 # for compute the hit percentage
|
185 |
+
total_time = 0
|
186 |
+
for lineno, stats in self._times.items():
|
187 |
+
if lineno == def_lineno: continue
|
188 |
+
total_hit += stats['hit']
|
189 |
+
total_time += stats['time']
|
190 |
+
|
191 |
+
total_memory = sum([m for l,m in self._memory.items()]) / 1024 / 1024
|
192 |
+
total_gpu_memory = sum([m for l,m in self._gpu_memory.items()]) / 1024 / 1024
|
193 |
+
# --- ends ---
|
194 |
+
|
195 |
+
peak_cpu_memory = 0
|
196 |
+
peak_gpu_memory = 0
|
197 |
+
for lineno in range(def_lineno, final_lineno+1):
|
198 |
+
if lineno not in self._times:
|
199 |
+
# the comment line, empty line or merged line from multi-lines code
|
200 |
+
table.add_row([lineno, '-', '-', '-', '-', '-', '-',f'{peak_cpu_memory:5.3f} MB', '-', f'{peak_gpu_memory:5.3f} MB'])
|
201 |
+
else:
|
202 |
+
stats = self._times[lineno]
|
203 |
+
if lineno == def_lineno:
|
204 |
+
table.add_row([lineno, self._func_hit_count, f'{total_time*self.base_n:.4f} {self.base}', f'{total_time/self._func_hit_count*self.base_n:.4f} {self.base}', '-', '-', f'{total_memory:5.3f} MB', 'baseline', f'{total_gpu_memory:5.3f} MB', 'baseline'])
|
205 |
+
else:
|
206 |
+
|
207 |
+
line_result = [lineno, stats['hit'],
|
208 |
+
f'{stats["time"]*self.base_n:.4f} {self.base}',
|
209 |
+
f'{stats["time"]/stats["hit"]*self.base_n:.4f} {self.base}' if stats['hit'] > 0 else 'nan',
|
210 |
+
f'{stats["hit"]/total_hit*100:.3f}%' if total_hit > 0 else 'nan',
|
211 |
+
f'{stats["time"]/total_time*100:.3f}%'] if total_time > 0 else 'nan'
|
212 |
+
|
213 |
+
line_result += [f'{self._memory[lineno]/1024/1024:5.3f} MB' if stats['hit'] > 0 else '0 MB']
|
214 |
+
peak_cpu_memory = peak_cpu_memory + self._memory[lineno]/1024/1024
|
215 |
+
line_result += [f'{peak_cpu_memory:5.3f} MB']
|
216 |
+
|
217 |
+
line_result += [f'{self._gpu_memory[lineno]/1024/1024:5.3f} MB' if stats['hit'] > 0 else '0 MB']
|
218 |
+
peak_gpu_memory = peak_gpu_memory + self._gpu_memory[lineno]/1024/1024
|
219 |
+
line_result += [f'{peak_gpu_memory:5.3f} MB']
|
220 |
+
|
221 |
+
table.add_row(line_result)
|
222 |
+
|
223 |
+
table.add_column('sources', [source_lines[i-1][pre_blank_count:] if len(source_lines[i-1])>pre_blank_count else '' for i in range(def_lineno, final_lineno+1)], 'l')
|
224 |
+
return table
|
225 |
+
|
226 |
+
def print(self, filename=None, mode="w"):
|
227 |
+
introducation = '''
|
228 |
+
1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
|
229 |
+
2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
|
230 |
+
3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
|
231 |
+
4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
|
232 |
+
'''
|
233 |
+
print(introducation)
|
234 |
+
|
235 |
+
table = PrettyTable(['lineno', 'hits', 'time', 'time per hit', 'hit perc', 'time perc', 'mem inc', 'mem peak', 'gpu mem inc', 'gpu mem peak'])
|
236 |
+
table = self._get_table()
|
237 |
+
print(table)
|
238 |
+
if filename is not None:
|
239 |
+
with open(filename, mode) as f:
|
240 |
+
f.write(introducation)
|
241 |
+
f.write(f"args - base={self.base}, cuda_sync={self.cuda_sync}, gpuids={self.gpuids}, warmup={self.warmup}\n")
|
242 |
+
f.write(str(table))
|
243 |
+
|
244 |
+
if __name__ == '__main__':
|
245 |
+
import numpy as np
|
246 |
+
def mytest(h='hello',
|
247 |
+
xx="xx"):
|
248 |
+
|
249 |
+
h = h + 'world'
|
250 |
+
a = []
|
251 |
+
for _ in range(200):
|
252 |
+
# a = np.zeros((1000, 1000), dtype=np.float32)
|
253 |
+
a.append(np.zeros((1000, 1000), dtype=np.float32))
|
254 |
+
a.append(
|
255 |
+
np.zeros((1000, 1000),
|
256 |
+
dtype=np.float32))
|
257 |
+
# print(a[0,0])
|
258 |
+
print(h)
|
259 |
+
|
260 |
+
profiler = MyLineProfiler(cuda_sync=False, warmup=2)
|
261 |
+
mytest = profiler.decorate(mytest)
|
262 |
+
for _ in range(5):
|
263 |
+
mytest()
|
264 |
+
profiler.print()
|
utils/gs_cuda_dmax/profile.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from gswrapper import gaussiansplatting_render
|
7 |
+
|
8 |
+
def generate_2D_gaussian_splatting(kernel_size, sigma_x, sigma_y, rho, coords,
|
9 |
+
colours, image_size=(256, 256, 3), device="cuda"):
|
10 |
+
|
11 |
+
batch_size = colours.shape[0]
|
12 |
+
|
13 |
+
sigma_x = sigma_x.view(batch_size, 1, 1)
|
14 |
+
sigma_y = sigma_y.view(batch_size, 1, 1)
|
15 |
+
rho = rho.view(batch_size, 1, 1)
|
16 |
+
|
17 |
+
covariance = torch.stack(
|
18 |
+
[torch.stack([sigma_x**2, rho*sigma_x*sigma_y], dim=-1),
|
19 |
+
torch.stack([rho*sigma_x*sigma_y, sigma_y**2], dim=-1)],
|
20 |
+
dim=-2
|
21 |
+
)
|
22 |
+
|
23 |
+
# Check for positive semi-definiteness
|
24 |
+
# determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y)**2
|
25 |
+
# if (determinant <= 0).any():
|
26 |
+
# raise ValueError("Covariance matrix must be positive semi-definite")
|
27 |
+
|
28 |
+
inv_covariance = torch.inverse(covariance)
|
29 |
+
|
30 |
+
# Choosing quite a broad range for the distribution [-5,5] to avoid any clipping
|
31 |
+
start = torch.tensor([-5.0], device=device).view(-1, 1)
|
32 |
+
end = torch.tensor([5.0], device=device).view(-1, 1)
|
33 |
+
base_linspace = torch.linspace(0, 1, steps=kernel_size, device=device)
|
34 |
+
ax_batch = start + (end - start) * base_linspace
|
35 |
+
|
36 |
+
# Expanding dims for broadcasting
|
37 |
+
ax_batch_expanded_x = ax_batch.unsqueeze(-1).expand(-1, -1, kernel_size)
|
38 |
+
ax_batch_expanded_y = ax_batch.unsqueeze(1).expand(-1, kernel_size, -1)
|
39 |
+
|
40 |
+
# Creating a batch-wise meshgrid using broadcasting
|
41 |
+
xx, yy = ax_batch_expanded_x, ax_batch_expanded_y # (batchsize, kernelsize, kernelsize)
|
42 |
+
|
43 |
+
xy = torch.stack([xx, yy], dim=-1) # (batchsize, kernelsize, kernelsize, 2)
|
44 |
+
z = torch.einsum('b...i,b...ij,b...j->b...', xy, -0.5 * inv_covariance, xy) # (batchsize, kernelsize, kernelsize, 2)
|
45 |
+
kernel = torch.exp(z) / (2 * torch.tensor(np.pi, device=device) * torch.sqrt(torch.det(covariance)).view(batch_size, 1, 1)) # (batchsize, kernelsize, kernelsize)
|
46 |
+
|
47 |
+
|
48 |
+
kernel_max_1, _ = kernel.max(dim=-1, keepdim=True) # Find max along the last dimension
|
49 |
+
kernel_max_2, _ = kernel_max_1.max(dim=-2, keepdim=True) # Find max along the second-to-last dimension
|
50 |
+
kernel_normalized = kernel / kernel_max_2 # (batchsize, kernelsize, kernelsize)
|
51 |
+
|
52 |
+
|
53 |
+
kernel_reshaped = kernel_normalized.repeat(1, 3, 1).view(batch_size * 3, kernel_size, kernel_size)
|
54 |
+
kernel_rgb = kernel_reshaped.unsqueeze(0).reshape(batch_size, 3, kernel_size, kernel_size) # (batchsize, 3, kernelsize, kernelsize)
|
55 |
+
|
56 |
+
# Calculating the padding needed to match the image size
|
57 |
+
pad_h = image_size[0] - kernel_size
|
58 |
+
pad_w = image_size[1] - kernel_size
|
59 |
+
|
60 |
+
if pad_h < 0 or pad_w < 0:
|
61 |
+
raise ValueError("Kernel size should be smaller or equal to the image size.")
|
62 |
+
|
63 |
+
# Adding padding to make kernel size equal to the image size
|
64 |
+
padding = (pad_w // 2, pad_w // 2 + pad_w % 2, # padding left and right
|
65 |
+
pad_h // 2, pad_h // 2 + pad_h % 2) # padding top and bottom
|
66 |
+
|
67 |
+
kernel_rgb_padded = torch.nn.functional.pad(kernel_rgb, padding, "constant", 0) # (batchsize, 3, h, w)
|
68 |
+
|
69 |
+
# Extracting shape information
|
70 |
+
b, c, h, w = kernel_rgb_padded.shape
|
71 |
+
|
72 |
+
# Create a batch of 2D affine matrices
|
73 |
+
theta = torch.zeros(b, 2, 3, dtype=torch.float32, device=device)
|
74 |
+
theta[:, 0, 0] = 1.0
|
75 |
+
theta[:, 1, 1] = 1.0
|
76 |
+
theta[:, :, 2] = -coords # (b, 2) - the offset of gaussian splating
|
77 |
+
|
78 |
+
# Creating grid and performing grid sampling
|
79 |
+
grid = F.affine_grid(theta, size=(b, c, h, w), align_corners=True) # (b, 3, h, w)
|
80 |
+
# grid_y = torch.linspace(-1, 1, steps=h, device=device).reshape(1, h, 1, 1).repeat(1, 1, w, 1)
|
81 |
+
# grid_x = torch.linspace(-1, 1, steps=w, device=device).reshape(1, 1, w, 1).repeat(1, h, 1, 1)
|
82 |
+
# grid = torch.cat([grid_x, grid_y], dim=-1)
|
83 |
+
# grid = grid - coords.reshape(-1, 1, 1, 2)
|
84 |
+
|
85 |
+
kernel_rgb_padded_translated = F.grid_sample(kernel_rgb_padded, grid, align_corners=True) # (b, 3, h, w)
|
86 |
+
|
87 |
+
rgb_values_reshaped = colours.unsqueeze(-1).unsqueeze(-1)
|
88 |
+
|
89 |
+
final_image_layers = rgb_values_reshaped * kernel_rgb_padded_translated
|
90 |
+
final_image = final_image_layers.sum(dim=0)
|
91 |
+
# final_image = torch.clamp(final_image, 0, 1)
|
92 |
+
final_image = final_image.permute(1,2,0)
|
93 |
+
|
94 |
+
return final_image
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
from mylineprofiler import MyLineProfiler
|
99 |
+
profiler_th = MyLineProfiler(cuda_sync=True)
|
100 |
+
generate_2D_gaussian_splatting = profiler_th.decorate(generate_2D_gaussian_splatting)
|
101 |
+
profiler_cuda = MyLineProfiler(cuda_sync=True)
|
102 |
+
gaussiansplatting_render = profiler_cuda.decorate(gaussiansplatting_render)
|
103 |
+
|
104 |
+
|
105 |
+
# --- test ---
|
106 |
+
# s = 1000
|
107 |
+
s = 5
|
108 |
+
# image_size = (512, 512, 3)
|
109 |
+
image_size = (511, 511, 3)
|
110 |
+
# image_size = (256, 512, 3)
|
111 |
+
# image_size = (256, 256, 3)
|
112 |
+
|
113 |
+
sigmas = 0.999*torch.rand(s, 3).to(torch.float32).to("cuda")
|
114 |
+
sigmas[:,:2] = 5*sigmas[:, :2]
|
115 |
+
coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
|
116 |
+
colors = torch.rand(s, 3).to(torch.float32).to("cuda")
|
117 |
+
|
118 |
+
# --- torch version ---
|
119 |
+
import gc
|
120 |
+
gc.collect()
|
121 |
+
torch.cuda.empty_cache()
|
122 |
+
for _ in range(20):
|
123 |
+
img = generate_2D_gaussian_splatting(101, sigmas[:,1], sigmas[:,0], sigmas[:,2], coords, colors, image_size)
|
124 |
+
profiler_th.print("profile.log", "w")
|
125 |
+
cv2.imwrite("th.png", 255.0 * img.detach().clamp(0, 1).cpu().numpy())
|
126 |
+
# --- ends ---
|
127 |
+
|
128 |
+
# --- cuda version ---
|
129 |
+
_stepsize_of_gs_th = 10 / (101-1)
|
130 |
+
_stepsize_of_gs_cuda_w = 2 / (image_size[1]-1)
|
131 |
+
_stepsize_of_gs_cuda_h = 2 / (image_size[0]-1)
|
132 |
+
sigmas[:, 0] = sigmas[:, 0] * _stepsize_of_gs_cuda_w / _stepsize_of_gs_th
|
133 |
+
sigmas[:, 1] = sigmas[:, 1] * _stepsize_of_gs_cuda_h / _stepsize_of_gs_th
|
134 |
+
dmax = 101/2*_stepsize_of_gs_cuda_w
|
135 |
+
gc.collect()
|
136 |
+
torch.cuda.empty_cache()
|
137 |
+
for _ in range(20):
|
138 |
+
img = gaussiansplatting_render(sigmas, coords, colors, image_size, dmax)
|
139 |
+
|
140 |
+
profiler_cuda.print("profile.log", "a")
|
141 |
+
cv2.imwrite("cuda.png", 255.0 * img.detach().clamp(0, 1).cpu().numpy())
|
142 |
+
# --- ends ---
|
utils/hatropeamp.py
ADDED
@@ -0,0 +1,1156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.utils.checkpoint import checkpoint
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import collections.abc
|
7 |
+
from itertools import repeat
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
from typing import Any, Optional, Tuple
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
# From PyTorch
|
15 |
+
def _ntuple(n):
|
16 |
+
|
17 |
+
def parse(x):
|
18 |
+
if isinstance(x, collections.abc.Iterable):
|
19 |
+
return x
|
20 |
+
return tuple(repeat(x, n))
|
21 |
+
|
22 |
+
return parse
|
23 |
+
|
24 |
+
|
25 |
+
to_1tuple = _ntuple(1)
|
26 |
+
to_2tuple = _ntuple(2)
|
27 |
+
to_3tuple = _ntuple(3)
|
28 |
+
to_4tuple = _ntuple(4)
|
29 |
+
to_ntuple = _ntuple
|
30 |
+
|
31 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
32 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
33 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
34 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
35 |
+
def norm_cdf(x):
|
36 |
+
# Computes standard normal cumulative distribution function
|
37 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
38 |
+
|
39 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
40 |
+
warnings.warn(
|
41 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
42 |
+
'The distribution of values may be incorrect.',
|
43 |
+
stacklevel=2)
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
# Values are generated by using a truncated uniform distribution and
|
47 |
+
# then using the inverse CDF for the normal distribution.
|
48 |
+
# Get upper and lower cdf values
|
49 |
+
low = norm_cdf((a - mean) / std)
|
50 |
+
up = norm_cdf((b - mean) / std)
|
51 |
+
|
52 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
53 |
+
# [2l-1, 2u-1].
|
54 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
55 |
+
|
56 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
57 |
+
# standard normal
|
58 |
+
tensor.erfinv_()
|
59 |
+
|
60 |
+
# Transform to proper mean, std
|
61 |
+
tensor.mul_(std * math.sqrt(2.))
|
62 |
+
tensor.add_(mean)
|
63 |
+
|
64 |
+
# Clamp to ensure it's in the proper range
|
65 |
+
tensor.clamp_(min=a, max=b)
|
66 |
+
return tensor
|
67 |
+
|
68 |
+
|
69 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
70 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
71 |
+
normal distribution.
|
72 |
+
|
73 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
74 |
+
|
75 |
+
The values are effectively drawn from the
|
76 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
77 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
78 |
+
the bounds. The method used for generating the random values works
|
79 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
tensor: an n-dimensional `torch.Tensor`
|
83 |
+
mean: the mean of the normal distribution
|
84 |
+
std: the standard deviation of the normal distribution
|
85 |
+
a: the minimum cutoff value
|
86 |
+
b: the maximum cutoff value
|
87 |
+
|
88 |
+
Examples:
|
89 |
+
>>> w = torch.empty(3, 5)
|
90 |
+
>>> nn.init.trunc_normal_(w)
|
91 |
+
"""
|
92 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
93 |
+
|
94 |
+
def init_t_xy(end_x: int, end_y: int, zero_center=False):
|
95 |
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
96 |
+
t_x = (t % end_x).float()
|
97 |
+
t_y = torch.div(t, end_x, rounding_mode='floor').float()
|
98 |
+
|
99 |
+
return t_x, t_y
|
100 |
+
|
101 |
+
def init_random_2d_freqs(head_dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
|
102 |
+
freqs_x = []
|
103 |
+
freqs_y = []
|
104 |
+
theta = theta
|
105 |
+
mag = 1 / (theta ** (torch.arange(0, head_dim, 4)[: (head_dim // 4)].float() / head_dim))
|
106 |
+
for i in range(num_heads):
|
107 |
+
angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
|
108 |
+
fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
|
109 |
+
fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
|
110 |
+
freqs_x.append(fx)
|
111 |
+
freqs_y.append(fy)
|
112 |
+
freqs_x = torch.stack(freqs_x, dim=0)
|
113 |
+
freqs_y = torch.stack(freqs_y, dim=0)
|
114 |
+
freqs = torch.stack([freqs_x, freqs_y], dim=0)
|
115 |
+
return freqs
|
116 |
+
|
117 |
+
def compute_cis(freqs, t_x, t_y):
|
118 |
+
N = t_x.shape[0]
|
119 |
+
# No float 16 for this range
|
120 |
+
with torch.cuda.amp.autocast(enabled=False):
|
121 |
+
freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
|
122 |
+
freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
|
123 |
+
freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)
|
124 |
+
|
125 |
+
return freqs_cis
|
126 |
+
|
127 |
+
|
128 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
129 |
+
ndim = x.ndim
|
130 |
+
assert 0 <= 1 < ndim
|
131 |
+
# assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
132 |
+
# print(f"freqs_cis shape is {freqs_cis.shape}, x shape is {x.shape}")
|
133 |
+
if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
|
134 |
+
shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
|
135 |
+
elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
|
136 |
+
shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
|
137 |
+
|
138 |
+
return freqs_cis.view(*shape)
|
139 |
+
|
140 |
+
def apply_rotary_emb(
|
141 |
+
xq: torch.Tensor,
|
142 |
+
xk: torch.Tensor,
|
143 |
+
freqs_cis: torch.Tensor,
|
144 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
145 |
+
# print(f"xq shape is {xq.shape}, xq.shape[:-1] is {xq.shape[:-1]}")
|
146 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
147 |
+
# print(f"xq_ shape is {xq_.shape}")
|
148 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
149 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
150 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
151 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
152 |
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
153 |
+
|
154 |
+
def apply_rotary_emb_single(x, freqs_cis):
|
155 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
156 |
+
seq_len = x_.shape[2]
|
157 |
+
freqs_cis = freqs_cis[:, :seq_len, :]
|
158 |
+
freqs_cis = freqs_cis.unsqueeze(0).expand_as(x_)
|
159 |
+
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
160 |
+
return x_out.type_as(x).to(x.device)
|
161 |
+
|
162 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
163 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
164 |
+
|
165 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
166 |
+
"""
|
167 |
+
if drop_prob == 0. or not training:
|
168 |
+
return x
|
169 |
+
keep_prob = 1 - drop_prob
|
170 |
+
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
171 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
172 |
+
random_tensor.floor_() # binarize
|
173 |
+
output = x.div(keep_prob) * random_tensor
|
174 |
+
return output
|
175 |
+
|
176 |
+
|
177 |
+
class DropPath(nn.Module):
|
178 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
179 |
+
|
180 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, drop_prob=None):
|
184 |
+
super(DropPath, self).__init__()
|
185 |
+
self.drop_prob = drop_prob
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
return drop_path(x, self.drop_prob, self.training)
|
189 |
+
|
190 |
+
|
191 |
+
class ChannelAttention(nn.Module):
|
192 |
+
"""Channel attention used in RCAN.
|
193 |
+
Args:
|
194 |
+
num_feat (int): Channel number of intermediate features.
|
195 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, num_feat, squeeze_factor=16):
|
199 |
+
super(ChannelAttention, self).__init__()
|
200 |
+
self.attention = nn.Sequential(
|
201 |
+
nn.AdaptiveAvgPool2d(1),
|
202 |
+
nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
|
203 |
+
nn.ReLU(inplace=True),
|
204 |
+
nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
|
205 |
+
nn.Sigmoid())
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
y = self.attention(x)
|
209 |
+
return x * y
|
210 |
+
|
211 |
+
|
212 |
+
class CAB(nn.Module):
|
213 |
+
|
214 |
+
def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
|
215 |
+
super(CAB, self).__init__()
|
216 |
+
|
217 |
+
self.cab = nn.Sequential(
|
218 |
+
nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
|
219 |
+
nn.GELU(),
|
220 |
+
nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
|
221 |
+
ChannelAttention(num_feat, squeeze_factor)
|
222 |
+
)
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
return self.cab(x)
|
226 |
+
|
227 |
+
|
228 |
+
class Mlp(nn.Module):
|
229 |
+
|
230 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
231 |
+
super().__init__()
|
232 |
+
out_features = out_features or in_features
|
233 |
+
hidden_features = hidden_features or in_features
|
234 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
235 |
+
self.act = act_layer()
|
236 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
237 |
+
self.drop = nn.Dropout(drop)
|
238 |
+
|
239 |
+
def forward(self, x):
|
240 |
+
x = self.fc1(x)
|
241 |
+
x = self.act(x)
|
242 |
+
x = self.drop(x)
|
243 |
+
x = self.fc2(x)
|
244 |
+
x = self.drop(x)
|
245 |
+
return x
|
246 |
+
|
247 |
+
|
248 |
+
def window_partition(x, window_size):
|
249 |
+
"""
|
250 |
+
Args:
|
251 |
+
x: (b, h, w, c)
|
252 |
+
window_size (int): window size
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
windows: (num_windows*b, window_size, window_size, c)
|
256 |
+
"""
|
257 |
+
b, h, w, c = x.shape
|
258 |
+
x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
|
259 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
|
260 |
+
return windows
|
261 |
+
|
262 |
+
|
263 |
+
def window_reverse(windows, window_size, h, w):
|
264 |
+
"""
|
265 |
+
Args:
|
266 |
+
windows: (num_windows*b, window_size, window_size, c)
|
267 |
+
window_size (int): Window size
|
268 |
+
h (int): Height of image
|
269 |
+
w (int): Width of image
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
x: (b, h, w, c)
|
273 |
+
"""
|
274 |
+
b = int(windows.shape[0] / (h * w / window_size / window_size))
|
275 |
+
x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
|
276 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
277 |
+
return x
|
278 |
+
|
279 |
+
|
280 |
+
class WindowAttention(nn.Module):
|
281 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
282 |
+
It supports both of shifted and non-shifted window.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
dim (int): Number of input channels.
|
286 |
+
window_size (tuple[int]): The height and width of the window.
|
287 |
+
num_heads (int): Number of attention heads.
|
288 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
289 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
290 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
291 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
292 |
+
"""
|
293 |
+
|
294 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., rope_mixed = True, rope_theta=10.0):
|
295 |
+
|
296 |
+
super().__init__()
|
297 |
+
self.dim = dim
|
298 |
+
self.window_size = window_size # Wh, Ww
|
299 |
+
self.num_heads = num_heads
|
300 |
+
head_dim = dim // num_heads
|
301 |
+
|
302 |
+
self.rope_mixed = rope_mixed
|
303 |
+
t_x, t_y = init_t_xy(end_x=self.window_size[1], end_y=self.window_size[0])
|
304 |
+
self.register_buffer('rope_t_x', t_x)
|
305 |
+
self.register_buffer('rope_t_y', t_y)
|
306 |
+
|
307 |
+
freqs = init_random_2d_freqs(
|
308 |
+
head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
|
309 |
+
rotate=self.rope_mixed
|
310 |
+
)
|
311 |
+
if self.rope_mixed:
|
312 |
+
self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
|
313 |
+
else:
|
314 |
+
self.register_buffer('rope_freqs', freqs)
|
315 |
+
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
|
316 |
+
self.rope_freqs_cis = freqs_cis
|
317 |
+
|
318 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
319 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
320 |
+
self.proj = nn.Linear(dim, dim)
|
321 |
+
|
322 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
323 |
+
|
324 |
+
|
325 |
+
def forward(self, x, rpi, mask=None):
|
326 |
+
"""
|
327 |
+
Args:
|
328 |
+
x: input features with shape of (num_windows*b, n, c)
|
329 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
330 |
+
"""
|
331 |
+
b_, n, c = x.shape
|
332 |
+
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
|
333 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
334 |
+
|
335 |
+
###### Apply rotary position embedding
|
336 |
+
if self.rope_mixed:
|
337 |
+
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
|
338 |
+
else:
|
339 |
+
freqs_cis = self.rope_freqs_cis.to(x.device)
|
340 |
+
q, k = apply_rotary_emb(q, k, freqs_cis)
|
341 |
+
#########
|
342 |
+
|
343 |
+
attn = F.scaled_dot_product_attention(q, k, v)
|
344 |
+
|
345 |
+
attn = attn.transpose(1, 2).reshape(b_, n, c)
|
346 |
+
|
347 |
+
x = self.proj(attn)
|
348 |
+
x = self.proj_drop(x)
|
349 |
+
return x
|
350 |
+
|
351 |
+
|
352 |
+
class HAB(nn.Module):
|
353 |
+
r""" Hybrid Attention Block.
|
354 |
+
|
355 |
+
Args:
|
356 |
+
dim (int): Number of input channels.
|
357 |
+
input_resolution (tuple[int]): Input resolution.
|
358 |
+
num_heads (int): Number of attention heads.
|
359 |
+
window_size (int): Window size.
|
360 |
+
shift_size (int): Shift size for SW-MSA.
|
361 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
362 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
363 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
364 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
365 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
366 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
367 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
368 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
369 |
+
"""
|
370 |
+
|
371 |
+
def __init__(self,
|
372 |
+
dim,
|
373 |
+
input_resolution,
|
374 |
+
num_heads,
|
375 |
+
window_size=7,
|
376 |
+
shift_size=0,
|
377 |
+
compress_ratio=3,
|
378 |
+
squeeze_factor=30,
|
379 |
+
conv_scale=0.01,
|
380 |
+
mlp_ratio=4.,
|
381 |
+
qkv_bias=True,
|
382 |
+
qk_scale=None,
|
383 |
+
drop=0.,
|
384 |
+
attn_drop=0.,
|
385 |
+
drop_path=0.,
|
386 |
+
act_layer=nn.GELU,
|
387 |
+
norm_layer=nn.LayerNorm,
|
388 |
+
rope_mixed = True, rope_theta=10.0):
|
389 |
+
super().__init__()
|
390 |
+
self.dim = dim
|
391 |
+
self.input_resolution = input_resolution
|
392 |
+
self.num_heads = num_heads
|
393 |
+
self.window_size = window_size
|
394 |
+
self.shift_size = shift_size
|
395 |
+
self.mlp_ratio = mlp_ratio
|
396 |
+
if min(self.input_resolution) <= self.window_size:
|
397 |
+
# if window size is larger than input resolution, we don't partition windows
|
398 |
+
self.shift_size = 0
|
399 |
+
self.window_size = min(self.input_resolution)
|
400 |
+
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
401 |
+
|
402 |
+
self.norm1 = norm_layer(dim)
|
403 |
+
self.attn = WindowAttention(
|
404 |
+
dim,
|
405 |
+
window_size=to_2tuple(self.window_size),
|
406 |
+
num_heads=num_heads,
|
407 |
+
qkv_bias=qkv_bias,
|
408 |
+
qk_scale=qk_scale,
|
409 |
+
attn_drop=attn_drop,
|
410 |
+
proj_drop=drop,
|
411 |
+
rope_mixed = rope_mixed, rope_theta=rope_theta)
|
412 |
+
|
413 |
+
self.conv_scale = conv_scale
|
414 |
+
self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
|
415 |
+
|
416 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
417 |
+
self.norm2 = norm_layer(dim)
|
418 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
419 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
420 |
+
|
421 |
+
def forward(self, x, x_size, rpi_sa, attn_mask):
|
422 |
+
h, w = x_size
|
423 |
+
b, _, c = x.shape
|
424 |
+
# assert seq_len == h * w, "input feature has wrong size"
|
425 |
+
|
426 |
+
shortcut = x
|
427 |
+
x = self.norm1(x)
|
428 |
+
x = x.view(b, h, w, c)
|
429 |
+
|
430 |
+
# Conv_X
|
431 |
+
conv_x = self.conv_block(x.permute(0, 3, 1, 2).contiguous())
|
432 |
+
conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
|
433 |
+
|
434 |
+
# cyclic shift
|
435 |
+
if self.shift_size > 0:
|
436 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
437 |
+
attn_mask = attn_mask
|
438 |
+
else:
|
439 |
+
shifted_x = x
|
440 |
+
attn_mask = None
|
441 |
+
|
442 |
+
# partition windows
|
443 |
+
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
|
444 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
|
445 |
+
|
446 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
447 |
+
attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
|
448 |
+
|
449 |
+
# merge windows
|
450 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
|
451 |
+
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
|
452 |
+
|
453 |
+
# reverse cyclic shift
|
454 |
+
if self.shift_size > 0:
|
455 |
+
attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
456 |
+
else:
|
457 |
+
attn_x = shifted_x
|
458 |
+
attn_x = attn_x.view(b, h * w, c)
|
459 |
+
|
460 |
+
# FFN
|
461 |
+
x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
|
462 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
463 |
+
|
464 |
+
return x
|
465 |
+
|
466 |
+
|
467 |
+
class PatchMerging(nn.Module):
|
468 |
+
r""" Patch Merging Layer.
|
469 |
+
|
470 |
+
Args:
|
471 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
472 |
+
dim (int): Number of input channels.
|
473 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
474 |
+
"""
|
475 |
+
|
476 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
477 |
+
super().__init__()
|
478 |
+
self.input_resolution = input_resolution
|
479 |
+
self.dim = dim
|
480 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
481 |
+
self.norm = norm_layer(4 * dim)
|
482 |
+
|
483 |
+
def forward(self, x):
|
484 |
+
"""
|
485 |
+
x: b, h*w, c
|
486 |
+
"""
|
487 |
+
h, w = self.input_resolution
|
488 |
+
b, seq_len, c = x.shape
|
489 |
+
assert seq_len == h * w, 'input feature has wrong size'
|
490 |
+
assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
|
491 |
+
|
492 |
+
x = x.view(b, h, w, c)
|
493 |
+
|
494 |
+
x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
|
495 |
+
x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
|
496 |
+
x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
|
497 |
+
x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
|
498 |
+
x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
|
499 |
+
x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
|
500 |
+
|
501 |
+
x = self.norm(x)
|
502 |
+
x = self.reduction(x)
|
503 |
+
|
504 |
+
return x
|
505 |
+
|
506 |
+
|
507 |
+
class OCAB(nn.Module):
|
508 |
+
# overlapping cross-attention block
|
509 |
+
|
510 |
+
def __init__(self, dim,
|
511 |
+
input_resolution,
|
512 |
+
window_size,
|
513 |
+
overlap_ratio,
|
514 |
+
num_heads,
|
515 |
+
qkv_bias=True,
|
516 |
+
qk_scale=None,
|
517 |
+
mlp_ratio=2,
|
518 |
+
norm_layer=nn.LayerNorm,
|
519 |
+
rope_mixed = True, rope_theta = 10.0
|
520 |
+
):
|
521 |
+
|
522 |
+
super().__init__()
|
523 |
+
self.dim = dim
|
524 |
+
self.input_resolution = input_resolution
|
525 |
+
self.window_size = window_size
|
526 |
+
self.num_heads = num_heads
|
527 |
+
head_dim = dim // num_heads
|
528 |
+
self.rope_mixed = rope_mixed
|
529 |
+
|
530 |
+
self.overlap_win_size = int(window_size * overlap_ratio) + window_size
|
531 |
+
|
532 |
+
self.norm1 = norm_layer(dim)
|
533 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
534 |
+
self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
|
535 |
+
|
536 |
+
t_x, t_y = init_t_xy(end_x=max(self.window_size, self.overlap_win_size), end_y=max(self.window_size, self.overlap_win_size))
|
537 |
+
self.register_buffer('rope_t_x', t_x)
|
538 |
+
self.register_buffer('rope_t_y', t_y)
|
539 |
+
|
540 |
+
freqs = init_random_2d_freqs(
|
541 |
+
head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
|
542 |
+
rotate=self.rope_mixed
|
543 |
+
)
|
544 |
+
if self.rope_mixed:
|
545 |
+
self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
|
546 |
+
else:
|
547 |
+
self.register_buffer('rope_freqs', freqs)
|
548 |
+
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
|
549 |
+
self.rope_freqs_cis = freqs_cis
|
550 |
+
|
551 |
+
|
552 |
+
self.proj = nn.Linear(dim,dim)
|
553 |
+
|
554 |
+
self.norm2 = norm_layer(dim)
|
555 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
556 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
|
557 |
+
|
558 |
+
def forward(self, x, x_size, rpi):
|
559 |
+
h, w = x_size
|
560 |
+
b, _, c = x.shape
|
561 |
+
|
562 |
+
shortcut = x
|
563 |
+
x = self.norm1(x)
|
564 |
+
x = x.view(b, h, w, c)
|
565 |
+
|
566 |
+
qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2).contiguous() # 3, b, c, h, w
|
567 |
+
q = qkv[0].permute(0, 2, 3, 1).contiguous() # b, h, w, c
|
568 |
+
kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
|
569 |
+
|
570 |
+
# partition windows
|
571 |
+
q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c
|
572 |
+
q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
|
573 |
+
|
574 |
+
kv_windows = self.unfold(kv) # b, c*w*w, nw
|
575 |
+
kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch', nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous() # 2, nw*b, ow*ow, c
|
576 |
+
k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
|
577 |
+
|
578 |
+
b_, nq, _ = q_windows.shape
|
579 |
+
_, n, _ = k_windows.shape
|
580 |
+
# print(f"nq is {nq}, n is {n}")
|
581 |
+
d = self.dim // self.num_heads
|
582 |
+
q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3).contiguous() # nw*b, nH, nq, d
|
583 |
+
k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3).contiguous() # nw*b, nH, n, d
|
584 |
+
v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3).contiguous() # nw*b, nH, n, d
|
585 |
+
|
586 |
+
###### Apply rotary position embedding
|
587 |
+
if self.rope_mixed:
|
588 |
+
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
|
589 |
+
else:
|
590 |
+
freqs_cis = self.rope_freqs_cis.to(x.device)
|
591 |
+
q = apply_rotary_emb_single(q, freqs_cis)
|
592 |
+
k = apply_rotary_emb_single(k, freqs_cis)
|
593 |
+
#########
|
594 |
+
|
595 |
+
attn = F.scaled_dot_product_attention(q, k, v)
|
596 |
+
attn_windows = attn.transpose(1, 2).reshape(b_, nq, self.dim)
|
597 |
+
|
598 |
+
# merge windows
|
599 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
|
600 |
+
x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
|
601 |
+
x = x.view(b, h * w, self.dim)
|
602 |
+
|
603 |
+
x = self.proj(x) + shortcut
|
604 |
+
|
605 |
+
x = x + self.mlp(self.norm2(x))
|
606 |
+
return x
|
607 |
+
|
608 |
+
|
609 |
+
class AttenBlocks(nn.Module):
|
610 |
+
""" A series of attention blocks for one RHAG.
|
611 |
+
|
612 |
+
Args:
|
613 |
+
dim (int): Number of input channels.
|
614 |
+
input_resolution (tuple[int]): Input resolution.
|
615 |
+
depth (int): Number of blocks.
|
616 |
+
num_heads (int): Number of attention heads.
|
617 |
+
window_size (int): Local window size.
|
618 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
619 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
620 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
621 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
622 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
623 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
624 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
625 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
626 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
627 |
+
"""
|
628 |
+
|
629 |
+
def __init__(self,
|
630 |
+
dim,
|
631 |
+
input_resolution,
|
632 |
+
depth,
|
633 |
+
num_heads,
|
634 |
+
window_size,
|
635 |
+
compress_ratio,
|
636 |
+
squeeze_factor,
|
637 |
+
conv_scale,
|
638 |
+
overlap_ratio,
|
639 |
+
mlp_ratio=4.,
|
640 |
+
qkv_bias=True,
|
641 |
+
qk_scale=None,
|
642 |
+
drop=0.,
|
643 |
+
attn_drop=0.,
|
644 |
+
drop_path=0.,
|
645 |
+
norm_layer=nn.LayerNorm,
|
646 |
+
downsample=None,
|
647 |
+
use_checkpoint=False,
|
648 |
+
rope_mixed = True, rope_theta=10.0):
|
649 |
+
|
650 |
+
super().__init__()
|
651 |
+
self.dim = dim
|
652 |
+
self.input_resolution = input_resolution
|
653 |
+
self.depth = depth
|
654 |
+
self.use_checkpoint = use_checkpoint
|
655 |
+
|
656 |
+
# build blocks
|
657 |
+
self.blocks = nn.ModuleList([
|
658 |
+
HAB(
|
659 |
+
dim=dim,
|
660 |
+
input_resolution=input_resolution,
|
661 |
+
num_heads=num_heads,
|
662 |
+
window_size=window_size,
|
663 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
664 |
+
compress_ratio=compress_ratio,
|
665 |
+
squeeze_factor=squeeze_factor,
|
666 |
+
conv_scale=conv_scale,
|
667 |
+
mlp_ratio=mlp_ratio,
|
668 |
+
qkv_bias=qkv_bias,
|
669 |
+
qk_scale=qk_scale,
|
670 |
+
drop=drop,
|
671 |
+
attn_drop=attn_drop,
|
672 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
673 |
+
norm_layer=norm_layer,
|
674 |
+
rope_mixed = rope_mixed, rope_theta=rope_theta) for i in range(depth)
|
675 |
+
])
|
676 |
+
|
677 |
+
# OCAB
|
678 |
+
self.overlap_attn = OCAB(
|
679 |
+
dim=dim,
|
680 |
+
input_resolution=input_resolution,
|
681 |
+
window_size=window_size,
|
682 |
+
overlap_ratio=overlap_ratio,
|
683 |
+
num_heads=num_heads,
|
684 |
+
qkv_bias=qkv_bias,
|
685 |
+
qk_scale=qk_scale,
|
686 |
+
mlp_ratio=mlp_ratio,
|
687 |
+
norm_layer=norm_layer,
|
688 |
+
rope_mixed = rope_mixed, rope_theta = rope_theta)
|
689 |
+
|
690 |
+
|
691 |
+
# patch merging layer
|
692 |
+
if downsample is not None:
|
693 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
694 |
+
else:
|
695 |
+
self.downsample = None
|
696 |
+
|
697 |
+
def forward(self, x, x_size, params):
|
698 |
+
for blk in self.blocks:
|
699 |
+
x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
|
700 |
+
|
701 |
+
|
702 |
+
x = self.overlap_attn(x, x_size, params['rpi_oca'])
|
703 |
+
|
704 |
+
|
705 |
+
if self.downsample is not None:
|
706 |
+
x = self.downsample(x)
|
707 |
+
return x
|
708 |
+
|
709 |
+
|
710 |
+
class RHAG(nn.Module):
|
711 |
+
"""Residual Hybrid Attention Group (RHAG).
|
712 |
+
|
713 |
+
Args:
|
714 |
+
dim (int): Number of input channels.
|
715 |
+
input_resolution (tuple[int]): Input resolution.
|
716 |
+
depth (int): Number of blocks.
|
717 |
+
num_heads (int): Number of attention heads.
|
718 |
+
window_size (int): Local window size.
|
719 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
720 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
721 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
722 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
723 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
724 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
725 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
726 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
727 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
728 |
+
img_size: Input image size.
|
729 |
+
patch_size: Patch size.
|
730 |
+
resi_connection: The convolutional block before residual connection.
|
731 |
+
"""
|
732 |
+
|
733 |
+
def __init__(self,
|
734 |
+
dim,
|
735 |
+
input_resolution,
|
736 |
+
depth,
|
737 |
+
num_heads,
|
738 |
+
window_size,
|
739 |
+
compress_ratio,
|
740 |
+
squeeze_factor,
|
741 |
+
conv_scale,
|
742 |
+
overlap_ratio,
|
743 |
+
mlp_ratio=4.,
|
744 |
+
qkv_bias=True,
|
745 |
+
qk_scale=None,
|
746 |
+
drop=0.,
|
747 |
+
attn_drop=0.,
|
748 |
+
drop_path=0.,
|
749 |
+
norm_layer=nn.LayerNorm,
|
750 |
+
downsample=None,
|
751 |
+
use_checkpoint=False,
|
752 |
+
img_size=224,
|
753 |
+
patch_size=4,
|
754 |
+
resi_connection='1conv',
|
755 |
+
rope_mixed = True, rope_theta=10.0):
|
756 |
+
super(RHAG, self).__init__()
|
757 |
+
|
758 |
+
self.dim = dim
|
759 |
+
self.input_resolution = input_resolution
|
760 |
+
|
761 |
+
self.residual_group = AttenBlocks(
|
762 |
+
dim=dim,
|
763 |
+
input_resolution=input_resolution,
|
764 |
+
depth=depth,
|
765 |
+
num_heads=num_heads,
|
766 |
+
window_size=window_size,
|
767 |
+
compress_ratio=compress_ratio,
|
768 |
+
squeeze_factor=squeeze_factor,
|
769 |
+
conv_scale=conv_scale,
|
770 |
+
overlap_ratio=overlap_ratio,
|
771 |
+
mlp_ratio=mlp_ratio,
|
772 |
+
qkv_bias=qkv_bias,
|
773 |
+
qk_scale=qk_scale,
|
774 |
+
drop=drop,
|
775 |
+
attn_drop=attn_drop,
|
776 |
+
drop_path=drop_path,
|
777 |
+
norm_layer=norm_layer,
|
778 |
+
downsample=downsample,
|
779 |
+
use_checkpoint=use_checkpoint,
|
780 |
+
rope_mixed = rope_mixed, rope_theta=rope_theta)
|
781 |
+
|
782 |
+
if resi_connection == '1conv':
|
783 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
784 |
+
elif resi_connection == 'identity':
|
785 |
+
self.conv = nn.Identity()
|
786 |
+
|
787 |
+
self.patch_embed = PatchEmbed(
|
788 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
789 |
+
|
790 |
+
self.patch_unembed = PatchUnEmbed(
|
791 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
792 |
+
|
793 |
+
def forward(self, x, x_size, params):
|
794 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
|
795 |
+
|
796 |
+
|
797 |
+
class PatchEmbed(nn.Module):
|
798 |
+
r""" Image to Patch Embedding
|
799 |
+
|
800 |
+
Args:
|
801 |
+
img_size (int): Image size. Default: 224.
|
802 |
+
patch_size (int): Patch token size. Default: 4.
|
803 |
+
in_chans (int): Number of input image channels. Default: 3.
|
804 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
805 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
806 |
+
"""
|
807 |
+
|
808 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
809 |
+
super().__init__()
|
810 |
+
img_size = to_2tuple(img_size)
|
811 |
+
patch_size = to_2tuple(patch_size)
|
812 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
813 |
+
self.img_size = img_size
|
814 |
+
self.patch_size = patch_size
|
815 |
+
self.patches_resolution = patches_resolution
|
816 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
817 |
+
|
818 |
+
self.in_chans = in_chans
|
819 |
+
self.embed_dim = embed_dim
|
820 |
+
|
821 |
+
if norm_layer is not None:
|
822 |
+
self.norm = norm_layer(embed_dim)
|
823 |
+
else:
|
824 |
+
self.norm = None
|
825 |
+
|
826 |
+
def forward(self, x):
|
827 |
+
x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
|
828 |
+
if self.norm is not None:
|
829 |
+
x = self.norm(x)
|
830 |
+
return x
|
831 |
+
|
832 |
+
|
833 |
+
class PatchUnEmbed(nn.Module):
|
834 |
+
r""" Image to Patch Unembedding
|
835 |
+
|
836 |
+
Args:
|
837 |
+
img_size (int): Image size. Default: 224.
|
838 |
+
patch_size (int): Patch token size. Default: 4.
|
839 |
+
in_chans (int): Number of input image channels. Default: 3.
|
840 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
841 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
842 |
+
"""
|
843 |
+
|
844 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
845 |
+
super().__init__()
|
846 |
+
img_size = to_2tuple(img_size)
|
847 |
+
patch_size = to_2tuple(patch_size)
|
848 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
849 |
+
self.img_size = img_size
|
850 |
+
self.patch_size = patch_size
|
851 |
+
self.patches_resolution = patches_resolution
|
852 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
853 |
+
|
854 |
+
self.in_chans = in_chans
|
855 |
+
self.embed_dim = embed_dim
|
856 |
+
|
857 |
+
def forward(self, x, x_size):
|
858 |
+
x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
|
859 |
+
return x
|
860 |
+
|
861 |
+
|
862 |
+
class Upsample(nn.Sequential):
|
863 |
+
"""Upsample module.
|
864 |
+
|
865 |
+
Args:
|
866 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
867 |
+
num_feat (int): Channel number of intermediate features.
|
868 |
+
"""
|
869 |
+
|
870 |
+
def __init__(self, scale, num_feat):
|
871 |
+
m = []
|
872 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
873 |
+
for _ in range(int(math.log(scale, 2))):
|
874 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
875 |
+
m.append(nn.PixelShuffle(2))
|
876 |
+
elif scale == 3:
|
877 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
878 |
+
m.append(nn.PixelShuffle(3))
|
879 |
+
else:
|
880 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
881 |
+
super(Upsample, self).__init__(*m)
|
882 |
+
|
883 |
+
|
884 |
+
|
885 |
+
|
886 |
+
class HATNOUP_ROPE_AMP(nn.Module):
|
887 |
+
def __init__(self,
|
888 |
+
img_size=64,
|
889 |
+
patch_size=1,
|
890 |
+
in_chans=3,
|
891 |
+
embed_dim=192,
|
892 |
+
depths=(6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6),
|
893 |
+
num_heads=(6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6),
|
894 |
+
window_size=16,
|
895 |
+
compress_ratio=3,
|
896 |
+
squeeze_factor=32,
|
897 |
+
conv_scale=0.01,
|
898 |
+
overlap_ratio=0.5,
|
899 |
+
mlp_ratio=2,
|
900 |
+
qkv_bias=True,
|
901 |
+
qk_scale=None,
|
902 |
+
drop_rate=0.,
|
903 |
+
attn_drop_rate=0.,
|
904 |
+
drop_path_rate=0.1,
|
905 |
+
norm_layer=nn.LayerNorm,
|
906 |
+
ape=False,
|
907 |
+
patch_norm=True,
|
908 |
+
use_checkpoint=False,
|
909 |
+
upscale=4,
|
910 |
+
img_range=1.,
|
911 |
+
upsampler='pixelshuffle',
|
912 |
+
resi_connection='1conv',
|
913 |
+
rope_mixed = True,
|
914 |
+
rope_theta=10.0,
|
915 |
+
**kwargs):
|
916 |
+
super(HATNOUP_ROPE_AMP, self).__init__()
|
917 |
+
|
918 |
+
self.window_size = window_size
|
919 |
+
self.shift_size = window_size // 2
|
920 |
+
self.overlap_ratio = overlap_ratio
|
921 |
+
|
922 |
+
num_in_ch = in_chans
|
923 |
+
num_out_ch = in_chans
|
924 |
+
num_feat = 64
|
925 |
+
self.img_range = img_range
|
926 |
+
if in_chans == 3:
|
927 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
928 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
929 |
+
else:
|
930 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
931 |
+
self.upscale = upscale
|
932 |
+
self.upsampler = upsampler
|
933 |
+
|
934 |
+
# relative position index
|
935 |
+
relative_position_index_SA = self.calculate_rpi_sa()
|
936 |
+
relative_position_index_OCA = self.calculate_rpi_oca()
|
937 |
+
self.register_buffer('relative_position_index_SA', relative_position_index_SA)
|
938 |
+
self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
|
939 |
+
|
940 |
+
# ------------------------- 1, shallow feature extraction ------------------------- #
|
941 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
942 |
+
|
943 |
+
# ------------------------- 2, deep feature extraction ------------------------- #
|
944 |
+
self.num_layers = len(depths)
|
945 |
+
self.embed_dim = embed_dim
|
946 |
+
self.ape = ape
|
947 |
+
self.patch_norm = patch_norm
|
948 |
+
self.num_features = embed_dim
|
949 |
+
self.mlp_ratio = mlp_ratio
|
950 |
+
|
951 |
+
# split image into non-overlapping patches
|
952 |
+
self.patch_embed = PatchEmbed(
|
953 |
+
img_size=img_size,
|
954 |
+
patch_size=patch_size,
|
955 |
+
in_chans=embed_dim,
|
956 |
+
embed_dim=embed_dim,
|
957 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
958 |
+
num_patches = self.patch_embed.num_patches
|
959 |
+
patches_resolution = self.patch_embed.patches_resolution
|
960 |
+
self.patches_resolution = patches_resolution
|
961 |
+
|
962 |
+
# merge non-overlapping patches into image
|
963 |
+
self.patch_unembed = PatchUnEmbed(
|
964 |
+
img_size=img_size,
|
965 |
+
patch_size=patch_size,
|
966 |
+
in_chans=embed_dim,
|
967 |
+
embed_dim=embed_dim,
|
968 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
969 |
+
|
970 |
+
# absolute position embedding
|
971 |
+
if self.ape:
|
972 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
973 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
974 |
+
|
975 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
976 |
+
|
977 |
+
# stochastic depth
|
978 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
979 |
+
|
980 |
+
# build Residual Hybrid Attention Groups (RHAG)
|
981 |
+
self.layers = nn.ModuleList()
|
982 |
+
for i_layer in range(self.num_layers):
|
983 |
+
layer = RHAG(
|
984 |
+
dim=embed_dim,
|
985 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
986 |
+
depth=depths[i_layer],
|
987 |
+
num_heads=num_heads[i_layer],
|
988 |
+
window_size=window_size,
|
989 |
+
compress_ratio=compress_ratio,
|
990 |
+
squeeze_factor=squeeze_factor,
|
991 |
+
conv_scale=conv_scale,
|
992 |
+
overlap_ratio=overlap_ratio,
|
993 |
+
mlp_ratio=self.mlp_ratio,
|
994 |
+
qkv_bias=qkv_bias,
|
995 |
+
qk_scale=qk_scale,
|
996 |
+
drop=drop_rate,
|
997 |
+
attn_drop=attn_drop_rate,
|
998 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
999 |
+
norm_layer=norm_layer,
|
1000 |
+
downsample=None,
|
1001 |
+
use_checkpoint=use_checkpoint,
|
1002 |
+
img_size=img_size,
|
1003 |
+
patch_size=patch_size,
|
1004 |
+
resi_connection=resi_connection,
|
1005 |
+
rope_mixed = rope_mixed, rope_theta=rope_theta)
|
1006 |
+
self.layers.append(layer)
|
1007 |
+
self.norm = norm_layer(self.num_features)
|
1008 |
+
|
1009 |
+
self.use_checkpoint = use_checkpoint
|
1010 |
+
|
1011 |
+
# build the last conv layer in deep feature extraction
|
1012 |
+
if resi_connection == '1conv':
|
1013 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
1014 |
+
elif resi_connection == 'identity':
|
1015 |
+
self.conv_after_body = nn.Identity()
|
1016 |
+
|
1017 |
+
# ------------------------- 3, high quality image reconstruction ------------------------- #
|
1018 |
+
if self.upsampler == 'pixelshuffle':
|
1019 |
+
# for classical SR
|
1020 |
+
self.conv_before_upsample = nn.Sequential(
|
1021 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
1022 |
+
# self.upsample = Upsample(upscale, num_feat)
|
1023 |
+
# self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
1024 |
+
|
1025 |
+
self.apply(self._init_weights)
|
1026 |
+
|
1027 |
+
def _init_weights(self, m):
|
1028 |
+
if isinstance(m, nn.Linear):
|
1029 |
+
trunc_normal_(m.weight, std=.02)
|
1030 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
1031 |
+
nn.init.constant_(m.bias, 0)
|
1032 |
+
elif isinstance(m, nn.LayerNorm):
|
1033 |
+
nn.init.constant_(m.bias, 0)
|
1034 |
+
nn.init.constant_(m.weight, 1.0)
|
1035 |
+
|
1036 |
+
def calculate_rpi_sa(self):
|
1037 |
+
# calculate relative position index for SA
|
1038 |
+
coords_h = torch.arange(self.window_size)
|
1039 |
+
coords_w = torch.arange(self.window_size)
|
1040 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
1041 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
1042 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
1043 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
1044 |
+
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
|
1045 |
+
relative_coords[:, :, 1] += self.window_size - 1
|
1046 |
+
relative_coords[:, :, 0] *= 2 * self.window_size - 1
|
1047 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
1048 |
+
return relative_position_index
|
1049 |
+
|
1050 |
+
def calculate_rpi_oca(self):
|
1051 |
+
# calculate relative position index for OCA
|
1052 |
+
window_size_ori = self.window_size
|
1053 |
+
window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
|
1054 |
+
|
1055 |
+
coords_h = torch.arange(window_size_ori)
|
1056 |
+
coords_w = torch.arange(window_size_ori)
|
1057 |
+
coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
|
1058 |
+
coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
|
1059 |
+
|
1060 |
+
coords_h = torch.arange(window_size_ext)
|
1061 |
+
coords_w = torch.arange(window_size_ext)
|
1062 |
+
coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
|
1063 |
+
coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
|
1064 |
+
|
1065 |
+
relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None] # 2, ws*ws, wse*wse
|
1066 |
+
|
1067 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # ws*ws, wse*wse, 2
|
1068 |
+
relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1 # shift to start from 0
|
1069 |
+
relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
|
1070 |
+
|
1071 |
+
relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
|
1072 |
+
relative_position_index = relative_coords.sum(-1)
|
1073 |
+
return relative_position_index
|
1074 |
+
|
1075 |
+
def calculate_mask(self, x_size):
|
1076 |
+
# calculate attention mask for SW-MSA
|
1077 |
+
h, w = x_size
|
1078 |
+
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
|
1079 |
+
h_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
1080 |
+
-self.shift_size), slice(-self.shift_size, None))
|
1081 |
+
w_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
1082 |
+
-self.shift_size), slice(-self.shift_size, None))
|
1083 |
+
cnt = 0
|
1084 |
+
for h in h_slices:
|
1085 |
+
for w in w_slices:
|
1086 |
+
img_mask[:, h, w, :] = cnt
|
1087 |
+
cnt += 1
|
1088 |
+
|
1089 |
+
mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
|
1090 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
1091 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
1092 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
1093 |
+
|
1094 |
+
return attn_mask
|
1095 |
+
|
1096 |
+
@torch.jit.ignore
|
1097 |
+
def no_weight_decay(self):
|
1098 |
+
return {'absolute_pos_embed'}
|
1099 |
+
|
1100 |
+
@torch.jit.ignore
|
1101 |
+
def no_weight_decay_keywords(self):
|
1102 |
+
return {'relative_position_bias_table'}
|
1103 |
+
|
1104 |
+
def forward_features(self, x):
|
1105 |
+
x_size = (x.shape[2], x.shape[3])
|
1106 |
+
|
1107 |
+
# Calculate attention mask and relative position index in advance to speed up inference.
|
1108 |
+
# The original code is very time-consuming for large window size.
|
1109 |
+
attn_mask = self.calculate_mask(x_size).to(x.device)
|
1110 |
+
params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
|
1111 |
+
|
1112 |
+
x = self.patch_embed(x)
|
1113 |
+
if self.ape:
|
1114 |
+
x = x + self.absolute_pos_embed
|
1115 |
+
x = self.pos_drop(x)
|
1116 |
+
|
1117 |
+
for layer in self.layers:
|
1118 |
+
x = layer(x, x_size, params)
|
1119 |
+
|
1120 |
+
x = self.norm(x) # b seq_len c
|
1121 |
+
x = self.patch_unembed(x, x_size)
|
1122 |
+
|
1123 |
+
return x
|
1124 |
+
|
1125 |
+
def forward(self, x):
|
1126 |
+
# self.mean = self.mean.type_as(x)
|
1127 |
+
# x = (x - self.mean) * self.img_range
|
1128 |
+
|
1129 |
+
if self.upsampler == 'pixelshuffle':
|
1130 |
+
# for classical SR
|
1131 |
+
x = self.conv_first(x)
|
1132 |
+
if self.use_checkpoint:
|
1133 |
+
x = self.conv_after_body(checkpoint(self.forward_features, x)) + x
|
1134 |
+
else:
|
1135 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
1136 |
+
x = self.conv_before_upsample(x)
|
1137 |
+
# x = self.conv_last(self.upsample(x))
|
1138 |
+
|
1139 |
+
# x = x / self.img_range + self.mean
|
1140 |
+
|
1141 |
+
return x
|
1142 |
+
|
1143 |
+
|
1144 |
+
if __name__ == '__main__':
|
1145 |
+
srcs = torch.randn(8, 3, 64, 64).cuda()
|
1146 |
+
encoder = HATNOUP_ROPE_AMP(upscale=4, in_chans=3, img_size=64, window_size=16, compress_ratio=3, squeeze_factor=32, conv_scale=0.01, overlap_ratio=0.5,
|
1147 |
+
img_range=1.,
|
1148 |
+
depths=(6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6),
|
1149 |
+
embed_dim=192,
|
1150 |
+
num_heads=(6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6),
|
1151 |
+
mlp_ratio=2,
|
1152 |
+
upsampler='pixelshuffle',
|
1153 |
+
resi_connection='1conv',
|
1154 |
+
use_checkpoint=False).cuda()
|
1155 |
+
feature = encoder(srcs)
|
1156 |
+
pass
|
utils/rdn.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
from distutils.version import LooseVersion
|
7 |
+
from itertools import repeat
|
8 |
+
from torch import nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn import init as init
|
11 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
12 |
+
|
13 |
+
class RDB_Conv(nn.Module):
|
14 |
+
def __init__(self, inChannels, growRate, kSize=3):
|
15 |
+
super(RDB_Conv, self).__init__()
|
16 |
+
Cin = inChannels
|
17 |
+
G = growRate
|
18 |
+
self.conv = nn.Sequential(*[
|
19 |
+
nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
|
20 |
+
nn.ReLU()
|
21 |
+
])
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
out = self.conv(x)
|
25 |
+
return torch.cat((x, out), 1)
|
26 |
+
|
27 |
+
class RDB(nn.Module):
|
28 |
+
def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
|
29 |
+
super(RDB, self).__init__()
|
30 |
+
G0 = growRate0
|
31 |
+
G = growRate
|
32 |
+
C = nConvLayers
|
33 |
+
|
34 |
+
convs = []
|
35 |
+
for c in range(C):
|
36 |
+
convs.append(RDB_Conv(G0 + c*G, G))
|
37 |
+
self.convs = nn.Sequential(*convs)
|
38 |
+
|
39 |
+
# Local Feature Fusion
|
40 |
+
self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return self.LFF(self.convs(x)) + x
|
44 |
+
|
45 |
+
class RDNNOUP(nn.Module):
|
46 |
+
def __init__(self, G0 = 64, kSize = 3, r = 4, n_colors = 3, RDNconfig = 'B',
|
47 |
+
no_upsampling = True, img_range = 1.0):
|
48 |
+
super(RDNNOUP, self).__init__()
|
49 |
+
|
50 |
+
self.no_upsampling = no_upsampling
|
51 |
+
self.img_range = img_range
|
52 |
+
|
53 |
+
# number of RDB blocks, conv layers, out channels
|
54 |
+
self.D, C, G = {
|
55 |
+
'A': (20, 6, 32),
|
56 |
+
'B': (16, 8, 64),
|
57 |
+
}[RDNconfig]
|
58 |
+
|
59 |
+
# Shallow feature extraction net
|
60 |
+
self.SFENet1 = nn.Conv2d(n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
|
61 |
+
self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
|
62 |
+
|
63 |
+
# Redidual dense blocks and dense feature fusion
|
64 |
+
self.RDBs = nn.ModuleList()
|
65 |
+
for i in range(self.D):
|
66 |
+
self.RDBs.append(
|
67 |
+
RDB(growRate0 = G0, growRate = G, nConvLayers = C)
|
68 |
+
)
|
69 |
+
|
70 |
+
# Global Feature Fusion
|
71 |
+
self.GFF = nn.Sequential(*[
|
72 |
+
nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
|
73 |
+
nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
|
74 |
+
])
|
75 |
+
|
76 |
+
if no_upsampling:
|
77 |
+
self.out_dim = G0
|
78 |
+
else:
|
79 |
+
self.out_dim = n_colors
|
80 |
+
# Up-sampling net
|
81 |
+
if r == 2 or r == 3:
|
82 |
+
self.UPNet = nn.Sequential(*[
|
83 |
+
nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
|
84 |
+
nn.PixelShuffle(r),
|
85 |
+
nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)
|
86 |
+
])
|
87 |
+
elif r == 4:
|
88 |
+
self.UPNet = nn.Sequential(*[
|
89 |
+
nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
|
90 |
+
nn.PixelShuffle(2),
|
91 |
+
nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
|
92 |
+
nn.PixelShuffle(2),
|
93 |
+
nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)
|
94 |
+
])
|
95 |
+
else:
|
96 |
+
raise ValueError("scale must be 2 or 3 or 4.")
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
x = x * self.img_range
|
100 |
+
f__1 = self.SFENet1(x)
|
101 |
+
x = self.SFENet2(f__1)
|
102 |
+
|
103 |
+
RDBs_out = []
|
104 |
+
for i in range(self.D):
|
105 |
+
x = self.RDBs[i](x)
|
106 |
+
RDBs_out.append(x)
|
107 |
+
|
108 |
+
x = self.GFF(torch.cat(RDBs_out,1))
|
109 |
+
x += f__1
|
110 |
+
|
111 |
+
if self.no_upsampling:
|
112 |
+
return x
|
113 |
+
else:
|
114 |
+
return self.UPNet(x)
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
x = torch.randn(8,3,48,48)
|
118 |
+
model = RDNNOUP()
|
119 |
+
y = model(x)
|
120 |
+
print(y.shape)
|
utils/split_and_joint_image.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
|
5 |
+
from utils.gaussian_splatting import generate_2D_gaussian_splatting_step, generate_2D_gaussian_splatting_step_buffer
|
6 |
+
|
7 |
+
|
8 |
+
### If the GPU memory is limited, please use the following code to do tiling process for input LR image
|
9 |
+
# def split_and_joint_image(lq, scale_factor, model_g, model_fea2gs, scale_modify, split_size = 48,
|
10 |
+
# overlap_size = 8,
|
11 |
+
# crop_size = 4,
|
12 |
+
# default_step_size = 1.2, mode = 'scale_modify',
|
13 |
+
# cuda_rendering = True,
|
14 |
+
# if_dmax = False,
|
15 |
+
# dmax_mode = 'fix',
|
16 |
+
# dmax = 0.1):
|
17 |
+
# h_lq, w_lq = lq.shape[-2:]
|
18 |
+
|
19 |
+
# assert overlap_size > 0 and overlap_size < split_size // 2, f"overlap size is wrong"
|
20 |
+
|
21 |
+
# tile_nums_h = math.ceil((h_lq - overlap_size) / (split_size - overlap_size))
|
22 |
+
# tile_nums_w = math.ceil((w_lq - overlap_size) / (split_size - overlap_size))
|
23 |
+
|
24 |
+
# pad_h_lq = tile_nums_h * (split_size - overlap_size) + overlap_size - h_lq
|
25 |
+
# pad_w_lq = tile_nums_w * (split_size - overlap_size) + overlap_size - w_lq
|
26 |
+
|
27 |
+
# lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='reflect')
|
28 |
+
|
29 |
+
# split_size_sr = math.ceil(split_size * scale_factor)
|
30 |
+
# sr_tile_list = []
|
31 |
+
# for h_num in range(tile_nums_h):
|
32 |
+
# for w_num in range(tile_nums_w):
|
33 |
+
# tile_lq_position_start_h = h_num * (split_size - overlap_size)
|
34 |
+
# tile_lq_position_start_w = w_num * (split_size - overlap_size)
|
35 |
+
# tile_lq_position_end_h = tile_lq_position_start_h + split_size
|
36 |
+
# tile_lq_position_end_w = tile_lq_position_start_w + split_size
|
37 |
+
|
38 |
+
# input_tile = lq_pad[:,:, tile_lq_position_start_h:tile_lq_position_end_h, tile_lq_position_start_w:tile_lq_position_end_w]
|
39 |
+
|
40 |
+
# model_g_output = model_g(input_tile)
|
41 |
+
|
42 |
+
# scale_vector = scale_modify[0].unsqueeze(0).to(model_g_output.device)
|
43 |
+
# batch_gs_parameters = model_fea2gs(model_g_output, scale_vector)
|
44 |
+
|
45 |
+
# gs_parameters = batch_gs_parameters[0, :]
|
46 |
+
# b_output = generate_2D_gaussian_splatting_step(sr_size=torch.tensor([split_size_sr, split_size_sr]), gs_parameters=gs_parameters,
|
47 |
+
# lq=input_tile[0, :], scale=scale_factor, sample_coords=None,
|
48 |
+
# scale_modify = scale_modify,
|
49 |
+
# default_step_size = default_step_size, mode = mode,
|
50 |
+
# cuda_rendering = cuda_rendering,
|
51 |
+
# if_dmax = if_dmax,
|
52 |
+
# dmax_mode = dmax_mode,
|
53 |
+
# dmax = dmax)
|
54 |
+
# sr_tile_list.append(b_output.unsqueeze(0))
|
55 |
+
|
56 |
+
# tile_sr_h = sr_tile_list[0].shape[2]
|
57 |
+
# tile_sr_w = sr_tile_list[0].shape[3]
|
58 |
+
|
59 |
+
# assert tile_sr_w == split_size_sr and tile_sr_h == split_size_sr, \
|
60 |
+
# f'tile_sr_h-{tile_sr_w}, tile_sr_w-{tile_sr_w}, split_size_sr-{split_size_sr} is not the same'
|
61 |
+
|
62 |
+
# overlap_sr = math.ceil(overlap_size * scale_factor)
|
63 |
+
|
64 |
+
# sr_pad = torch.zeros(lq.shape[0], lq.shape[1],
|
65 |
+
# math.ceil(lq_pad.shape[2] * scale_factor),
|
66 |
+
# math.ceil(lq_pad.shape[3] * scale_factor),
|
67 |
+
# device=lq.device)
|
68 |
+
|
69 |
+
# idx = 0
|
70 |
+
# for h_num in range(tile_nums_h):
|
71 |
+
# for w_num in range(tile_nums_w):
|
72 |
+
# tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
|
73 |
+
# tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
|
74 |
+
# tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
|
75 |
+
# tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
|
76 |
+
# if h_num == 0 and w_num == 0:
|
77 |
+
# sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
|
78 |
+
# tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
|
79 |
+
# elif h_num == 0 and w_num !=0:
|
80 |
+
# sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
|
81 |
+
# tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
|
82 |
+
# elif h_num != 0 and w_num ==0:
|
83 |
+
# sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
|
84 |
+
# tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
|
85 |
+
# else:
|
86 |
+
# sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
|
87 |
+
# tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
|
88 |
+
# idx = idx + 1
|
89 |
+
|
90 |
+
# print(f"sr_pad shape is {sr_pad.shape}")
|
91 |
+
|
92 |
+
# # sr_final = sr_pad[:,:, 0:math.ceil(h_lq * scale_factor), 0: math.ceil(w_lq * scale_factor)]
|
93 |
+
# sr_final = sr_pad
|
94 |
+
|
95 |
+
# return sr_final
|
96 |
+
|
97 |
+
|
98 |
+
def split_and_joint_image(lq, scale_factor, split_size,
|
99 |
+
overlap_size, model_g, model_fea2gs,
|
100 |
+
scale_modify, crop_size = 2,
|
101 |
+
default_step_size = 1.2, mode = 'scale_modify',
|
102 |
+
cuda_rendering = True,
|
103 |
+
if_dmax = False,
|
104 |
+
dmax_mode = 'fix',
|
105 |
+
dmax = 25):
|
106 |
+
h_lq, w_lq = lq.shape[-2:]
|
107 |
+
|
108 |
+
# assert h_lq > split_size, f'h_lq-{h_lq} should be larger than split_size-{split_size}, please do not use tile_process, or decrease the split_size'
|
109 |
+
# assert w_lq > split_size, f'w_lq-{w_lq} should be larger than split_size-{split_size}, please do not use tile_process, or decrease the split_size'
|
110 |
+
|
111 |
+
assert overlap_size > 0 and overlap_size < split_size // 2, f"overlap size is wrong"
|
112 |
+
|
113 |
+
tile_nums_h = math.ceil((h_lq - overlap_size) / (split_size - overlap_size))
|
114 |
+
tile_nums_w = math.ceil((w_lq - overlap_size) / (split_size - overlap_size))
|
115 |
+
|
116 |
+
pad_h_lq = tile_nums_h * (split_size - overlap_size) + overlap_size - h_lq
|
117 |
+
pad_w_lq = tile_nums_w * (split_size - overlap_size) + overlap_size - w_lq
|
118 |
+
|
119 |
+
assert pad_h_lq < h_lq, f'pad_h_lq-{pad_h_lq} should be smaller than h_lq-{h_lq}, please decrease the split_size-{split_size}'
|
120 |
+
assert pad_w_lq < w_lq, f'pad_w_lq-{pad_w_lq} should be smaller than w_lq-{w_lq}, please decrease the split_size-{split_size}'
|
121 |
+
|
122 |
+
lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='reflect')
|
123 |
+
# lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='constant', value=0)
|
124 |
+
|
125 |
+
split_size_sr = math.ceil(split_size * scale_factor)
|
126 |
+
sr_tile_list = []
|
127 |
+
for h_num in range(tile_nums_h):
|
128 |
+
for w_num in range(tile_nums_w):
|
129 |
+
tile_lq_position_start_h = h_num * (split_size - overlap_size)
|
130 |
+
tile_lq_position_start_w = w_num * (split_size - overlap_size)
|
131 |
+
tile_lq_position_end_h = tile_lq_position_start_h + split_size
|
132 |
+
tile_lq_position_end_w = tile_lq_position_start_w + split_size
|
133 |
+
|
134 |
+
input_tile = lq_pad[:,:, tile_lq_position_start_h:tile_lq_position_end_h, tile_lq_position_start_w:tile_lq_position_end_w]
|
135 |
+
|
136 |
+
model_g_output = model_g(input_tile)
|
137 |
+
|
138 |
+
scale_vector = scale_modify[0].unsqueeze(0).to(model_g_output.device)
|
139 |
+
batch_gs_parameters = model_fea2gs(model_g_output, scale_vector)
|
140 |
+
|
141 |
+
|
142 |
+
gs_parameters = batch_gs_parameters[0, :]
|
143 |
+
b_output = generate_2D_gaussian_splatting_step(sr_size=torch.tensor([split_size_sr, split_size_sr]), gs_parameters=gs_parameters,
|
144 |
+
scale=scale_factor, sample_coords=None,
|
145 |
+
scale_modify = scale_modify,
|
146 |
+
default_step_size = default_step_size, mode = mode,
|
147 |
+
cuda_rendering = cuda_rendering,
|
148 |
+
if_dmax = if_dmax,
|
149 |
+
dmax_mode = dmax_mode,
|
150 |
+
dmax = dmax)
|
151 |
+
sr_tile_list.append(b_output.unsqueeze(0))
|
152 |
+
|
153 |
+
tile_sr_h = sr_tile_list[0].shape[2]
|
154 |
+
tile_sr_w = sr_tile_list[0].shape[3]
|
155 |
+
|
156 |
+
assert tile_sr_w == split_size_sr and tile_sr_h == split_size_sr, \
|
157 |
+
f'tile_sr_h-{tile_sr_w}, tile_sr_w-{tile_sr_w}, split_size_sr-{split_size_sr} is not the same'
|
158 |
+
|
159 |
+
overlap_sr = math.ceil(overlap_size * scale_factor)
|
160 |
+
|
161 |
+
sr_pad = torch.zeros(lq.shape[0], lq.shape[1],
|
162 |
+
(tile_nums_h - 1) * (split_size_sr - overlap_sr) + split_size_sr,
|
163 |
+
(tile_nums_w - 1) * (split_size_sr - overlap_sr) + split_size_sr,
|
164 |
+
device=lq.device)
|
165 |
+
|
166 |
+
idx = 0
|
167 |
+
|
168 |
+
if scale_factor != int(scale_factor):
|
169 |
+
for h_num in range(tile_nums_h):
|
170 |
+
for w_num in range(tile_nums_w):
|
171 |
+
tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
|
172 |
+
tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
|
173 |
+
tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
|
174 |
+
tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
|
175 |
+
if h_num == 0 and w_num == 0:
|
176 |
+
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
|
177 |
+
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
|
178 |
+
elif h_num == 0 and w_num !=0:
|
179 |
+
if w_num != tile_nums_w - 1:
|
180 |
+
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
|
181 |
+
tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
|
182 |
+
else:
|
183 |
+
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
|
184 |
+
tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,:,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
|
185 |
+
elif h_num != 0 and w_num ==0:
|
186 |
+
if h_num != tile_nums_h - 1:
|
187 |
+
sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
|
188 |
+
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
|
189 |
+
else:
|
190 |
+
sr_pad[:, :, tile_sr_position_start_h+crop_size:sr_pad.shape[2],
|
191 |
+
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,:]
|
192 |
+
else:
|
193 |
+
if w_num != tile_nums_w - 1 and h_num != tile_nums_h - 1:
|
194 |
+
sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
|
195 |
+
tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
|
196 |
+
elif w_num == tile_nums_w - 1 and h_num != tile_nums_h - 1:
|
197 |
+
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
|
198 |
+
tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,:,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
|
199 |
+
elif w_num != tile_nums_w - 1 and h_num == tile_nums_h - 1:
|
200 |
+
sr_pad[:, :, tile_sr_position_start_h+crop_size:sr_pad.shape[2],
|
201 |
+
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,:]
|
202 |
+
elif w_num == tile_nums_w - 1 and h_num == tile_nums_h - 1:
|
203 |
+
sr_pad[:,:,tile_sr_position_start_h+crop_size:sr_pad.shape[2],
|
204 |
+
tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
|
205 |
+
idx = idx + 1
|
206 |
+
else:
|
207 |
+
for h_num in range(tile_nums_h):
|
208 |
+
for w_num in range(tile_nums_w):
|
209 |
+
tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
|
210 |
+
tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
|
211 |
+
tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
|
212 |
+
tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
|
213 |
+
if h_num == 0 and w_num == 0:
|
214 |
+
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
|
215 |
+
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
|
216 |
+
elif h_num == 0 and w_num !=0:
|
217 |
+
sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
|
218 |
+
tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
|
219 |
+
elif h_num != 0 and w_num ==0:
|
220 |
+
sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
|
221 |
+
tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
|
222 |
+
else:
|
223 |
+
sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
|
224 |
+
tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
|
225 |
+
idx = idx + 1
|
226 |
+
|
227 |
+
print(f"sr_pad shape is {sr_pad.shape}")
|
228 |
+
|
229 |
+
# sr_final = sr_pad[:,:, 0:math.ceil(h_lq * scale_factor), 0: math.ceil(w_lq * scale_factor)]
|
230 |
+
sr_final = sr_pad
|
231 |
+
|
232 |
+
return sr_final
|
utils/swinir.py
ADDED
@@ -0,0 +1,1243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/JingyunLiang/SwinIR
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
|
5 |
+
import collections.abc
|
6 |
+
import torchvision
|
7 |
+
import warnings
|
8 |
+
from distutils.version import LooseVersion
|
9 |
+
from itertools import repeat
|
10 |
+
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
|
16 |
+
# From PyTorch
|
17 |
+
def _ntuple(n):
|
18 |
+
|
19 |
+
def parse(x):
|
20 |
+
if isinstance(x, collections.abc.Iterable):
|
21 |
+
return x
|
22 |
+
return tuple(repeat(x, n))
|
23 |
+
|
24 |
+
return parse
|
25 |
+
|
26 |
+
|
27 |
+
to_1tuple = _ntuple(1)
|
28 |
+
to_2tuple = _ntuple(2)
|
29 |
+
to_3tuple = _ntuple(3)
|
30 |
+
to_4tuple = _ntuple(4)
|
31 |
+
to_ntuple = _ntuple
|
32 |
+
|
33 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
34 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
35 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
36 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
37 |
+
def norm_cdf(x):
|
38 |
+
# Computes standard normal cumulative distribution function
|
39 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
40 |
+
|
41 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
42 |
+
warnings.warn(
|
43 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
44 |
+
'The distribution of values may be incorrect.',
|
45 |
+
stacklevel=2)
|
46 |
+
|
47 |
+
with torch.no_grad():
|
48 |
+
# Values are generated by using a truncated uniform distribution and
|
49 |
+
# then using the inverse CDF for the normal distribution.
|
50 |
+
# Get upper and lower cdf values
|
51 |
+
low = norm_cdf((a - mean) / std)
|
52 |
+
up = norm_cdf((b - mean) / std)
|
53 |
+
|
54 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
55 |
+
# [2l-1, 2u-1].
|
56 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
57 |
+
|
58 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
59 |
+
# standard normal
|
60 |
+
tensor.erfinv_()
|
61 |
+
|
62 |
+
# Transform to proper mean, std
|
63 |
+
tensor.mul_(std * math.sqrt(2.))
|
64 |
+
tensor.add_(mean)
|
65 |
+
|
66 |
+
# Clamp to ensure it's in the proper range
|
67 |
+
tensor.clamp_(min=a, max=b)
|
68 |
+
return tensor
|
69 |
+
|
70 |
+
|
71 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
72 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
73 |
+
normal distribution.
|
74 |
+
|
75 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
76 |
+
|
77 |
+
The values are effectively drawn from the
|
78 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
79 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
80 |
+
the bounds. The method used for generating the random values works
|
81 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
tensor: an n-dimensional `torch.Tensor`
|
85 |
+
mean: the mean of the normal distribution
|
86 |
+
std: the standard deviation of the normal distribution
|
87 |
+
a: the minimum cutoff value
|
88 |
+
b: the maximum cutoff value
|
89 |
+
|
90 |
+
Examples:
|
91 |
+
>>> w = torch.empty(3, 5)
|
92 |
+
>>> nn.init.trunc_normal_(w)
|
93 |
+
"""
|
94 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
95 |
+
|
96 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
97 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
98 |
+
|
99 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
100 |
+
"""
|
101 |
+
if drop_prob == 0. or not training:
|
102 |
+
return x
|
103 |
+
keep_prob = 1 - drop_prob
|
104 |
+
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
105 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
106 |
+
random_tensor.floor_() # binarize
|
107 |
+
output = x.div(keep_prob) * random_tensor
|
108 |
+
return output
|
109 |
+
|
110 |
+
|
111 |
+
class DropPath(nn.Module):
|
112 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
113 |
+
|
114 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
115 |
+
"""
|
116 |
+
|
117 |
+
def __init__(self, drop_prob=None):
|
118 |
+
super(DropPath, self).__init__()
|
119 |
+
self.drop_prob = drop_prob
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
return drop_path(x, self.drop_prob, self.training)
|
123 |
+
|
124 |
+
|
125 |
+
class Mlp(nn.Module):
|
126 |
+
|
127 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
128 |
+
super().__init__()
|
129 |
+
out_features = out_features or in_features
|
130 |
+
hidden_features = hidden_features or in_features
|
131 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
132 |
+
self.act = act_layer()
|
133 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
134 |
+
self.drop = nn.Dropout(drop)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.fc1(x)
|
138 |
+
x = self.act(x)
|
139 |
+
x = self.drop(x)
|
140 |
+
x = self.fc2(x)
|
141 |
+
x = self.drop(x)
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
def window_partition(x, window_size):
|
146 |
+
"""
|
147 |
+
Args:
|
148 |
+
x: (b, h, w, c)
|
149 |
+
window_size (int): window size
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
windows: (num_windows*b, window_size, window_size, c)
|
153 |
+
"""
|
154 |
+
b, h, w, c = x.shape
|
155 |
+
x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
|
156 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
|
157 |
+
return windows
|
158 |
+
|
159 |
+
|
160 |
+
def window_reverse(windows, window_size, h, w):
|
161 |
+
"""
|
162 |
+
Args:
|
163 |
+
windows: (num_windows*b, window_size, window_size, c)
|
164 |
+
window_size (int): Window size
|
165 |
+
h (int): Height of image
|
166 |
+
w (int): Width of image
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
x: (b, h, w, c)
|
170 |
+
"""
|
171 |
+
b = int(windows.shape[0] / (h * w / window_size / window_size))
|
172 |
+
x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
|
173 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
174 |
+
return x
|
175 |
+
|
176 |
+
|
177 |
+
class WindowAttention(nn.Module):
|
178 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
179 |
+
It supports both of shifted and non-shifted window.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
dim (int): Number of input channels.
|
183 |
+
window_size (tuple[int]): The height and width of the window.
|
184 |
+
num_heads (int): Number of attention heads.
|
185 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
186 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
187 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
188 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
192 |
+
|
193 |
+
super().__init__()
|
194 |
+
self.dim = dim
|
195 |
+
self.window_size = window_size # Wh, Ww
|
196 |
+
self.num_heads = num_heads
|
197 |
+
head_dim = dim // num_heads
|
198 |
+
self.scale = qk_scale or head_dim**-0.5
|
199 |
+
|
200 |
+
# define a parameter table of relative position bias
|
201 |
+
self.relative_position_bias_table = nn.Parameter(
|
202 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
203 |
+
|
204 |
+
# get pair-wise relative position index for each token inside the window
|
205 |
+
coords_h = torch.arange(self.window_size[0])
|
206 |
+
coords_w = torch.arange(self.window_size[1])
|
207 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
208 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
209 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
210 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
211 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
212 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
213 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
214 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
215 |
+
self.register_buffer('relative_position_index', relative_position_index)
|
216 |
+
|
217 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
218 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
219 |
+
self.proj = nn.Linear(dim, dim)
|
220 |
+
|
221 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
222 |
+
|
223 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
224 |
+
self.softmax = nn.Softmax(dim=-1)
|
225 |
+
|
226 |
+
def forward(self, x, mask=None):
|
227 |
+
"""
|
228 |
+
Args:
|
229 |
+
x: input features with shape of (num_windows*b, n, c)
|
230 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
231 |
+
"""
|
232 |
+
b_, n, c = x.shape
|
233 |
+
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
234 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
235 |
+
|
236 |
+
q = q * self.scale
|
237 |
+
attn = (q @ k.transpose(-2, -1))
|
238 |
+
|
239 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
240 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
241 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
242 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
243 |
+
|
244 |
+
if mask is not None:
|
245 |
+
nw = mask.shape[0]
|
246 |
+
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
247 |
+
attn = attn.view(-1, self.num_heads, n, n)
|
248 |
+
attn = self.softmax(attn)
|
249 |
+
else:
|
250 |
+
attn = self.softmax(attn)
|
251 |
+
|
252 |
+
attn = self.attn_drop(attn)
|
253 |
+
|
254 |
+
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
|
255 |
+
x = self.proj(x)
|
256 |
+
x = self.proj_drop(x)
|
257 |
+
return x
|
258 |
+
|
259 |
+
def extra_repr(self) -> str:
|
260 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
261 |
+
|
262 |
+
def flops(self, n):
|
263 |
+
# calculate flops for 1 window with token length of n
|
264 |
+
flops = 0
|
265 |
+
# qkv = self.qkv(x)
|
266 |
+
flops += n * self.dim * 3 * self.dim
|
267 |
+
# attn = (q @ k.transpose(-2, -1))
|
268 |
+
flops += self.num_heads * n * (self.dim // self.num_heads) * n
|
269 |
+
# x = (attn @ v)
|
270 |
+
flops += self.num_heads * n * n * (self.dim // self.num_heads)
|
271 |
+
# x = self.proj(x)
|
272 |
+
flops += n * self.dim * self.dim
|
273 |
+
return flops
|
274 |
+
|
275 |
+
|
276 |
+
class SwinTransformerBlock(nn.Module):
|
277 |
+
r""" Swin Transformer Block.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
dim (int): Number of input channels.
|
281 |
+
input_resolution (tuple[int]): Input resolution.
|
282 |
+
num_heads (int): Number of attention heads.
|
283 |
+
window_size (int): Window size.
|
284 |
+
shift_size (int): Shift size for SW-MSA.
|
285 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
286 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
287 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
288 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
289 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
290 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
291 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
292 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self,
|
296 |
+
dim,
|
297 |
+
input_resolution,
|
298 |
+
num_heads,
|
299 |
+
window_size=7,
|
300 |
+
shift_size=0,
|
301 |
+
mlp_ratio=4.,
|
302 |
+
qkv_bias=True,
|
303 |
+
qk_scale=None,
|
304 |
+
drop=0.,
|
305 |
+
attn_drop=0.,
|
306 |
+
drop_path=0.,
|
307 |
+
act_layer=nn.GELU,
|
308 |
+
norm_layer=nn.LayerNorm):
|
309 |
+
super().__init__()
|
310 |
+
self.dim = dim
|
311 |
+
self.input_resolution = input_resolution
|
312 |
+
self.num_heads = num_heads
|
313 |
+
self.window_size = window_size
|
314 |
+
self.shift_size = shift_size
|
315 |
+
self.mlp_ratio = mlp_ratio
|
316 |
+
if min(self.input_resolution) <= self.window_size:
|
317 |
+
# if window size is larger than input resolution, we don't partition windows
|
318 |
+
self.shift_size = 0
|
319 |
+
self.window_size = min(self.input_resolution)
|
320 |
+
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
321 |
+
|
322 |
+
self.norm1 = norm_layer(dim)
|
323 |
+
self.attn = WindowAttention(
|
324 |
+
dim,
|
325 |
+
window_size=to_2tuple(self.window_size),
|
326 |
+
num_heads=num_heads,
|
327 |
+
qkv_bias=qkv_bias,
|
328 |
+
qk_scale=qk_scale,
|
329 |
+
attn_drop=attn_drop,
|
330 |
+
proj_drop=drop)
|
331 |
+
|
332 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
333 |
+
self.norm2 = norm_layer(dim)
|
334 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
335 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
336 |
+
|
337 |
+
if self.shift_size > 0:
|
338 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
339 |
+
else:
|
340 |
+
attn_mask = None
|
341 |
+
|
342 |
+
self.register_buffer('attn_mask', attn_mask)
|
343 |
+
|
344 |
+
def calculate_mask(self, x_size):
|
345 |
+
# calculate attention mask for SW-MSA
|
346 |
+
h, w = x_size
|
347 |
+
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
|
348 |
+
h_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
349 |
+
-self.shift_size), slice(-self.shift_size, None))
|
350 |
+
w_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
351 |
+
-self.shift_size), slice(-self.shift_size, None))
|
352 |
+
cnt = 0
|
353 |
+
for h in h_slices:
|
354 |
+
for w in w_slices:
|
355 |
+
img_mask[:, h, w, :] = cnt
|
356 |
+
cnt += 1
|
357 |
+
|
358 |
+
mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
|
359 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
360 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
361 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
362 |
+
|
363 |
+
return attn_mask
|
364 |
+
|
365 |
+
def forward(self, x, x_size):
|
366 |
+
h, w = x_size
|
367 |
+
b, _, c = x.shape
|
368 |
+
# assert seq_len == h * w, "input feature has wrong size"
|
369 |
+
|
370 |
+
shortcut = x
|
371 |
+
x = self.norm1(x)
|
372 |
+
x = x.view(b, h, w, c)
|
373 |
+
|
374 |
+
# cyclic shift
|
375 |
+
if self.shift_size > 0:
|
376 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
377 |
+
else:
|
378 |
+
shifted_x = x
|
379 |
+
|
380 |
+
# partition windows
|
381 |
+
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
|
382 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
|
383 |
+
|
384 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
385 |
+
if self.input_resolution == x_size:
|
386 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
|
387 |
+
else:
|
388 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
389 |
+
|
390 |
+
# merge windows
|
391 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
|
392 |
+
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
|
393 |
+
|
394 |
+
# reverse cyclic shift
|
395 |
+
if self.shift_size > 0:
|
396 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
397 |
+
else:
|
398 |
+
x = shifted_x
|
399 |
+
x = x.view(b, h * w, c)
|
400 |
+
|
401 |
+
# FFN
|
402 |
+
x = shortcut + self.drop_path(x)
|
403 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
404 |
+
|
405 |
+
return x
|
406 |
+
|
407 |
+
def extra_repr(self) -> str:
|
408 |
+
return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
|
409 |
+
f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
|
410 |
+
|
411 |
+
def flops(self):
|
412 |
+
flops = 0
|
413 |
+
h, w = self.input_resolution
|
414 |
+
# norm1
|
415 |
+
flops += self.dim * h * w
|
416 |
+
# W-MSA/SW-MSA
|
417 |
+
nw = h * w / self.window_size / self.window_size
|
418 |
+
flops += nw * self.attn.flops(self.window_size * self.window_size)
|
419 |
+
# mlp
|
420 |
+
flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
|
421 |
+
# norm2
|
422 |
+
flops += self.dim * h * w
|
423 |
+
return flops
|
424 |
+
|
425 |
+
|
426 |
+
class PatchMerging(nn.Module):
|
427 |
+
r""" Patch Merging Layer.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
431 |
+
dim (int): Number of input channels.
|
432 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
433 |
+
"""
|
434 |
+
|
435 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
436 |
+
super().__init__()
|
437 |
+
self.input_resolution = input_resolution
|
438 |
+
self.dim = dim
|
439 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
440 |
+
self.norm = norm_layer(4 * dim)
|
441 |
+
|
442 |
+
def forward(self, x):
|
443 |
+
"""
|
444 |
+
x: b, h*w, c
|
445 |
+
"""
|
446 |
+
h, w = self.input_resolution
|
447 |
+
b, seq_len, c = x.shape
|
448 |
+
assert seq_len == h * w, 'input feature has wrong size'
|
449 |
+
assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
|
450 |
+
|
451 |
+
x = x.view(b, h, w, c)
|
452 |
+
|
453 |
+
x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
|
454 |
+
x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
|
455 |
+
x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
|
456 |
+
x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
|
457 |
+
x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
|
458 |
+
x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
|
459 |
+
|
460 |
+
x = self.norm(x)
|
461 |
+
x = self.reduction(x)
|
462 |
+
|
463 |
+
return x
|
464 |
+
|
465 |
+
def extra_repr(self) -> str:
|
466 |
+
return f'input_resolution={self.input_resolution}, dim={self.dim}'
|
467 |
+
|
468 |
+
def flops(self):
|
469 |
+
h, w = self.input_resolution
|
470 |
+
flops = h * w * self.dim
|
471 |
+
flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
|
472 |
+
return flops
|
473 |
+
|
474 |
+
|
475 |
+
class BasicLayer(nn.Module):
|
476 |
+
""" A basic Swin Transformer layer for one stage.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
dim (int): Number of input channels.
|
480 |
+
input_resolution (tuple[int]): Input resolution.
|
481 |
+
depth (int): Number of blocks.
|
482 |
+
num_heads (int): Number of attention heads.
|
483 |
+
window_size (int): Local window size.
|
484 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
485 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
486 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
487 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
488 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
489 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
490 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
491 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
492 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
493 |
+
"""
|
494 |
+
|
495 |
+
def __init__(self,
|
496 |
+
dim,
|
497 |
+
input_resolution,
|
498 |
+
depth,
|
499 |
+
num_heads,
|
500 |
+
window_size,
|
501 |
+
mlp_ratio=4.,
|
502 |
+
qkv_bias=True,
|
503 |
+
qk_scale=None,
|
504 |
+
drop=0.,
|
505 |
+
attn_drop=0.,
|
506 |
+
drop_path=0.,
|
507 |
+
norm_layer=nn.LayerNorm,
|
508 |
+
downsample=None,
|
509 |
+
use_checkpoint=False):
|
510 |
+
|
511 |
+
super().__init__()
|
512 |
+
self.dim = dim
|
513 |
+
self.input_resolution = input_resolution
|
514 |
+
self.depth = depth
|
515 |
+
self.use_checkpoint = use_checkpoint
|
516 |
+
|
517 |
+
# build blocks
|
518 |
+
self.blocks = nn.ModuleList([
|
519 |
+
SwinTransformerBlock(
|
520 |
+
dim=dim,
|
521 |
+
input_resolution=input_resolution,
|
522 |
+
num_heads=num_heads,
|
523 |
+
window_size=window_size,
|
524 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
525 |
+
mlp_ratio=mlp_ratio,
|
526 |
+
qkv_bias=qkv_bias,
|
527 |
+
qk_scale=qk_scale,
|
528 |
+
drop=drop,
|
529 |
+
attn_drop=attn_drop,
|
530 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
531 |
+
norm_layer=norm_layer) for i in range(depth)
|
532 |
+
])
|
533 |
+
|
534 |
+
# patch merging layer
|
535 |
+
if downsample is not None:
|
536 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
537 |
+
else:
|
538 |
+
self.downsample = None
|
539 |
+
|
540 |
+
def forward(self, x, x_size):
|
541 |
+
for blk in self.blocks:
|
542 |
+
if self.use_checkpoint:
|
543 |
+
x = checkpoint.checkpoint(blk, x)
|
544 |
+
else:
|
545 |
+
x = blk(x, x_size)
|
546 |
+
if self.downsample is not None:
|
547 |
+
x = self.downsample(x)
|
548 |
+
return x
|
549 |
+
|
550 |
+
def extra_repr(self) -> str:
|
551 |
+
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
552 |
+
|
553 |
+
def flops(self):
|
554 |
+
flops = 0
|
555 |
+
for blk in self.blocks:
|
556 |
+
flops += blk.flops()
|
557 |
+
if self.downsample is not None:
|
558 |
+
flops += self.downsample.flops()
|
559 |
+
return flops
|
560 |
+
|
561 |
+
|
562 |
+
class RSTB(nn.Module):
|
563 |
+
"""Residual Swin Transformer Block (RSTB).
|
564 |
+
|
565 |
+
Args:
|
566 |
+
dim (int): Number of input channels.
|
567 |
+
input_resolution (tuple[int]): Input resolution.
|
568 |
+
depth (int): Number of blocks.
|
569 |
+
num_heads (int): Number of attention heads.
|
570 |
+
window_size (int): Local window size.
|
571 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
572 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
573 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
574 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
575 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
576 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
577 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
578 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
579 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
580 |
+
img_size: Input image size.
|
581 |
+
patch_size: Patch size.
|
582 |
+
resi_connection: The convolutional block before residual connection.
|
583 |
+
"""
|
584 |
+
|
585 |
+
def __init__(self,
|
586 |
+
dim,
|
587 |
+
input_resolution,
|
588 |
+
depth,
|
589 |
+
num_heads,
|
590 |
+
window_size,
|
591 |
+
mlp_ratio=4.,
|
592 |
+
qkv_bias=True,
|
593 |
+
qk_scale=None,
|
594 |
+
drop=0.,
|
595 |
+
attn_drop=0.,
|
596 |
+
drop_path=0.,
|
597 |
+
norm_layer=nn.LayerNorm,
|
598 |
+
downsample=None,
|
599 |
+
use_checkpoint=False,
|
600 |
+
img_size=224,
|
601 |
+
patch_size=4,
|
602 |
+
resi_connection='1conv'):
|
603 |
+
super(RSTB, self).__init__()
|
604 |
+
|
605 |
+
self.dim = dim
|
606 |
+
self.input_resolution = input_resolution
|
607 |
+
|
608 |
+
self.residual_group = BasicLayer(
|
609 |
+
dim=dim,
|
610 |
+
input_resolution=input_resolution,
|
611 |
+
depth=depth,
|
612 |
+
num_heads=num_heads,
|
613 |
+
window_size=window_size,
|
614 |
+
mlp_ratio=mlp_ratio,
|
615 |
+
qkv_bias=qkv_bias,
|
616 |
+
qk_scale=qk_scale,
|
617 |
+
drop=drop,
|
618 |
+
attn_drop=attn_drop,
|
619 |
+
drop_path=drop_path,
|
620 |
+
norm_layer=norm_layer,
|
621 |
+
downsample=downsample,
|
622 |
+
use_checkpoint=use_checkpoint)
|
623 |
+
|
624 |
+
if resi_connection == '1conv':
|
625 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
626 |
+
elif resi_connection == '3conv':
|
627 |
+
# to save parameters and memory
|
628 |
+
self.conv = nn.Sequential(
|
629 |
+
nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
630 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
631 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
632 |
+
|
633 |
+
self.patch_embed = PatchEmbed(
|
634 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
635 |
+
|
636 |
+
self.patch_unembed = PatchUnEmbed(
|
637 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
638 |
+
|
639 |
+
def forward(self, x, x_size):
|
640 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
641 |
+
|
642 |
+
def flops(self):
|
643 |
+
flops = 0
|
644 |
+
flops += self.residual_group.flops()
|
645 |
+
h, w = self.input_resolution
|
646 |
+
flops += h * w * self.dim * self.dim * 9
|
647 |
+
flops += self.patch_embed.flops()
|
648 |
+
flops += self.patch_unembed.flops()
|
649 |
+
|
650 |
+
return flops
|
651 |
+
|
652 |
+
|
653 |
+
class PatchEmbed(nn.Module):
|
654 |
+
r""" Image to Patch Embedding
|
655 |
+
|
656 |
+
Args:
|
657 |
+
img_size (int): Image size. Default: 224.
|
658 |
+
patch_size (int): Patch token size. Default: 4.
|
659 |
+
in_chans (int): Number of input image channels. Default: 3.
|
660 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
661 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
662 |
+
"""
|
663 |
+
|
664 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
665 |
+
super().__init__()
|
666 |
+
img_size = to_2tuple(img_size)
|
667 |
+
patch_size = to_2tuple(patch_size)
|
668 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
669 |
+
self.img_size = img_size
|
670 |
+
self.patch_size = patch_size
|
671 |
+
self.patches_resolution = patches_resolution
|
672 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
673 |
+
|
674 |
+
self.in_chans = in_chans
|
675 |
+
self.embed_dim = embed_dim
|
676 |
+
|
677 |
+
if norm_layer is not None:
|
678 |
+
self.norm = norm_layer(embed_dim)
|
679 |
+
else:
|
680 |
+
self.norm = None
|
681 |
+
|
682 |
+
def forward(self, x):
|
683 |
+
x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
|
684 |
+
if self.norm is not None:
|
685 |
+
x = self.norm(x)
|
686 |
+
return x
|
687 |
+
|
688 |
+
def flops(self):
|
689 |
+
flops = 0
|
690 |
+
h, w = self.img_size
|
691 |
+
if self.norm is not None:
|
692 |
+
flops += h * w * self.embed_dim
|
693 |
+
return flops
|
694 |
+
|
695 |
+
|
696 |
+
class PatchUnEmbed(nn.Module):
|
697 |
+
r""" Image to Patch Unembedding
|
698 |
+
|
699 |
+
Args:
|
700 |
+
img_size (int): Image size. Default: 224.
|
701 |
+
patch_size (int): Patch token size. Default: 4.
|
702 |
+
in_chans (int): Number of input image channels. Default: 3.
|
703 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
704 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
705 |
+
"""
|
706 |
+
|
707 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
708 |
+
super().__init__()
|
709 |
+
img_size = to_2tuple(img_size)
|
710 |
+
patch_size = to_2tuple(patch_size)
|
711 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
712 |
+
self.img_size = img_size
|
713 |
+
self.patch_size = patch_size
|
714 |
+
self.patches_resolution = patches_resolution
|
715 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
716 |
+
|
717 |
+
self.in_chans = in_chans
|
718 |
+
self.embed_dim = embed_dim
|
719 |
+
|
720 |
+
def forward(self, x, x_size):
|
721 |
+
x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
|
722 |
+
return x
|
723 |
+
|
724 |
+
def flops(self):
|
725 |
+
flops = 0
|
726 |
+
return flops
|
727 |
+
|
728 |
+
|
729 |
+
class Upsample(nn.Sequential):
|
730 |
+
"""Upsample module.
|
731 |
+
|
732 |
+
Args:
|
733 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
734 |
+
num_feat (int): Channel number of intermediate features.
|
735 |
+
"""
|
736 |
+
|
737 |
+
def __init__(self, scale, num_feat):
|
738 |
+
m = []
|
739 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
740 |
+
for _ in range(int(math.log(scale, 2))):
|
741 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
742 |
+
m.append(nn.PixelShuffle(2))
|
743 |
+
elif scale == 3:
|
744 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
745 |
+
m.append(nn.PixelShuffle(3))
|
746 |
+
else:
|
747 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
748 |
+
super(Upsample, self).__init__(*m)
|
749 |
+
|
750 |
+
|
751 |
+
class UpsampleOneStep(nn.Sequential):
|
752 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
753 |
+
Used in lightweight SR to save parameters.
|
754 |
+
|
755 |
+
Args:
|
756 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
757 |
+
num_feat (int): Channel number of intermediate features.
|
758 |
+
|
759 |
+
"""
|
760 |
+
|
761 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
762 |
+
self.num_feat = num_feat
|
763 |
+
self.input_resolution = input_resolution
|
764 |
+
m = []
|
765 |
+
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
|
766 |
+
m.append(nn.PixelShuffle(scale))
|
767 |
+
super(UpsampleOneStep, self).__init__(*m)
|
768 |
+
|
769 |
+
def flops(self):
|
770 |
+
h, w = self.input_resolution
|
771 |
+
flops = h * w * self.num_feat * 3 * 9
|
772 |
+
return flops
|
773 |
+
|
774 |
+
|
775 |
+
class SwinIR(nn.Module):
|
776 |
+
r""" SwinIR
|
777 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
778 |
+
|
779 |
+
Args:
|
780 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
781 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
782 |
+
in_chans (int): Number of input image channels. Default: 3
|
783 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
784 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
785 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
786 |
+
window_size (int): Window size. Default: 7
|
787 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
788 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
789 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
790 |
+
drop_rate (float): Dropout rate. Default: 0
|
791 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
792 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
793 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
794 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
795 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
796 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
797 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
798 |
+
img_range: Image range. 1. or 255.
|
799 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
800 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
801 |
+
"""
|
802 |
+
|
803 |
+
def __init__(self,
|
804 |
+
img_size=64,
|
805 |
+
patch_size=1,
|
806 |
+
in_chans=3,
|
807 |
+
embed_dim=96,
|
808 |
+
depths=(6, 6, 6, 6),
|
809 |
+
num_heads=(6, 6, 6, 6),
|
810 |
+
window_size=7,
|
811 |
+
mlp_ratio=4.,
|
812 |
+
qkv_bias=True,
|
813 |
+
qk_scale=None,
|
814 |
+
drop_rate=0.,
|
815 |
+
attn_drop_rate=0.,
|
816 |
+
drop_path_rate=0.1,
|
817 |
+
norm_layer=nn.LayerNorm,
|
818 |
+
ape=False,
|
819 |
+
patch_norm=True,
|
820 |
+
use_checkpoint=False,
|
821 |
+
upscale=2,
|
822 |
+
img_range=1.,
|
823 |
+
upsampler='',
|
824 |
+
resi_connection='1conv',
|
825 |
+
**kwargs):
|
826 |
+
super(SwinIR, self).__init__()
|
827 |
+
num_in_ch = in_chans
|
828 |
+
num_out_ch = in_chans
|
829 |
+
num_feat = 64
|
830 |
+
self.img_range = img_range
|
831 |
+
if in_chans == 3:
|
832 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
833 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
834 |
+
else:
|
835 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
836 |
+
self.upscale = upscale
|
837 |
+
self.upsampler = upsampler
|
838 |
+
|
839 |
+
# ------------------------- 1, shallow feature extraction ------------------------- #
|
840 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
841 |
+
|
842 |
+
# ------------------------- 2, deep feature extraction ------------------------- #
|
843 |
+
self.num_layers = len(depths)
|
844 |
+
self.embed_dim = embed_dim
|
845 |
+
self.ape = ape
|
846 |
+
self.patch_norm = patch_norm
|
847 |
+
self.num_features = embed_dim
|
848 |
+
self.mlp_ratio = mlp_ratio
|
849 |
+
|
850 |
+
# split image into non-overlapping patches
|
851 |
+
self.patch_embed = PatchEmbed(
|
852 |
+
img_size=img_size,
|
853 |
+
patch_size=patch_size,
|
854 |
+
in_chans=embed_dim,
|
855 |
+
embed_dim=embed_dim,
|
856 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
857 |
+
num_patches = self.patch_embed.num_patches
|
858 |
+
patches_resolution = self.patch_embed.patches_resolution
|
859 |
+
self.patches_resolution = patches_resolution
|
860 |
+
|
861 |
+
# merge non-overlapping patches into image
|
862 |
+
self.patch_unembed = PatchUnEmbed(
|
863 |
+
img_size=img_size,
|
864 |
+
patch_size=patch_size,
|
865 |
+
in_chans=embed_dim,
|
866 |
+
embed_dim=embed_dim,
|
867 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
868 |
+
|
869 |
+
# absolute position embedding
|
870 |
+
if self.ape:
|
871 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
872 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
873 |
+
|
874 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
875 |
+
|
876 |
+
# stochastic depth
|
877 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
878 |
+
|
879 |
+
# build Residual Swin Transformer blocks (RSTB)
|
880 |
+
self.layers = nn.ModuleList()
|
881 |
+
for i_layer in range(self.num_layers):
|
882 |
+
layer = RSTB(
|
883 |
+
dim=embed_dim,
|
884 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
885 |
+
depth=depths[i_layer],
|
886 |
+
num_heads=num_heads[i_layer],
|
887 |
+
window_size=window_size,
|
888 |
+
mlp_ratio=self.mlp_ratio,
|
889 |
+
qkv_bias=qkv_bias,
|
890 |
+
qk_scale=qk_scale,
|
891 |
+
drop=drop_rate,
|
892 |
+
attn_drop=attn_drop_rate,
|
893 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
894 |
+
norm_layer=norm_layer,
|
895 |
+
downsample=None,
|
896 |
+
use_checkpoint=use_checkpoint,
|
897 |
+
img_size=img_size,
|
898 |
+
patch_size=patch_size,
|
899 |
+
resi_connection=resi_connection)
|
900 |
+
self.layers.append(layer)
|
901 |
+
self.norm = norm_layer(self.num_features)
|
902 |
+
|
903 |
+
# build the last conv layer in deep feature extraction
|
904 |
+
if resi_connection == '1conv':
|
905 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
906 |
+
elif resi_connection == '3conv':
|
907 |
+
# to save parameters and memory
|
908 |
+
self.conv_after_body = nn.Sequential(
|
909 |
+
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
910 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
911 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
912 |
+
|
913 |
+
# ------------------------- 3, high quality image reconstruction ------------------------- #
|
914 |
+
if self.upsampler == 'pixelshuffle':
|
915 |
+
# for classical SR
|
916 |
+
self.conv_before_upsample = nn.Sequential(
|
917 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
918 |
+
self.upsample = Upsample(upscale, num_feat)
|
919 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
920 |
+
elif self.upsampler == 'pixelshuffledirect':
|
921 |
+
# for lightweight SR (to save parameters)
|
922 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
923 |
+
(patches_resolution[0], patches_resolution[1]))
|
924 |
+
elif self.upsampler == 'nearest+conv':
|
925 |
+
# for real-world SR (less artifacts)
|
926 |
+
assert self.upscale == 4, 'only support x4 now.'
|
927 |
+
self.conv_before_upsample = nn.Sequential(
|
928 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
929 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
930 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
931 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
932 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
933 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
934 |
+
else:
|
935 |
+
# for image denoising and JPEG compression artifact reduction
|
936 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
937 |
+
|
938 |
+
self.apply(self._init_weights)
|
939 |
+
|
940 |
+
def _init_weights(self, m):
|
941 |
+
if isinstance(m, nn.Linear):
|
942 |
+
trunc_normal_(m.weight, std=.02)
|
943 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
944 |
+
nn.init.constant_(m.bias, 0)
|
945 |
+
elif isinstance(m, nn.LayerNorm):
|
946 |
+
nn.init.constant_(m.bias, 0)
|
947 |
+
nn.init.constant_(m.weight, 1.0)
|
948 |
+
|
949 |
+
@torch.jit.ignore
|
950 |
+
def no_weight_decay(self):
|
951 |
+
return {'absolute_pos_embed'}
|
952 |
+
|
953 |
+
@torch.jit.ignore
|
954 |
+
def no_weight_decay_keywords(self):
|
955 |
+
return {'relative_position_bias_table'}
|
956 |
+
|
957 |
+
def forward_features(self, x):
|
958 |
+
x_size = (x.shape[2], x.shape[3])
|
959 |
+
x = self.patch_embed(x)
|
960 |
+
if self.ape:
|
961 |
+
x = x + self.absolute_pos_embed
|
962 |
+
x = self.pos_drop(x)
|
963 |
+
|
964 |
+
for layer in self.layers:
|
965 |
+
x = layer(x, x_size)
|
966 |
+
|
967 |
+
x = self.norm(x) # b seq_len c
|
968 |
+
x = self.patch_unembed(x, x_size)
|
969 |
+
|
970 |
+
return x
|
971 |
+
|
972 |
+
def forward(self, x):
|
973 |
+
self.mean = self.mean.type_as(x)
|
974 |
+
x = (x - self.mean) * self.img_range
|
975 |
+
|
976 |
+
if self.upsampler == 'pixelshuffle':
|
977 |
+
# for classical SR
|
978 |
+
x = self.conv_first(x)
|
979 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
980 |
+
x = self.conv_before_upsample(x)
|
981 |
+
x = self.conv_last(self.upsample(x))
|
982 |
+
elif self.upsampler == 'pixelshuffledirect':
|
983 |
+
# for lightweight SR
|
984 |
+
x = self.conv_first(x)
|
985 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
986 |
+
x = self.upsample(x)
|
987 |
+
elif self.upsampler == 'nearest+conv':
|
988 |
+
# for real-world SR
|
989 |
+
x = self.conv_first(x)
|
990 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
991 |
+
x = self.conv_before_upsample(x)
|
992 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
993 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
994 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
995 |
+
else:
|
996 |
+
# for image denoising and JPEG compression artifact reduction
|
997 |
+
x_first = self.conv_first(x)
|
998 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
999 |
+
x = x + self.conv_last(res)
|
1000 |
+
|
1001 |
+
x = x / self.img_range + self.mean
|
1002 |
+
|
1003 |
+
return x
|
1004 |
+
|
1005 |
+
def flops(self):
|
1006 |
+
flops = 0
|
1007 |
+
h, w = self.patches_resolution
|
1008 |
+
flops += h * w * 3 * self.embed_dim * 9
|
1009 |
+
flops += self.patch_embed.flops()
|
1010 |
+
for layer in self.layers:
|
1011 |
+
flops += layer.flops()
|
1012 |
+
flops += h * w * 3 * self.embed_dim * self.embed_dim
|
1013 |
+
flops += self.upsample.flops()
|
1014 |
+
return flops
|
1015 |
+
|
1016 |
+
|
1017 |
+
|
1018 |
+
class SwinIRNOUP(nn.Module):
|
1019 |
+
def __init__(self,
|
1020 |
+
img_size=48,
|
1021 |
+
patch_size=1,
|
1022 |
+
in_chans=3,
|
1023 |
+
embed_dim=180,
|
1024 |
+
depths=(6, 6, 6, 6, 6, 6),
|
1025 |
+
num_heads=(6, 6, 6, 6, 6, 6),
|
1026 |
+
window_size=8,
|
1027 |
+
mlp_ratio=2,
|
1028 |
+
qkv_bias=True,
|
1029 |
+
qk_scale=None,
|
1030 |
+
drop_rate=0.,
|
1031 |
+
attn_drop_rate=0.,
|
1032 |
+
drop_path_rate=0.1,
|
1033 |
+
norm_layer=nn.LayerNorm,
|
1034 |
+
ape=False,
|
1035 |
+
patch_norm=True,
|
1036 |
+
use_checkpoint=False,
|
1037 |
+
upscale=4,
|
1038 |
+
img_range=1.,
|
1039 |
+
upsampler='pixelshuffle',
|
1040 |
+
resi_connection='1conv',
|
1041 |
+
**kwargs):
|
1042 |
+
super(SwinIRNOUP, self).__init__()
|
1043 |
+
num_in_ch = in_chans
|
1044 |
+
num_out_ch = in_chans
|
1045 |
+
num_feat = 64
|
1046 |
+
self.img_range = img_range
|
1047 |
+
self.upsampler = upsampler
|
1048 |
+
|
1049 |
+
|
1050 |
+
# ------------------------- 1, shallow feature extraction ------------------------- #
|
1051 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
1052 |
+
|
1053 |
+
# ------------------------- 2, deep feature extraction ------------------------- #
|
1054 |
+
self.num_layers = len(depths)
|
1055 |
+
self.embed_dim = embed_dim
|
1056 |
+
self.ape = ape
|
1057 |
+
self.patch_norm = patch_norm
|
1058 |
+
self.num_features = embed_dim
|
1059 |
+
self.mlp_ratio = mlp_ratio
|
1060 |
+
|
1061 |
+
# split image into non-overlapping patches
|
1062 |
+
self.patch_embed = PatchEmbed(
|
1063 |
+
img_size=img_size,
|
1064 |
+
patch_size=patch_size,
|
1065 |
+
in_chans=embed_dim,
|
1066 |
+
embed_dim=embed_dim,
|
1067 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
1068 |
+
num_patches = self.patch_embed.num_patches
|
1069 |
+
patches_resolution = self.patch_embed.patches_resolution
|
1070 |
+
self.patches_resolution = patches_resolution
|
1071 |
+
|
1072 |
+
# merge non-overlapping patches into image
|
1073 |
+
self.patch_unembed = PatchUnEmbed(
|
1074 |
+
img_size=img_size,
|
1075 |
+
patch_size=patch_size,
|
1076 |
+
in_chans=embed_dim,
|
1077 |
+
embed_dim=embed_dim,
|
1078 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
1079 |
+
|
1080 |
+
# absolute position embedding
|
1081 |
+
if self.ape:
|
1082 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
1083 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
1084 |
+
|
1085 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
1086 |
+
|
1087 |
+
# stochastic depth
|
1088 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
1089 |
+
|
1090 |
+
# build Residual Swin Transformer blocks (RSTB)
|
1091 |
+
self.layers = nn.ModuleList()
|
1092 |
+
for i_layer in range(self.num_layers):
|
1093 |
+
layer = RSTB(
|
1094 |
+
dim=embed_dim,
|
1095 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
1096 |
+
depth=depths[i_layer],
|
1097 |
+
num_heads=num_heads[i_layer],
|
1098 |
+
window_size=window_size,
|
1099 |
+
mlp_ratio=self.mlp_ratio,
|
1100 |
+
qkv_bias=qkv_bias,
|
1101 |
+
qk_scale=qk_scale,
|
1102 |
+
drop=drop_rate,
|
1103 |
+
attn_drop=attn_drop_rate,
|
1104 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
1105 |
+
norm_layer=norm_layer,
|
1106 |
+
downsample=None,
|
1107 |
+
use_checkpoint=use_checkpoint,
|
1108 |
+
img_size=img_size,
|
1109 |
+
patch_size=patch_size,
|
1110 |
+
resi_connection=resi_connection)
|
1111 |
+
self.layers.append(layer)
|
1112 |
+
self.norm = norm_layer(self.num_features)
|
1113 |
+
|
1114 |
+
# build the last conv layer in deep feature extraction
|
1115 |
+
if resi_connection == '1conv':
|
1116 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
1117 |
+
elif resi_connection == '3conv':
|
1118 |
+
# to save parameters and memory
|
1119 |
+
self.conv_after_body = nn.Sequential(
|
1120 |
+
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
1121 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
1122 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
1123 |
+
|
1124 |
+
# ------------------------- 3, high quality image reconstruction ------------------------- #
|
1125 |
+
if self.upsampler == 'pixelshuffle':
|
1126 |
+
# for classical SR
|
1127 |
+
self.conv_before_upsample = nn.Sequential(
|
1128 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
1129 |
+
|
1130 |
+
elif self.upsampler == 'pixelshuffledirect':
|
1131 |
+
# for lightweight SR (to save parameters)
|
1132 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
1133 |
+
(patches_resolution[0], patches_resolution[1]))
|
1134 |
+
elif self.upsampler == 'nearest+conv':
|
1135 |
+
# for real-world SR (less artifacts)
|
1136 |
+
assert self.upscale == 4, 'only support x4 now.'
|
1137 |
+
self.conv_before_upsample = nn.Sequential(
|
1138 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
1139 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
1140 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
1141 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
1142 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
1143 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
1144 |
+
else:
|
1145 |
+
# for image denoising and JPEG compression artifact reduction
|
1146 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
1147 |
+
|
1148 |
+
self.apply(self._init_weights)
|
1149 |
+
|
1150 |
+
def _init_weights(self, m):
|
1151 |
+
if isinstance(m, nn.Linear):
|
1152 |
+
trunc_normal_(m.weight, std=.02)
|
1153 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
1154 |
+
nn.init.constant_(m.bias, 0)
|
1155 |
+
elif isinstance(m, nn.LayerNorm):
|
1156 |
+
nn.init.constant_(m.bias, 0)
|
1157 |
+
nn.init.constant_(m.weight, 1.0)
|
1158 |
+
|
1159 |
+
@torch.jit.ignore
|
1160 |
+
def no_weight_decay(self):
|
1161 |
+
return {'absolute_pos_embed'}
|
1162 |
+
|
1163 |
+
@torch.jit.ignore
|
1164 |
+
def no_weight_decay_keywords(self):
|
1165 |
+
return {'relative_position_bias_table'}
|
1166 |
+
|
1167 |
+
def forward_features(self, x):
|
1168 |
+
x_size = (x.shape[2], x.shape[3])
|
1169 |
+
x = self.patch_embed(x)
|
1170 |
+
if self.ape:
|
1171 |
+
x = x + self.absolute_pos_embed
|
1172 |
+
x = self.pos_drop(x)
|
1173 |
+
|
1174 |
+
for layer in self.layers:
|
1175 |
+
x = layer(x, x_size)
|
1176 |
+
|
1177 |
+
x = self.norm(x) # b seq_len c
|
1178 |
+
x = self.patch_unembed(x, x_size)
|
1179 |
+
|
1180 |
+
return x
|
1181 |
+
|
1182 |
+
def forward(self, x):
|
1183 |
+
|
1184 |
+
if self.upsampler == 'pixelshuffle':
|
1185 |
+
# for classical SR
|
1186 |
+
x = self.conv_first(x)
|
1187 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
1188 |
+
x = self.conv_before_upsample(x)
|
1189 |
+
|
1190 |
+
elif self.upsampler == 'pixelshuffledirect':
|
1191 |
+
# for lightweight SR
|
1192 |
+
x = self.conv_first(x)
|
1193 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
1194 |
+
|
1195 |
+
elif self.upsampler == 'nearest+conv':
|
1196 |
+
# for real-world SR
|
1197 |
+
x = self.conv_first(x)
|
1198 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
1199 |
+
x = self.conv_before_upsample(x)
|
1200 |
+
|
1201 |
+
else:
|
1202 |
+
# for image denoising and JPEG compression artifact reduction
|
1203 |
+
x_first = self.conv_first(x)
|
1204 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
1205 |
+
x = x + self.conv_last(res)
|
1206 |
+
|
1207 |
+
return x
|
1208 |
+
|
1209 |
+
def flops(self):
|
1210 |
+
flops = 0
|
1211 |
+
h, w = self.patches_resolution
|
1212 |
+
flops += h * w * 3 * self.embed_dim * 9
|
1213 |
+
flops += self.patch_embed.flops()
|
1214 |
+
for layer in self.layers:
|
1215 |
+
flops += layer.flops()
|
1216 |
+
flops += h * w * 3 * self.embed_dim * self.embed_dim
|
1217 |
+
flops += self.upsample.flops()
|
1218 |
+
return flops
|
1219 |
+
|
1220 |
+
|
1221 |
+
|
1222 |
+
|
1223 |
+
if __name__ == '__main__':
|
1224 |
+
upscale = 4
|
1225 |
+
window_size = 8
|
1226 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
1227 |
+
width = (720 // upscale // window_size + 1) * window_size
|
1228 |
+
model = SwinIR(
|
1229 |
+
upscale=2,
|
1230 |
+
img_size=(height, width),
|
1231 |
+
window_size=window_size,
|
1232 |
+
img_range=1.,
|
1233 |
+
depths=[6, 6, 6, 6],
|
1234 |
+
embed_dim=60,
|
1235 |
+
num_heads=[6, 6, 6, 6],
|
1236 |
+
mlp_ratio=2,
|
1237 |
+
upsampler='pixelshuffledirect')
|
1238 |
+
print(model)
|
1239 |
+
print(height, width, model.flops() / 1e9)
|
1240 |
+
|
1241 |
+
x = torch.randn((1, 3, height, width))
|
1242 |
+
x = model(x)
|
1243 |
+
print(x.shape)
|