mt-cly commited on
Commit
909940e
·
1 Parent(s): a6d2ec4
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python build and package directories
2
+ build/
3
+ gscuda.egg-info/
4
+
5
+ # Additional common Python ignore patterns
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ *.so
10
+ *.egg
11
+ *.egg-info/
12
+
13
+ # IDE and editor files
14
+ .vscode/
15
+ .idea/
16
+ *.swp
17
+ *.swo
18
+ *~
19
+
20
+ # OS generated files
21
+ .DS_Store
22
+ .DS_Store?
23
+ ._*
24
+ .Spotlight-V100
25
+ .Trashes
26
+ ehthumbs.db
27
+ Thumbs.db
28
+
29
+ # Gradio cache
30
+ .gradio/
31
+ .setup_complete
README.md CHANGED
@@ -1,13 +1,16 @@
1
  ---
2
  title: GSASR
3
- emoji: 👀
4
- colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.33.0
 
8
  app_file: app.py
9
  pinned: false
 
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: GSASR
3
+ emoji: 🌖
4
+ colorFrom: pink
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
+ python_version: 3.10
9
  app_file: app.py
10
  pinned: false
11
+ # suggested_hardware: zero-a10g
12
  license: mit
13
+ short_description: GSASR(2d gaussian for arbitrary-scale super-resolution)
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import math
8
+ import torch.nn.functional as F
9
+ import os
10
+ import tempfile
11
+ import time
12
+ import threading
13
+
14
+ from utils.hatropeamp import HATNOUP_ROPE_AMP
15
+ from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
16
+ from utils.edsrbaseline import EDSRNOUP
17
+ from utils.hatropeamp import HATNOUP_ROPE_AMP
18
+ from utils.rdn import RDNNOUP
19
+ from utils.swinir import SwinIRNOUP
20
+ from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
21
+ from utils.gaussian_splatting import generate_2D_gaussian_splatting_step
22
+ from utils.split_and_joint_image import split_and_joint_image
23
+ from huggingface_hub import hf_hub_download
24
+ import subprocess
25
+ import sys
26
+ import spaces
27
+
28
+
29
+
30
+ # Device setup
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ # Global stop flag for interrupting inference
34
+ stop_inference = False
35
+ inference_lock = threading.Lock()
36
+
37
+ def load_model(
38
+ pretrained_model_name_or_path: str = "mutou0308/GSASR",
39
+ model_name: str = "HATL_SA1B",
40
+ device: str | torch.device = "cuda"
41
+ ):
42
+ enc_path = hf_hub_download(
43
+ repo_id=pretrained_model_name_or_path, filename=os.path.join(model_name, 'encoder.pth')
44
+ )
45
+ dec_path = hf_hub_download(
46
+ repo_id=pretrained_model_name_or_path, filename=os.path.join(model_name, 'decoder.pth')
47
+ )
48
+
49
+ enc_weight = torch.load(enc_path, weights_only=True)['params_ema']
50
+ dec_weight = torch.load(dec_path, weights_only=True)['params_ema']
51
+
52
+ if model_name in ['EDSR_DIV2K', 'EDSR_DF2K']:
53
+ encoder = EDSRNOUP()
54
+ decoder = Fea2GS_ROPE_AMP()
55
+ elif model_name in ['RDN_DIV2K', 'RDN_DF2K']:
56
+ encoder = RDNNOUP()
57
+ decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks = 2)
58
+ elif model_name in ['SwinIR_DIV2K', 'SwinIR_DF2K']:
59
+ encoder = SwinIRNOUP()
60
+ decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks=2, num_crossattn_layers=4, num_gs_seed=256, window_size=16)
61
+ elif model_name in ['HATL_SA1B']:
62
+ encoder = HATNOUP_ROPE_AMP()
63
+ decoder = Fea2GS_ROPE_AMP(channel=192, num_crossattn_blocks=4, num_crossattn_layers=4, num_selfattn_blocks=8, num_selfattn_layers=6,
64
+ num_gs_seed=256, window_size=16)
65
+ else:
66
+ raise ValueError(f"args.model-{model_name} must be in ['EDSR_DIV2K', 'EDSR_DF2K', 'RDN_DIV2K', 'RDN_DF2K', 'SwinIR_DIV2K', 'SwinIR_DF2K', 'HATL_SA1B']")
67
+
68
+ encoder.load_state_dict(enc_weight, strict=True)
69
+ decoder.load_state_dict(dec_weight, strict=True)
70
+ encoder.eval()
71
+ decoder.eval()
72
+ encoder = encoder.to(device)
73
+ decoder = decoder.to(device)
74
+ return encoder, decoder
75
+
76
+
77
+ def preprocess(x, denominator=16):
78
+ """Preprocess image to ensure dimensions are multiples of denominator"""
79
+ _, c, h, w = x.shape
80
+ if h % denominator > 0:
81
+ pad_h = denominator - h % denominator
82
+ else:
83
+ pad_h = 0
84
+ if w % denominator > 0:
85
+ pad_w = denominator - w % denominator
86
+ else:
87
+ pad_w = 0
88
+ x_new = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
89
+ return x_new
90
+
91
+ def postprocess(x, gt_size_h, gt_size_w):
92
+ """Post-process by cropping to target size"""
93
+ x_new = x[:, :, :gt_size_h, :gt_size_w]
94
+ return x_new
95
+
96
+ def should_use_tile(image_height, image_width, threshold=1024):
97
+ """Determine if tile processing should be used based on image resolution"""
98
+ return max(image_height, image_width) > threshold
99
+
100
+ def set_stop_flag():
101
+ """Set the global stop flag to interrupt inference"""
102
+ global stop_inference
103
+ with inference_lock:
104
+ stop_inference = True
105
+ return "🛑 Stopping inference...", gr.update(interactive=False)
106
+
107
+ def reset_stop_flag():
108
+ """Reset the global stop flag"""
109
+ global stop_inference
110
+ with inference_lock:
111
+ stop_inference = False
112
+
113
+ def check_stop_flag():
114
+ """Check if inference should be stopped"""
115
+ global stop_inference
116
+ with inference_lock:
117
+ return stop_inference
118
+
119
+ @spaces.GPU
120
+ def super_resolution_inference(image, scale=4.0):
121
+ """Super-resolution inference function with automatic tile processing"""
122
+
123
+ # Check if gscuda setup has been run
124
+ setup_marker = ".setup_complete"
125
+ if not os.path.exists(setup_marker):
126
+ print("First run detected, installing dependencies...")
127
+ try:
128
+ # subprocess.check_call(["pip", "install", "-e", "."])
129
+ subprocess.check_call(["pip", "install", "dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl"])
130
+ # Create marker file to indicate setup is complete
131
+ with open(setup_marker, "w") as f:
132
+ f.write("Setup completed")
133
+ print("Setup completed successfully!")
134
+ except subprocess.CalledProcessError as e:
135
+ return None, f"❌ Setup failed with error: {e}", None
136
+
137
+
138
+
139
+ if image is None:
140
+ return None, "Please upload an image", None
141
+
142
+ # Load model
143
+ encoder, decoder = load_model(model_name="HATL_SA1B")
144
+
145
+ # Reset stop flag at the beginning
146
+ reset_stop_flag()
147
+
148
+ # Fixed parameters
149
+ tile_overlap = 16 # Fixed overlap size
150
+ crop_size = 8 # Fixed crop size
151
+ tile_size = 1024 # Fixed tile size for large images
152
+
153
+ try:
154
+ # Check for interruption
155
+ if check_stop_flag():
156
+ return None, "❌ Inference interrupted", None
157
+
158
+ # Convert PIL image to numpy array
159
+ img_np = np.array(image)
160
+ if len(img_np.shape) == 3:
161
+ img_np = img_np[:, :, [2, 1, 0]] # RGB to BGR
162
+
163
+ # Convert to tensor
164
+ img = torch.from_numpy(np.transpose(img_np.astype(np.float32) / 255., (2, 0, 1))).float()
165
+ img = img.unsqueeze(0).to(device)
166
+
167
+ # Check for interruption
168
+ if check_stop_flag():
169
+ return None, "❌ Inference interrupted", None
170
+
171
+ # Calculate target size
172
+ gt_size = [math.floor(scale * img.shape[2]), math.floor(scale * img.shape[3])]
173
+
174
+ # Determine if tile processing should be used
175
+ use_tile = should_use_tile(img.shape[2], img.shape[3])
176
+
177
+ # Force AMP mixed precision
178
+ with torch.inference_mode():
179
+ with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
180
+ # Check for interruption before main processing
181
+ if check_stop_flag():
182
+ return None, "❌ Inference interrupted", None
183
+
184
+ if use_tile:
185
+ # Use tile processing
186
+ assert tile_size % 16 == 0, f"tile_size-{tile_size} must be divisible by 16"
187
+ assert 2 * tile_overlap < tile_size, f"2 * tile_overlap must be less than tile_size"
188
+ assert 2 * crop_size <= tile_overlap, f"2 * crop_size must be less than or equal to tile_overlap"
189
+
190
+ with torch.no_grad():
191
+ output = split_and_joint_image(
192
+ lq=img,
193
+ scale_factor=scale,
194
+ split_size=tile_size,
195
+ overlap_size=tile_overlap,
196
+ model_g=encoder,
197
+ model_fea2gs=decoder,
198
+ crop_size=crop_size,
199
+ scale_modify=torch.tensor([scale, scale]),
200
+ default_step_size=1.2,
201
+ cuda_rendering=True,
202
+ mode='scale_modify',
203
+ if_dmax=True,
204
+ dmax_mode='fix',
205
+ dmax=0.1
206
+ )
207
+ else:
208
+ # Direct processing without tiles
209
+ lq_pad = preprocess(img, 16) # denominator=16 for HATL
210
+ gt_size_pad = torch.tensor([math.floor(scale * lq_pad.shape[2]),
211
+ math.floor(scale * lq_pad.shape[3])])
212
+ gt_size_pad = gt_size_pad.unsqueeze(0)
213
+
214
+ with torch.no_grad():
215
+ # Check for interruption before encoder
216
+ if check_stop_flag():
217
+ return None, "❌ Inference interrupted", None
218
+
219
+ # Encoder output
220
+ encoder_output = encoder(lq_pad) # b,c,h,w
221
+
222
+ # Check for interruption before decoder
223
+ if check_stop_flag():
224
+ return None, "❌ Inference interrupted", None
225
+
226
+ scale_vector = torch.tensor(scale, dtype=torch.float32).unsqueeze(0).to(device)
227
+
228
+ # Decoder output
229
+ batch_gs_parameters = decoder(encoder_output, scale_vector)
230
+ gs_parameters = batch_gs_parameters[0, :]
231
+
232
+ # Check for interruption before gaussian rendering
233
+ if check_stop_flag():
234
+ return None, "❌ Inference interrupted", None
235
+
236
+ # Gaussian rendering
237
+ b_output = generate_2D_gaussian_splatting_step(
238
+ gs_parameters=gs_parameters,
239
+ sr_size=gt_size_pad[0],
240
+ scale=scale,
241
+ sample_coords=None,
242
+ scale_modify=torch.tensor([scale, scale]),
243
+ default_step_size=1.2,
244
+ cuda_rendering=True,
245
+ mode='scale_modify',
246
+ if_dmax=True,
247
+ dmax_mode='fix',
248
+ dmax=0.1
249
+ )
250
+ output = b_output.unsqueeze(0)
251
+
252
+ # Check for interruption before post-processing
253
+ if check_stop_flag():
254
+ return None, "❌ Inference interrupted", None
255
+
256
+ # Post-processing
257
+ output = postprocess(output, gt_size[0], gt_size[1])
258
+
259
+ # Convert back to PIL image format
260
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
261
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # BGR to RGB
262
+ output = (output * 255.0).round().astype(np.uint8)
263
+
264
+ # Convert to PIL image
265
+ output_pil = Image.fromarray(output)
266
+
267
+ # Generate result information
268
+ original_size = f"{img.shape[3]}x{img.shape[2]}"
269
+ output_size = f"{output.shape[1]}x{output.shape[0]}"
270
+ tile_info = f"Tile processing enabled (size: {tile_size})" if use_tile else "Direct processing (no tiles)"
271
+ result_info = f"✅ Processing completed successfully!\nOriginal size: {original_size}\nSuper-resolution size: {output_size}\nScale factor: {scale:.2f}x\nProcessing mode: {tile_info}\nAMP acceleration: Force enabled\nOverlap size: {tile_overlap}\nCrop size: {crop_size}"
272
+
273
+ return output_pil, result_info, output_pil
274
+
275
+ except Exception as e:
276
+ if check_stop_flag():
277
+ return None, "❌ Inference interrupted", None
278
+ return None, f"❌ Error during processing: {str(e)}", None
279
+
280
+ def predict(image, scale):
281
+ """Gradio prediction function"""
282
+ output_image, info, download_image = super_resolution_inference(image, scale)
283
+
284
+ # If processing successful, save image for download
285
+ if output_image is not None:
286
+ # Create temporary filename
287
+ timestamp = int(time.time())
288
+ temp_filename = f"GSASR_SR_result_{scale}x_{timestamp}.png"
289
+ temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
290
+
291
+ # Save image
292
+ output_image.save(temp_path, "PNG")
293
+
294
+ return output_image, temp_path, "✅ Ready", gr.update(interactive=True)
295
+ else:
296
+ return output_image, None, info if info else "❌ Processing failed", gr.update(interactive=True)
297
+
298
+ # Create Gradio interface
299
+ with gr.Blocks(title="🚀 GSASR (2D Gaussian Splatting Super-Resolution)") as demo:
300
+ gr.Markdown("# **🚀 GSASR (Generalized and efficient 2d gaussian splatting for arbitrary-scale super-resolution)**")
301
+ gr.Markdown("Official demo for GSASR. Please refer to our [paper](https://arxiv.org/pdf/2501.06838), [project page](https://mt-cly.github.io/GSASR.github.io/), and [github](https://github.com/ChrisDud0257/GSASR) for more details.")
302
+
303
+ with gr.Row():
304
+ with gr.Column():
305
+ input_image = gr.Image(type="pil", label="Input Image")
306
+
307
+ # Scale parameters
308
+ with gr.Group():
309
+ gr.Markdown("### SR Scale")
310
+ scale_slider = gr.Slider(minimum=1.0, maximum=30.0, value=4.0, step=0.1, label="SR Scale")
311
+
312
+ # Control buttons
313
+ with gr.Row():
314
+ submit_btn = gr.Button("🚀 Start Super-Resolution", variant="primary")
315
+ stop_btn = gr.Button("🛑 Stop Inference", variant="stop")
316
+
317
+ with gr.Column():
318
+ output_image = gr.Image(type="pil", label="Super-Resolution Result")
319
+
320
+ # Status display
321
+ status_text = gr.Textbox(label="Status", value="✅ Ready", interactive=False)
322
+
323
+ # Download component
324
+ with gr.Group():
325
+ gr.Markdown("### 📥 Download Super-Resolution Result")
326
+ download_btn = gr.File(visible=True)
327
+
328
+ # Event handlers
329
+ submit_event = submit_btn.click(
330
+ fn=predict,
331
+ inputs=[input_image, scale_slider],
332
+ outputs=[output_image, download_btn, status_text, stop_btn]
333
+ )
334
+
335
+ stop_btn.click(
336
+ fn=set_stop_flag,
337
+ inputs=[],
338
+ outputs=[status_text, stop_btn],
339
+ cancels=[submit_event]
340
+ )
341
+
342
+ # Example images
343
+ gr.Markdown("### 📚 Example Images")
344
+ gr.Markdown("Try these examples with different scales:")
345
+
346
+ gr.Examples(
347
+ examples=[
348
+ ["assets/0846x4.png", 1.5],
349
+ ["assets/0892x4.png", 2.8],
350
+ ["assets/0873x4_cropped_120x120.png", 30.0]
351
+ ],
352
+ inputs=[input_image, scale_slider],
353
+ examples_per_page=3,
354
+ cache_examples=False,
355
+ label="Examples"
356
+ )
357
+
358
+ if __name__ == "__main__":
359
+ demo.launch(share=True, server_name="0.0.0.0")
assets/0846x4.png ADDED

Git LFS Details

  • SHA256: 1ed26d96cbd5885f73dfdffbebbe9a048276036bf050c435a4da190199a932a0
  • Pointer size: 131 Bytes
  • Size of remote file: 262 kB
assets/0873.png ADDED

Git LFS Details

  • SHA256: 3a76a1452be69f0a04bddaeffa825bff46027d7155bb24479fab93c12db9bd73
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
assets/0873x4.png ADDED

Git LFS Details

  • SHA256: 2c034622b96885845c4438c25a5248afc2a5bff00b89734917189f292e57754f
  • Pointer size: 131 Bytes
  • Size of remote file: 331 kB
assets/0873x4_cropped_120x120.png ADDED

Git LFS Details

  • SHA256: b21380583809cce487b5129f93451148dc9954d9d4ebebcdbb824fdbdc1198a3
  • Pointer size: 130 Bytes
  • Size of remote file: 32.4 kB
assets/0892x4.png ADDED

Git LFS Details

  • SHA256: e95aebc62748c232bfc5942ad506e5d2d31323b7d10cb977a10287065293ce0b
  • Pointer size: 131 Bytes
  • Size of remote file: 315 kB
assets/Screenshot_cropped_180x100.png ADDED

Git LFS Details

  • SHA256: 30241bfe51891e2c19c3fe9b949dde8dee50d4baddcbd5e1612befed111543f8
  • Pointer size: 130 Bytes
  • Size of remote file: 48.2 kB
dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:832b5f0cd6cd078e39a8bf68c481488cf606ec9633591d4d981794338a3f2b29
3
+ size 90122
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124
2
+ torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124
3
+ torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
4
+ # gradio==5.32.0
5
+ gradio==5.23.0
6
+ huggingface-hub==0.32.3
7
+ pillow==11.2.1
8
+ numpy==1.23.0
9
+ einops==0.8.1
10
+ opencv-python==4.11.0.86
11
+ pydantic==2.10.6
12
+ # dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl
setup.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+ import os
4
+ import torch
5
+
6
+ print("Building gscuda")
7
+ # 假设源文件在 gs_cuda 目录下
8
+ file_path = "utils/gs_cuda_dmax"
9
+
10
+ setup(
11
+ name="gscuda", # 模块名
12
+ ext_modules=[
13
+ CUDAExtension(
14
+ name="gscuda", # 可以直接作为模块导入
15
+ sources=[
16
+ os.path.join(file_path, "gswrapper.cpp"),
17
+ os.path.join(file_path, "gs.cu")
18
+ ],
19
+ # 设置运行时库路径(可选)
20
+ library_dirs=[os.path.join(os.path.dirname(torch.__file__), 'lib')],
21
+ )
22
+ ],
23
+ cmdclass={
24
+ "build_ext": BuildExtension
25
+ },
26
+ )
utils/edsrbaseline.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ from itertools import repeat
7
+ from torch import nn as nn
8
+ from torch.nn import functional as F
9
+ from torch.nn import init as init
10
+ from torch.nn.modules.batchnorm import _BatchNorm
11
+
12
+ @torch.no_grad()
13
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
14
+ """Initialize network weights.
15
+
16
+ Args:
17
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
18
+ scale (float): Scale initialized weights, especially for residual
19
+ blocks. Default: 1.
20
+ bias_fill (float): The value to fill bias. Default: 0
21
+ kwargs (dict): Other arguments for initialization function.
22
+ """
23
+ if not isinstance(module_list, list):
24
+ module_list = [module_list]
25
+ for module in module_list:
26
+ for m in module.modules():
27
+ if isinstance(m, nn.Conv2d):
28
+ init.kaiming_normal_(m.weight, **kwargs)
29
+ m.weight.data *= scale
30
+ if m.bias is not None:
31
+ m.bias.data.fill_(bias_fill)
32
+ elif isinstance(m, nn.Linear):
33
+ init.kaiming_normal_(m.weight, **kwargs)
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+ elif isinstance(m, _BatchNorm):
38
+ init.constant_(m.weight, 1)
39
+ if m.bias is not None:
40
+ m.bias.data.fill_(bias_fill)
41
+
42
+ def make_layer(basic_block, num_basic_block, **kwarg):
43
+ """Make layers by stacking the same blocks.
44
+
45
+ Args:
46
+ basic_block (nn.module): nn.module class for basic block.
47
+ num_basic_block (int): number of blocks.
48
+
49
+ Returns:
50
+ nn.Sequential: Stacked blocks in nn.Sequential.
51
+ """
52
+ layers = []
53
+ for _ in range(num_basic_block):
54
+ layers.append(basic_block(**kwarg))
55
+ return nn.Sequential(*layers)
56
+
57
+ class ResidualBlockNoBN(nn.Module):
58
+ """Residual block without BN.
59
+
60
+ Args:
61
+ num_feat (int): Channel number of intermediate features.
62
+ Default: 64.
63
+ res_scale (float): Residual scale. Default: 1.
64
+ pytorch_init (bool): If set to True, use pytorch default init,
65
+ otherwise, use default_init_weights. Default: False.
66
+ """
67
+
68
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
69
+ super(ResidualBlockNoBN, self).__init__()
70
+ self.res_scale = res_scale
71
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
72
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
73
+ self.relu = nn.ReLU(inplace=True)
74
+
75
+ if not pytorch_init:
76
+ default_init_weights([self.conv1, self.conv2], 0.1)
77
+
78
+ def forward(self, x):
79
+ identity = x
80
+ out = self.conv2(self.relu(self.conv1(x)))
81
+ return identity + out * self.res_scale
82
+
83
+
84
+
85
+ class EDSRNOUP(nn.Module):
86
+ def __init__(self,
87
+ num_in_ch=3,
88
+ num_out_ch=3,
89
+ num_feat=64,
90
+ num_block=16,
91
+ upscale=4,
92
+ res_scale=1):
93
+ super(EDSRNOUP, self).__init__()
94
+
95
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
96
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
97
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+
99
+
100
+ def forward(self, x):
101
+
102
+ x = self.conv_first(x)
103
+ res = self.conv_after_body(self.body(x))
104
+ x = res + x
105
+
106
+ return res
107
+
108
+
109
+ if __name__ == '__main__':
110
+ x = torch.randn(8,3,48,48)
111
+ model = EDSRNOUP(num_in_ch=3, num_out_ch=3)
112
+ y = model(x)
113
+ print(y.shape)
utils/fea2gsropeamp.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import print_function
3
+ from __future__ import division
4
+
5
+ import warnings
6
+ import math
7
+ import copy
8
+ from einops import rearrange
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_, kaiming_normal_
13
+ from einops import rearrange
14
+ from torch.utils.checkpoint import checkpoint
15
+ from functools import partial
16
+ from typing import Any, Optional, Tuple
17
+ import numpy as np
18
+
19
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
20
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
21
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
22
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
23
+ def norm_cdf(x):
24
+ # Computes standard normal cumulative distribution function
25
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
26
+
27
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
28
+ warnings.warn(
29
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
30
+ 'The distribution of values may be incorrect.',
31
+ stacklevel=2)
32
+
33
+ with torch.no_grad():
34
+ # Values are generated by using a truncated uniform distribution and
35
+ # then using the inverse CDF for the normal distribution.
36
+ # Get upper and lower cdf values
37
+ low = norm_cdf((a - mean) / std)
38
+ up = norm_cdf((b - mean) / std)
39
+
40
+ # Uniformly fill tensor with values from [low, up], then translate to
41
+ # [2l-1, 2u-1].
42
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
43
+
44
+ # Use inverse cdf transform for normal distribution to get truncated
45
+ # standard normal
46
+ tensor.erfinv_()
47
+
48
+ # Transform to proper mean, std
49
+ tensor.mul_(std * math.sqrt(2.))
50
+ tensor.add_(mean)
51
+
52
+ # Clamp to ensure it's in the proper range
53
+ tensor.clamp_(min=a, max=b)
54
+ return tensor
55
+
56
+
57
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
58
+ r"""Fills the input Tensor with values drawn from a truncated
59
+ normal distribution.
60
+
61
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
62
+
63
+ The values are effectively drawn from the
64
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
65
+ with values outside :math:`[a, b]` redrawn until they are within
66
+ the bounds. The method used for generating the random values works
67
+ best when :math:`a \leq \text{mean} \leq b`.
68
+
69
+ Args:
70
+ tensor: an n-dimensional `torch.Tensor`
71
+ mean: the mean of the normal distribution
72
+ std: the standard deviation of the normal distribution
73
+ a: the minimum cutoff value
74
+ b: the maximum cutoff value
75
+
76
+ Examples:
77
+ >>> w = torch.empty(3, 5)
78
+ >>> nn.init.trunc_normal_(w)
79
+ """
80
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
81
+
82
+ def init_t_xy(end_x: int, end_y: int, zero_center=False):
83
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
84
+ t_x = (t % end_x).float()
85
+ t_y = torch.div(t, end_x, rounding_mode='floor').float()
86
+
87
+ return t_x, t_y
88
+
89
+ def init_random_2d_freqs(head_dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
90
+ freqs_x = []
91
+ freqs_y = []
92
+ theta = theta
93
+ mag = 1 / (theta ** (torch.arange(0, head_dim, 4)[: (head_dim // 4)].float() / head_dim))
94
+ for i in range(num_heads):
95
+ angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
96
+ fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
97
+ fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
98
+ freqs_x.append(fx)
99
+ freqs_y.append(fy)
100
+ freqs_x = torch.stack(freqs_x, dim=0)
101
+ freqs_y = torch.stack(freqs_y, dim=0)
102
+ freqs = torch.stack([freqs_x, freqs_y], dim=0)
103
+ return freqs
104
+
105
+ def compute_cis(freqs, t_x, t_y):
106
+ N = t_x.shape[0]
107
+ # No float 16 for this range
108
+ with torch.cuda.amp.autocast(enabled=False):
109
+ freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
110
+ freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
111
+ freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)
112
+
113
+ return freqs_cis
114
+
115
+
116
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
117
+ ndim = x.ndim
118
+ assert 0 <= 1 < ndim
119
+ # assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
120
+ # print(f"freqs_cis shape is {freqs_cis.shape}, x shape is {x.shape}")
121
+ if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
122
+ shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
123
+ elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
124
+ shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
125
+
126
+ return freqs_cis.view(*shape)
127
+
128
+ def apply_rotary_emb(
129
+ xq: torch.Tensor,
130
+ xk: torch.Tensor,
131
+ freqs_cis: torch.Tensor,
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ # print(f"xq shape is {xq.shape}, xq.shape[:-1] is {xq.shape[:-1]}")
134
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
135
+ # print(f"xq_ shape is {xq_.shape}")
136
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
137
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
138
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
139
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
140
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
141
+
142
+ def apply_rotary_emb_single(x, freqs_cis):
143
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
144
+ seq_len = x_.shape[2]
145
+ freqs_cis = freqs_cis[:, :seq_len, :]
146
+ freqs_cis = freqs_cis.unsqueeze(0).expand_as(x_)
147
+ x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
148
+ return x_out.type_as(x).to(x.device)
149
+
150
+ def window_partition(x, window_size):
151
+ # x is the feature from net_g
152
+ b, c, h, w = x.shape
153
+ windows = rearrange(x, 'b c (h_count dh) (w_count dw) -> (b h_count w_count) (dh dw) c', dh=window_size,
154
+ dw=window_size)
155
+ # h_count = h // window_size
156
+ # w_count = w // window_size
157
+ # windows = x.reshape(b,c,h_count, window_size, w_count, window_size)
158
+ # windows = windows.permute(0,1,2,4,3,5) #b,c,h_count,w_count,window_size,window_size
159
+ # windows = windows.reshape(b,c,h_count*w_count, window_size * window_size)
160
+ # windows = windows.permute(0,2,3,1) #b,h_count*w_count, window_size*window_size,c
161
+ # windows = windows.reshape(-1, window_size*window_size, c)
162
+
163
+ return windows
164
+
165
+
166
+ def with_pos_embed(tensor, pos):
167
+ return tensor if pos is None else tensor + pos
168
+
169
+
170
+ class MLP(nn.Module):
171
+ def __init__(self, in_features, hidden_features, out_features, act_layer=nn.ReLU):
172
+ super(MLP, self).__init__()
173
+ self.fc1 = nn.Linear(in_features, hidden_features)
174
+ self.act = act_layer()
175
+ self.fc2 = nn.Linear(hidden_features, out_features)
176
+
177
+ def forward(self, x):
178
+ x = self.fc1(x)
179
+ x = self.act(x)
180
+ x = self.fc2(x)
181
+ return x
182
+
183
+ class WindowCrossAttn(nn.Module):
184
+ def __init__(self, dim=180, num_heads=6, window_size=12, num_gs_seed=2304, rope_mixed = True, rope_theta = 10.0):
185
+ super(WindowCrossAttn, self).__init__()
186
+ self.dim = dim
187
+ self.num_heads = num_heads
188
+ self.window_size = window_size
189
+ self.num_gs_seed = num_gs_seed
190
+ self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
191
+
192
+
193
+ self.rope_mixed = rope_mixed
194
+
195
+ t_x, t_y = init_t_xy(end_x=max(self.num_gs_seed_sqrt, self.window_size), end_y=max(self.num_gs_seed_sqrt, self.window_size))
196
+ self.register_buffer('rope_t_x', t_x)
197
+ self.register_buffer('rope_t_y', t_y)
198
+
199
+ freqs = init_random_2d_freqs(
200
+ head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
201
+ rotate=self.rope_mixed
202
+ )
203
+ if self.rope_mixed:
204
+ self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
205
+ else:
206
+ self.register_buffer('rope_freqs', freqs)
207
+ freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
208
+ self.rope_freqs_cis = freqs_cis
209
+
210
+ self.qhead = nn.Linear(dim, dim, bias=True)
211
+ self.khead = nn.Linear(dim, dim, bias=True)
212
+ self.vhead = nn.Linear(dim, dim, bias=True)
213
+
214
+ self.proj = nn.Linear(dim, dim)
215
+
216
+
217
+ def forward(self, gs, feat):
218
+ # gs shape: b*h_count*w_count, num_gs, c the input gs here should already include pos embedding and scale embedding
219
+ # feat shape: b*h_count*w_count, dh*dw, c dh=dw=window_size
220
+ b_, num_gs, c = gs.shape
221
+ b_, n, c = feat.shape
222
+
223
+ q = self.qhead(gs) # b_, num_gs_, c
224
+ q = q.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
225
+ q = q.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
226
+
227
+ k = self.khead(feat) # b_, n_, c
228
+ k = k.reshape(b_, n, self.num_heads, c // self.num_heads)
229
+ k = k.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
230
+
231
+ v = self.vhead(feat) # b_, n_, c
232
+ v = v.reshape(b_, n, self.num_heads, c // self.num_heads)
233
+ v = v.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
234
+
235
+ ###### Apply rotary position embedding
236
+ if self.rope_mixed:
237
+ freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
238
+ else:
239
+ freqs_cis = self.rope_freqs_cis.to(gs.device)
240
+ q = apply_rotary_emb_single(q, freqs_cis)
241
+ k = apply_rotary_emb_single(k, freqs_cis)
242
+ #########
243
+
244
+ attn = F.scaled_dot_product_attention(q, k, v)
245
+
246
+ x = attn.transpose(1, 2).reshape(b_, num_gs, c)
247
+
248
+ x = self.proj(x)
249
+
250
+ return x
251
+
252
+
253
+ class WindowCrossAttnLayer(nn.Module):
254
+ def __init__(self, dim=180, num_heads=6, window_size=12, shift_size=0, num_gs_seed=2308, rope_mixed = True, rope_theta = 10.0):
255
+ super(WindowCrossAttnLayer, self).__init__()
256
+
257
+ self.gs_cross_attn_scale = nn.MultiheadAttention(dim, num_heads, batch_first=True)
258
+
259
+ self.norm1 = nn.LayerNorm(dim)
260
+ self.norm2 = nn.LayerNorm(dim)
261
+ self.norm3 = nn.LayerNorm(dim)
262
+ self.norm4 = nn.LayerNorm(dim)
263
+ self.shift_size = shift_size
264
+ self.window_size = window_size
265
+
266
+ self.window_cross_attn = WindowCrossAttn(dim=dim, num_heads=num_heads, window_size=window_size,
267
+ num_gs_seed=num_gs_seed, rope_mixed = rope_mixed, rope_theta = rope_theta)
268
+ self.mlp_crossattn_scale = MLP(in_features=dim, hidden_features=dim, out_features=dim)
269
+ self.mlp_crossattn_feature = MLP(in_features=dim, hidden_features=dim, out_features=dim)
270
+
271
+ def forward(self, x, query_pos, feat, scale_embedding):
272
+ # gs shape: b*h_count*w_count, num_gs, c
273
+ # query_pos shape: b*h_count*w_count, num_gs, c
274
+ # feat shape: b,c,h,w
275
+ # scale_embedding shape: b*h_count*w_count, 1, c
276
+
277
+ ###GS cross attn with scale embedding
278
+ resi = x
279
+ x = self.norm1(x)
280
+ # print(f"x: {x.shape} {x.device}, query_pos: {query_pos.shape}, {query_pos.device}, scale_embedding: {scale_embedding.shape}, {scale_embedding.device}")
281
+ x, _ = self.gs_cross_attn_scale(with_pos_embed(x, query_pos), scale_embedding, scale_embedding)
282
+ x = resi + x
283
+
284
+ ###FFN
285
+ resi = x
286
+ x = self.norm2(x)
287
+ x = self.mlp_crossattn_scale(x)
288
+ x = resi + x
289
+
290
+ ###cross attention for Q,K,V
291
+ resi = x
292
+ x = self.norm3(x)
293
+ if self.shift_size > 0:
294
+ shift_feat = torch.roll(feat, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
295
+ else:
296
+ shift_feat = feat
297
+ shift_feat = window_partition(shift_feat, self.window_size) # b*h_count*w_count, dh*dw, c dh=dw=window_size
298
+ x = self.window_cross_attn(with_pos_embed(x, query_pos),
299
+ shift_feat) # b*h_count*w_count, num_gs, c dh=dw=window_size
300
+ x = resi + x
301
+
302
+ ###FFN
303
+ resi = x
304
+ x = self.norm4(x)
305
+ x = self.mlp_crossattn_feature(x)
306
+ x = resi + x
307
+
308
+ return x
309
+
310
+
311
+ class WindowCrossAttnBlock(nn.Module):
312
+ def __init__(self, dim=180, window_size=12, num_heads=6, num_layers=4, num_gs_seed=230, rope_mixed = True, rope_theta = 10.0):
313
+ super(WindowCrossAttnBlock, self).__init__()
314
+
315
+ self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
316
+
317
+ self.mlp = nn.Sequential(
318
+ nn.Linear(dim, dim),
319
+ nn.ReLU(),
320
+ nn.Linear(dim, dim)
321
+ )
322
+ self.norm = nn.LayerNorm(dim)
323
+ self.blocks = nn.ModuleList([
324
+ WindowCrossAttnLayer(
325
+ dim=dim,
326
+ num_heads=num_heads,
327
+ window_size=window_size,
328
+ shift_size=0 if i % 2 == 0 else window_size // 2,
329
+ num_gs_seed=num_gs_seed,
330
+ rope_mixed = rope_mixed, rope_theta = rope_theta) for i in range(num_layers)
331
+ ])
332
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
333
+
334
+ def forward(self, x, query_pos, feat, scale_embedding, h_count, w_count):
335
+ resi = x
336
+ x = self.norm(x)
337
+ for block in self.blocks:
338
+ x = block(x, query_pos, feat, scale_embedding)
339
+ x = self.mlp(x)
340
+
341
+ x = rearrange(x, '(b m n) (h w) c -> b c (m h) (n w)', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
342
+ x = self.conv(x)
343
+ x = rearrange(x, 'b c (m h) (n w) -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
344
+
345
+ x = resi + x
346
+ return x
347
+
348
+
349
+ class GSSelfAttn(nn.Module):
350
+ def __init__(self, dim=180, num_heads=6, num_gs_seed_sqrt = 12, rope_mixed = True, rope_theta=10.0):
351
+ super(GSSelfAttn, self).__init__()
352
+ self.dim = dim
353
+ self.num_heads = num_heads
354
+ self.num_gs_seed_sqrt = num_gs_seed_sqrt
355
+
356
+ self.proj = nn.Linear(dim, dim)
357
+ self.rope_mixed = rope_mixed
358
+
359
+ t_x, t_y = init_t_xy(end_x=self.num_gs_seed_sqrt, end_y=self.num_gs_seed_sqrt)
360
+ self.register_buffer('rope_t_x', t_x)
361
+ self.register_buffer('rope_t_y', t_y)
362
+
363
+ freqs = init_random_2d_freqs(
364
+ head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
365
+ rotate=self.rope_mixed
366
+ )
367
+ if self.rope_mixed:
368
+ self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
369
+ else:
370
+ self.register_buffer('rope_freqs', freqs)
371
+ freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
372
+ self.rope_freqs_cis = freqs_cis
373
+
374
+ self.qhead = nn.Linear(dim, dim, bias=True)
375
+ self.khead = nn.Linear(dim, dim, bias=True)
376
+ self.vhead = nn.Linear(dim, dim, bias=True)
377
+
378
+ def forward(self, gs):
379
+ # gs shape: b*h_count*w_count, num_gs, c
380
+ # pos shape: b*h_count*w_count, num_gs, c
381
+ b_, num_gs, c = gs.shape
382
+
383
+ q = self.qhead(gs)
384
+ q = q.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
385
+ q = q.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
386
+
387
+ k = self.khead(gs)
388
+ k = k.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
389
+ k = k.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
390
+
391
+ v = self.vhead(gs)
392
+ v = v.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
393
+ v = v.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
394
+
395
+ ###### Apply rotary position embedding
396
+ if self.rope_mixed:
397
+ freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
398
+ else:
399
+ freqs_cis = self.rope_freqs_cis.to(gs.device)
400
+ q, k = apply_rotary_emb(q, k, freqs_cis)
401
+ #########
402
+
403
+ attn = F.scaled_dot_product_attention(q, k, v)
404
+
405
+ attn = attn.transpose(1, 2).reshape(b_, num_gs, c)
406
+
407
+
408
+ attn = self.proj(attn)
409
+
410
+ return attn
411
+
412
+
413
+ class GSSelfAttnLayer(nn.Module):
414
+ def __init__(self, dim=180, num_heads=6, num_gs_seed_sqrt = 12, shift_size = 0, rope_mixed = True, rope_theta=10.0):
415
+ super(GSSelfAttnLayer, self).__init__()
416
+
417
+ self.norm1 = nn.LayerNorm(dim)
418
+ self.norm2 = nn.LayerNorm(dim)
419
+ self.norm3 = nn.LayerNorm(dim)
420
+ self.norm4 = nn.LayerNorm(dim)
421
+
422
+ self.gs_self_attn = GSSelfAttn(dim = dim, num_heads = num_heads, num_gs_seed_sqrt = num_gs_seed_sqrt, rope_mixed = rope_mixed, rope_theta=rope_theta)
423
+
424
+ self.mlp_selfattn = MLP(in_features=dim, hidden_features=dim, out_features=dim)
425
+
426
+ self.num_gs_seed_sqrt = num_gs_seed_sqrt
427
+ self.shift_size = shift_size
428
+
429
+ self.gs_cross_attn_scale = nn.MultiheadAttention(dim, num_heads, batch_first=True)
430
+
431
+ self.mlp_crossattn = MLP(in_features=dim, hidden_features=dim, out_features=dim)
432
+
433
+ def forward(self, gs, pos, h_count, w_count, scale_embedding):
434
+ # gs shape:b*h_count*w_count, num_gs_seed, channel
435
+ # pos shape: b*h_count*w_count, num_gs_seed, channel
436
+ # scale_embedding shape: b*h_count*w_count, 1, channel
437
+
438
+ # gs cross attn with scale_embedding
439
+ resi = gs
440
+ gs = self.norm3(gs)
441
+ gs, _ = self.gs_cross_attn_scale(with_pos_embed(gs, pos), scale_embedding, scale_embedding)
442
+ gs = gs + resi
443
+
444
+ # FFN
445
+ resi = gs
446
+ gs = self.norm4(gs)
447
+ gs = self.mlp_crossattn(gs)
448
+ gs = gs + resi
449
+
450
+ resi = gs
451
+ gs = self.norm1(gs)
452
+
453
+ #### shift gs
454
+ if self.shift_size > 0:
455
+ shift_gs = rearrange(gs, '(b m n) (h w) c -> b (m h) (n w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
456
+ shift_gs = torch.roll(shift_gs, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
457
+ shift_gs = rearrange(shift_gs, 'b (m h) (n w) c -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
458
+ else:
459
+ shift_gs = gs
460
+
461
+ #### gs self attention
462
+ gs = self.gs_self_attn(shift_gs)
463
+
464
+ #### shift gs back
465
+ if self.shift_size > 0:
466
+ shift_gs = rearrange(gs, '(b m n) (h w) c -> b (m h) (n w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
467
+ shift_gs = torch.roll(shift_gs, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
468
+ shift_gs = rearrange(shift_gs, 'b (m h) (n w) c -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
469
+ else:
470
+ shift_gs = gs
471
+
472
+ gs = shift_gs + resi
473
+
474
+ #FFN
475
+ resi = gs
476
+ gs = self.norm2(gs)
477
+ gs = self.mlp_selfattn(gs)
478
+ gs = gs + resi
479
+ return gs
480
+
481
+
482
+ class GSSelfAttnBlock(nn.Module):
483
+ def __init__(self, dim=180, num_heads=6, num_selfattn_layers=4, num_gs_seed_sqrt = 12, rope_mixed = True, rope_theta=10.0):
484
+ super(GSSelfAttnBlock, self).__init__()
485
+ self.num_gs_seed_sqrt = num_gs_seed_sqrt
486
+
487
+ self.mlp = nn.Sequential(
488
+ nn.Linear(dim, dim),
489
+ nn.ReLU(),
490
+ nn.Linear(dim, dim)
491
+ )
492
+ self.norm = nn.LayerNorm(dim)
493
+ self.blocks = nn.ModuleList([
494
+ GSSelfAttnLayer(
495
+ dim = dim,
496
+ num_heads = num_heads,
497
+ num_gs_seed_sqrt=num_gs_seed_sqrt,
498
+ shift_size=0 if i % 2 == 0 else num_gs_seed_sqrt // 2,
499
+ rope_mixed = rope_mixed, rope_theta=rope_theta
500
+ ) for i in range(num_selfattn_layers)
501
+ ])
502
+
503
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
504
+
505
+ def forward(self, gs, pos, h_count, w_count, scale_embedding):
506
+ resi = gs
507
+ gs = self.norm(gs)
508
+ for block in self.blocks:
509
+ gs = block(gs, pos, h_count, w_count, scale_embedding)
510
+
511
+ gs = self.mlp(gs)
512
+ gs = rearrange(gs, '(b m n) (h w) c -> b c (m h) (n w)', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
513
+ gs = self.conv(gs)
514
+ gs = rearrange(gs, 'b c (m h) (n w) -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
515
+ gs = gs + resi
516
+ return gs
517
+
518
+ class Fea2GS_ROPE_AMP(nn.Module):
519
+ def __init__(self, inchannel=64, channel=192, num_heads=6, num_crossattn_blocks=1, num_crossattn_layers=2, num_selfattn_blocks = 6, num_selfattn_layers = 6,
520
+ num_gs_seed=144, gs_up_factor=1.0, window_size=12, img_range=1.0, shuffle_scale1 = 2, shuffle_scale2 = 2, use_checkpoint = False,
521
+ rope_mixed = True, rope_theta = 10.0):
522
+ """
523
+ Args:
524
+ gs_repeat_factor: the ratio of gs embedding number and pixel number along width&height, will generate
525
+ (h * gs_repeat_factor) * (w * gs_repeat_factor) gs embedding, higher values means repeat more gs embedding.
526
+ gs_up_factor: how many 2d gaussian are generated by one gasussian embedding.
527
+ """
528
+ super(Fea2GS_ROPE_AMP, self).__init__()
529
+ self.channel = channel
530
+ self.nhead = num_heads
531
+ self.gs_up_factor = gs_up_factor
532
+ self.num_gs_seed = num_gs_seed
533
+ self.window_size = window_size
534
+ self.img_range = img_range
535
+ self.use_checkpoint = use_checkpoint
536
+
537
+ self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
538
+ self.gs_up_factor_sqrt = int(math.sqrt(gs_up_factor))
539
+
540
+ self.shuffle_scale1 = shuffle_scale1
541
+ self.shuffle_scale2 = shuffle_scale2
542
+
543
+ # shared gaussian embedding and its pos embedding
544
+ self.gs_embedding = nn.Parameter(torch.randn(self.num_gs_seed, channel), requires_grad=True)
545
+ self.pos_embedding = nn.Parameter(torch.randn(self.num_gs_seed, channel), requires_grad=True)
546
+
547
+ self.img_feat_proj = nn.Sequential(
548
+ nn.Conv2d(inchannel, channel, 3, 1, 1),
549
+ nn.ReLU(),
550
+ nn.Conv2d(channel, channel, 3, 1, 1)
551
+ )
552
+
553
+ self.window_crossattn_blocks = nn.ModuleList([
554
+ WindowCrossAttnBlock(dim=channel,
555
+ window_size=window_size,
556
+ num_heads=num_heads,
557
+ num_layers=num_crossattn_layers,
558
+ num_gs_seed=num_gs_seed, rope_mixed = rope_mixed, rope_theta = rope_theta) for i in range(num_crossattn_blocks)
559
+ ])
560
+
561
+ self.gs_selfattn_blocks = nn.ModuleList([
562
+ GSSelfAttnBlock(dim=channel,
563
+ num_heads=num_heads,
564
+ num_selfattn_layers=num_selfattn_layers,
565
+ num_gs_seed_sqrt=self.num_gs_seed_sqrt,
566
+ rope_mixed = rope_mixed, rope_theta=rope_theta
567
+ ) for i in range(num_selfattn_blocks)
568
+ ])
569
+
570
+ # GS sigma_x, sigma_y
571
+ self.mlp_block_sigma = nn.Sequential(
572
+ nn.Linear(channel, channel),
573
+ nn.ReLU(),
574
+ nn.Linear(channel, channel * 4),
575
+ nn.ReLU(),
576
+ nn.Linear(channel * 4, int(2 * gs_up_factor))
577
+ )
578
+
579
+ # GS rho
580
+ self.mlp_block_rho = nn.Sequential(
581
+ nn.Linear(channel, channel),
582
+ nn.ReLU(),
583
+ nn.Linear(channel, channel * 4),
584
+ nn.ReLU(),
585
+ nn.Linear(channel * 4, int(1 * gs_up_factor))
586
+ )
587
+
588
+ # GS alpha
589
+ self.mlp_block_alpha = nn.Sequential(
590
+ nn.Linear(channel, channel),
591
+ nn.ReLU(),
592
+ nn.Linear(channel, channel * 4),
593
+ nn.ReLU(),
594
+ nn.Linear(channel * 4, int(1 * gs_up_factor))
595
+ )
596
+
597
+ # GS RGB values
598
+ self.mlp_block_rgb = nn.Sequential(
599
+ nn.Linear(channel, channel),
600
+ nn.ReLU(),
601
+ nn.Linear(channel, channel * 4),
602
+ nn.ReLU(),
603
+ nn.Linear(channel * 4, int(3 * gs_up_factor))
604
+ )
605
+
606
+ # GS mean_x, mean_y
607
+ self.mlp_block_mean = nn.Sequential(
608
+ nn.Linear(channel, channel),
609
+ nn.ReLU(),
610
+ nn.Linear(channel, channel * 4),
611
+ nn.ReLU(),
612
+ nn.Linear(channel * 4, int(2 * gs_up_factor))
613
+ )
614
+
615
+ self.scale_mlp = nn.Sequential(
616
+ nn.Linear(1, channel * 4),
617
+ nn.ReLU(),
618
+ nn.Linear(channel * 4, channel)
619
+ )
620
+
621
+ self.UPNet = nn.Sequential(
622
+ nn.Conv2d(channel, channel * self.shuffle_scale1 * self.shuffle_scale1, 3, 1, 1),
623
+ nn.PixelShuffle(self.shuffle_scale1),
624
+ nn.Conv2d(channel, channel * self.shuffle_scale2 * self.shuffle_scale2, 3, 1, 1),
625
+ nn.PixelShuffle(self.shuffle_scale2)
626
+ )
627
+
628
+ self.conv_final = nn.Conv2d(channel, channel, 3, 1, 1)
629
+
630
+ @staticmethod
631
+ def get_N_reference_points(h, w, device='cuda'):
632
+ # step_y = 1/(h+1)
633
+ # step_x = 1/(w+1)
634
+ step_y = 1 / h
635
+ step_x = 1 / w
636
+ ref_y, ref_x = torch.meshgrid(torch.linspace(step_y / 2, 1 - step_y / 2, h, dtype=torch.float32, device=device),
637
+ torch.linspace(step_x / 2, 1 - step_x / 2, w, dtype=torch.float32, device=device))
638
+ reference_points = torch.stack((ref_x.reshape(-1), ref_y.reshape(-1)), -1)
639
+ reference_points = reference_points[None, :, None]
640
+ return reference_points
641
+
642
+ def forward(self, srcs, scale):
643
+ '''
644
+ using deformable detr decoder for cross attention
645
+ Args:
646
+ query: (batch_size, num_query, dim)
647
+ query_pos: (batch_size, num_query, dim)
648
+ srcs: (batch_size, dim, h1, w1)
649
+ '''
650
+ b, c, h, w = srcs.shape ###srcs is pad to the size that could be divided by window_size
651
+ query = self.gs_embedding.unsqueeze(0).unsqueeze(1).repeat(b, (h // self.window_size) * (w // self.window_size),
652
+ 1, 1) # b, h_count*w_count, num_gs_seed, channel
653
+ query = query.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
654
+ self.channel) # b*h_count*w_count, num_gs_seed, channel
655
+
656
+ scale = 1 / scale
657
+ scale = scale.unsqueeze(1) # b*1
658
+ scale_embedding = self.scale_mlp(scale) # b*channel
659
+ scale_embedding = scale_embedding.unsqueeze(1).unsqueeze(2).repeat(1, (h // self.window_size) * (
660
+ w // self.window_size), self.num_gs_seed, 1) # b, h_count*w_count, num_gs_seed, channel
661
+ scale_embedding = scale_embedding.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
662
+ self.channel) # b*h_count*w_count, num_gs_seed, channel
663
+
664
+ query_pos = self.pos_embedding.unsqueeze(0).unsqueeze(1).repeat(b, (h // self.window_size) * (
665
+ w // self.window_size), 1, 1) # b, h_count*w_count, num_gs_seed, channel
666
+
667
+ feat = self.img_feat_proj(srcs) # b*channel*h*w
668
+
669
+ query_pos = query_pos.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
670
+ self.channel) # b*h_count*w_count, num_gs_seed, channel
671
+
672
+ for block in self.window_crossattn_blocks:
673
+ if self.use_checkpoint:
674
+ query = checkpoint(block, query, query_pos, feat, scale_embedding, h // self.window_size, w // self.window_size)
675
+ else:
676
+ query = block(query, query_pos, feat, scale_embedding, h // self.window_size, w // self.window_size) # b*h_count*w_count, num_gs_seed, channel
677
+
678
+ resi = query
679
+ for block in self.gs_selfattn_blocks:
680
+ if self.use_checkpoint:
681
+ query = checkpoint(block, query, query_pos, h // self.window_size, w // self.window_size, scale_embedding)
682
+ else:
683
+ query = block(query, query_pos, h // self.window_size, w // self.window_size, scale_embedding)
684
+
685
+
686
+ query = rearrange(query, '(b m n) (h w) c -> b c (m h) (n w)', m=h // self.window_size, n=w // self.window_size,
687
+ h=self.num_gs_seed_sqrt)
688
+ query = self.conv_final(query)
689
+
690
+
691
+ resi = rearrange(resi, '(b m n) (h w) c -> b c (m h) (n w)', m=h // self.window_size, n=w // self.window_size,
692
+ h=self.num_gs_seed_sqrt)
693
+
694
+ query = query + resi
695
+ query = self.UPNet(query)
696
+ query = query.permute(0,2,3,1)
697
+
698
+ # query = rearrange(query, '(b m n) (h w) c -> b m h n w c', m=h // self.window_size, n=w // self.window_size,
699
+ # h=self.num_gs_seed_sqrt)
700
+
701
+ query_sigma = self.mlp_block_sigma(query).reshape(b, -1, 2)
702
+ query_rho = self.mlp_block_rho(query).reshape(b, -1, 1)
703
+ query_alpha = self.mlp_block_alpha(query).reshape(b, -1, 1)
704
+ query_rgb = self.mlp_block_rgb(query).reshape(b, -1, 3)
705
+ query_mean = self.mlp_block_mean(query).reshape(b, -1, 2)
706
+
707
+ query_mean = query_mean / torch.tensor(
708
+ [self.num_gs_seed_sqrt * (w // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2,
709
+ self.num_gs_seed_sqrt * (h // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2])[
710
+ None, None].to(query_mean.device) # b, h_count*w_count*num_gs_seed, 2
711
+
712
+ reference_offset = self.get_N_reference_points(self.num_gs_seed_sqrt * (h // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2,
713
+ self.num_gs_seed_sqrt * (w // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2, srcs.device)
714
+ query_mean = query_mean + reference_offset.reshape(1, -1, 2)
715
+
716
+ query = torch.cat([query_sigma, query_rho, query_alpha, query_rgb, query_mean],
717
+ dim=-1) # b, h_count*w_count*num_gs_seed, 9
718
+
719
+ return query
720
+
721
+
722
+ if __name__ == '__main__':
723
+ srcs = torch.randn(6, 64, 64, 64, requires_grad = True).cuda()
724
+ scale = torch.randn(6).cuda()
725
+ decoder = Fea2GS_ROPE_AMP(inchannel=64, channel=192, num_heads=6,
726
+ num_crossattn_blocks=1, num_crossattn_layers=2,
727
+ num_selfattn_blocks = 6, num_selfattn_layers = 6,
728
+ num_gs_seed=256, gs_up_factor=1.0, window_size=16,
729
+ img_range=1.0, shuffle_scale1 = 2, shuffle_scale2 = 2).cuda()
730
+ import time
731
+
732
+ for i in range(10):
733
+ torch.cuda.synchronize()
734
+ time1 = time.time()
735
+ # with torch.autocast(device_type = 'cuda'):
736
+ y = decoder(srcs, scale)
737
+ torch.cuda.synchronize()
738
+ time2 = time.time()
739
+ print(f"decoder time is {time2 - time1}")
740
+ print(y.shape)
741
+
742
+ torch.cuda.synchronize()
743
+ time3 = time.time()
744
+ y.sum().backward()
745
+ torch.cuda.synchronize()
746
+ time4 = time.time()
747
+ print(f"backward time is {time4 - time3}")
748
+
749
+
utils/gaussian_splatting.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import math
5
+ import torch.nn as nn
6
+
7
+ import torchvision.utils
8
+ from torchvision.utils import save_image
9
+
10
+
11
+ def rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device):
12
+ sr_h, sr_w = sr_size[0], sr_size[1]
13
+ num_gs = sigma_x.shape[0]
14
+
15
+ sigma_x = sigma_x[...,None]
16
+ sigma_y = sigma_y[...,None]
17
+ rho = rho[...,None]
18
+ covariance = torch.stack(
19
+ [torch.stack([sigma_x**2, rho*sigma_x*sigma_y], dim=-1),
20
+ torch.stack([rho*sigma_x*sigma_y, sigma_y**2], dim=-1)],
21
+ dim=-2
22
+ )
23
+
24
+ # Check for positive semi-definiteness
25
+ determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y)**2
26
+ if (determinant < 0).any():
27
+ raise ValueError("Covariance matrix must be positive semi-definite")
28
+
29
+ inv_covariance = torch.inverse(covariance)
30
+
31
+ # Sampling progress
32
+ num_step = int(10 * 2 / step_size)
33
+ ax_h_batch = torch.tensor([i * step_size for i in range(num_step)]).to(device)[None]
34
+ ax_h_batch -= ax_h_batch.mean()
35
+ ax_w_batch = torch.tensor([i * step_size for i in range(num_step)]).to(device)[None]
36
+ ax_w_batch -= ax_w_batch.mean()
37
+
38
+ # Expanding dims for broadcasting
39
+ ax_batch_expanded_x = ax_h_batch.unsqueeze(-1).expand(-1, -1, num_step)
40
+ ax_batch_expanded_y = ax_w_batch.unsqueeze(1).expand(-1, num_step, -1)
41
+
42
+ # Creating a batch-wise meshgrid using broadcasting
43
+ xx, yy = ax_batch_expanded_x, ax_batch_expanded_y
44
+
45
+ xy = torch.stack([xx, yy], dim=-1)
46
+
47
+ max_buffer = 2000
48
+ final_image = torch.zeros((3, sr_h, sr_w), device=device)
49
+ for i in range(num_gs // max_buffer + 1):
50
+ # print('processing gs buffer id:', i, num_gs // max_buffer )
51
+ s_idx, e_idx = i * max_buffer, min((i + 1) * max_buffer, num_gs)
52
+ buffer_size = e_idx - s_idx
53
+ if buffer_size == 0:
54
+ break
55
+ # print(f"buffer_size is {buffer_size}")
56
+ buff_inv_covariance = inv_covariance[s_idx:e_idx]
57
+ buff_covariance = covariance[s_idx:e_idx]
58
+ buffer_pixel_coords = coords[s_idx:e_idx]
59
+ buffer_alpha = colours_with_alpha[s_idx:e_idx].unsqueeze(-1).unsqueeze(-1)
60
+
61
+ z = torch.einsum('b...i,b...ij,b...j->b...', xy, -0.5 * buff_inv_covariance, xy)
62
+ kernel = torch.exp(z) / (2 * torch.tensor(np.pi, device=device) * torch.sqrt(torch.det(buff_covariance)).view(buffer_size, 1, 1))
63
+
64
+ kernel_max = kernel.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
65
+ kernel_normalized = kernel / (kernel_max + 1e-4)
66
+ kernel_reshaped = kernel_normalized.repeat(1, 3, 1).view(buffer_size * 3, num_step, num_step)
67
+ kernel_reshaped = kernel_reshaped.unsqueeze(0).reshape(buffer_size, 3, num_step, num_step)
68
+
69
+ b, c, h, w = kernel_reshaped.shape
70
+
71
+ # Create a batch of 2D affine matrices
72
+ theta = torch.zeros(b, 2, 3, dtype=torch.float32, device=device)
73
+ theta[:, 0, 0] = 1 * sr_w / num_step
74
+ theta[:, 1, 1] = 1 * sr_h / num_step
75
+ theta[:, 0, 2] = -buffer_pixel_coords[:, 0] * sr_w / num_step # !!!!!!!! note -1
76
+ theta[:, 1, 2] = -buffer_pixel_coords[:, 1] * sr_h / num_step # !!!!!!!! note -1
77
+
78
+ grid = F.affine_grid(theta, size=(b, c, sr_h, sr_w), align_corners=False) # !!!!! align_corners=False
79
+ kernel_reshaped_translated = F.grid_sample(kernel_reshaped, grid,
80
+ align_corners=False) # !!!! align_corners=False
81
+ buffer_final_image = buffer_alpha * kernel_reshaped_translated
82
+ final_image += buffer_final_image.sum(0)
83
+
84
+ return final_image
85
+
86
+ def rendering_cuda(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device):
87
+ from utils.gs_cuda.gswrapper import GSCUDA
88
+ sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() # (gs num, 3)
89
+ coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0
90
+ coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0
91
+ colours_with_alpha = colours_with_alpha.contiguous() # (gs num, 3)
92
+ rendered_img = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous()
93
+ # with torch.no_grad():
94
+ # final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img)
95
+ # final_image = (torch.sum(sigmas)+torch.sum(coords)+torch.sum(colours_with_alpha))*final_image
96
+ final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img)
97
+ final_image = final_image.permute(2, 0, 1).contiguous()
98
+ return final_image
99
+
100
+ def rendering_cuda_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, buffer_size = 1000000):
101
+ from utils.gs_cuda.gswrapper import GSCUDA
102
+ sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() # (gs num, 3)
103
+ coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0
104
+ coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0
105
+ colours_with_alpha = colours_with_alpha.contiguous() # (gs num, 3)
106
+ final_image = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous()
107
+
108
+ # buffer
109
+ buffer_num = len(sigma_x)// buffer_size+1
110
+ for buffer_id in range(buffer_num):
111
+ # print(f'processing{buffer_id+1}/{buffer_num}')
112
+ idx_start, idx_end = buffer_id * buffer_size, (buffer_id+1) * buffer_size
113
+ final_image = GSCUDA.apply(sigmas[idx_start:idx_end], coords[idx_start:idx_end],
114
+ colours_with_alpha[idx_start:idx_end], final_image)
115
+ # final_image += buffer_image
116
+ final_image = final_image.permute(2, 0, 1).contiguous()
117
+ return final_image
118
+
119
+ def rendering_cuda_dmax(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, dmax=1):
120
+ from utils.gs_cuda_dmax.gswrapper import GSCUDA
121
+ sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() # (gs num, 3)
122
+ coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0
123
+ coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0
124
+ colours_with_alpha = colours_with_alpha.contiguous() # (gs num, 3)
125
+ rendered_img = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous()
126
+ # with torch.no_grad():
127
+ # final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img, dmax)
128
+ # final_image = (torch.sum(sigmas)+torch.sum(coords)+torch.sum(colours_with_alpha))*final_image
129
+ final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img, dmax)
130
+ final_image = final_image.permute(2, 0, 1).contiguous()
131
+ return final_image
132
+
133
+ def rendering_cuda_dmax_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device, dmax=1, buffer_size = 1000000):
134
+ from utils.gs_cuda_dmax.gswrapper import GSCUDA
135
+ sigmas = torch.cat([sigma_y/step_size*2/(sr_size[1] - 1), sigma_x/step_size*2/(sr_size[0] - 1), rho], dim=-1).contiguous() # (gs num, 3)
136
+ coords[:, 0] = (coords[:, 0] + 1 - 1/sr_size[1]) * sr_size[1] / (sr_size[1] - 1) - 1.0
137
+ coords[:, 1] = (coords[:, 1] + 1 - 1/sr_size[0]) * sr_size[0] / (sr_size[0] - 1) - 1.0
138
+ colours_with_alpha = colours_with_alpha.contiguous() # (gs num, 3)
139
+
140
+ final_image = torch.zeros(sr_size[0], sr_size[1], 3).to(device).type(torch.float32).contiguous()
141
+ # with torch.no_grad():
142
+ # final_image = GSCUDA.apply(sigmas, coords, colours_with_alpha, rendered_img, dmax)
143
+ # final_image = (torch.sum(sigmas)+torch.sum(coords)+torch.sum(colours_with_alpha))*final_image
144
+
145
+ # buffer
146
+ buffer_num = len(sigma_x)// buffer_size+1
147
+ for buffer_id in range(buffer_num):
148
+ # print(f'processing{buffer_id+1}/{buffer_num}')
149
+ idx_start, idx_end = buffer_id * buffer_size, (buffer_id+1) * buffer_size
150
+ final_image = GSCUDA.apply(sigmas[idx_start:idx_end], coords[idx_start:idx_end],
151
+ colours_with_alpha[idx_start:idx_end], final_image, dmax)
152
+ # final_image += buffer_image
153
+
154
+ final_image = final_image.permute(2, 0, 1).contiguous()
155
+ return final_image
156
+
157
+
158
+ def generate_2D_gaussian_splatting_step(sr_size, gs_parameters, scale, scale_modify,
159
+ sample_coords = None, default_step_size = 1.2,
160
+ cuda_rendering=True, mode = 'scale_modify',
161
+ if_dmax = True,
162
+ dmax_mode = 'fix',
163
+ dmax = 25):
164
+
165
+ # set step_size according to scale factor
166
+ if mode == 'scale':
167
+ final_scale = scale
168
+ elif mode == 'scale_modify':
169
+ assert scale_modify[0] == scale_modify[1], f"scale_modify is not the same-{scale_modify}"
170
+ final_scale = scale_modify[0]
171
+ step_size = default_step_size/ final_scale
172
+
173
+ # prepare gaussian properties
174
+ sigma_x = 0.99999 * torch.sigmoid(gs_parameters[:, 0:1]) + 1e-6
175
+ sigma_y = 0.99999 * torch.sigmoid(gs_parameters[:, 1:2]) + 1e-6
176
+ rho = 0.999999 * torch.tanh(gs_parameters[:, 2:3])
177
+ alpha = torch.sigmoid(gs_parameters[:, 3:4])
178
+ colours = torch.sigmoid(gs_parameters[:, 4:7])
179
+ coords = (gs_parameters[:, 7:9] * 2 - 1)
180
+ colours_with_alpha = colours * alpha
181
+
182
+
183
+ ## todo for save GS parameters
184
+ # GS_parameters = torch.cat([sigma_x, sigma_y, rho, alpha, colours, coords], dim = 1)
185
+ # torch.save(GS_parameters.cpu(), "/home/notebook/code/personal/S9053766/chendu/myprojects/GSSR_20240606/results/0804_48*48.pt")
186
+ # print(f"GS_parameter shape is {GS_parameters.shape}")
187
+ # print(f"-------")
188
+
189
+ # todo for visualization the position of Gaussian
190
+ # select = (torch.randn_like(alpha[..., 0])>2.5)
191
+ # colours_with_alpha[select, 0] = 1
192
+ # colours_with_alpha[select, 1] = 0
193
+ # colours_with_alpha[select, 2] = 0
194
+ # todo for visualization the shape of Gaussian
195
+ # sigma_x = torch.ones_like(sigma_x)*0.05
196
+ # sigma_y = torch.ones_like(sigma_y)*0.05
197
+ # rho = torch.ones_like(rho) * 0
198
+ # colours_with_alpha = torch.ones_like(colours_with_alpha)*0.5
199
+
200
+ # rendering
201
+ if cuda_rendering:
202
+ if if_dmax:
203
+ if dmax_mode == 'dynamic':
204
+ dmax = (dmax + 2) / min(sr_size[0], sr_size[1])
205
+ elif dmax_mode == 'fix':
206
+ pass
207
+ else:
208
+ raise ValueError(f"dmax_mode-{dmax_mode} must be fix or dynamic")
209
+ final_image = rendering_cuda_dmax(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, dmax=dmax, device=sigma_x.device)
210
+ else:
211
+ final_image = rendering_cuda(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device)
212
+ else:
213
+ final_image = rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device)
214
+ if sample_coords is not None:
215
+ sample_RGB_values = [final_image[:, coord[0], coord[1]] for coord in sample_coords]
216
+ final_image = torch.stack(sample_RGB_values, dim = 1)
217
+ return final_image
218
+
219
+ def generate_2D_gaussian_splatting_step_buffer(sr_size, gs_parameters, scale, scale_modify,
220
+ sample_coords = None, default_step_size = 1.2,
221
+ cuda_rendering=True, mode = 'scale_modify',
222
+ if_dmax = True,
223
+ dmax_mode = 'fix',
224
+ dmax = 25,
225
+ buffer_size = 4000000):
226
+
227
+ # set step_size according to scale factor
228
+ if mode == 'scale':
229
+ final_scale = scale
230
+ elif mode == 'scale_modify':
231
+ assert scale_modify[0] == scale_modify[1], f"scale_modify is not the same-{scale_modify}"
232
+ final_scale = scale_modify[0]
233
+ step_size = default_step_size/ final_scale
234
+
235
+ # prepare gaussian properties
236
+ sigma_x = 0.99999 * torch.sigmoid(gs_parameters[:, 0:1]) + 1e-6
237
+ sigma_y = 0.99999 * torch.sigmoid(gs_parameters[:, 1:2]) + 1e-6
238
+ rho = 0.999999 * torch.tanh(gs_parameters[:, 2:3])
239
+ alpha = torch.sigmoid(gs_parameters[:, 3:4])
240
+ colours = torch.sigmoid(gs_parameters[:, 4:7])
241
+ coords = (gs_parameters[:, 7:9] * 2 - 1)
242
+ colours_with_alpha = colours * alpha
243
+
244
+ # rendering
245
+ if cuda_rendering:
246
+ if if_dmax:
247
+ if dmax_mode == 'dynamic':
248
+ dmax = (dmax + 2) / min(sr_size[0], sr_size[1])
249
+ elif dmax_mode == 'fix':
250
+ pass
251
+ else:
252
+ raise ValueError(f"dmax_mode-{dmax_mode} must be fix or dynamic")
253
+ final_image = rendering_cuda_dmax_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha,
254
+ sr_size, step_size, dmax=dmax, device=sigma_x.device,
255
+ buffer_size = buffer_size)
256
+ else:
257
+ final_image = rendering_cuda_buffer(sigma_x, sigma_y, rho, coords, colours_with_alpha,
258
+ sr_size, step_size, device=sigma_x.device,
259
+ buffer_size = buffer_size)
260
+ else:
261
+ final_image = rendering_python(sigma_x, sigma_y, rho, coords, colours_with_alpha, sr_size, step_size, device=sigma_x.device)
262
+ if sample_coords is not None:
263
+ sample_RGB_values = [final_image[:, coord[0], coord[1]] for coord in sample_coords]
264
+ final_image = torch.stack(sample_RGB_values, dim = 1)
265
+ return final_image
utils/gs_cuda/check.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from gswrapper import gaussiansplatting_render
3
+
4
+ def torch_version(sigmas, coords, colors, image_size):
5
+ h, w = image_size
6
+ c = colors.shape[-1]
7
+
8
+ if h >= 50 or w >= 50:
9
+ logger.warning(f'too large values for h({h}), w({w}), torch version would be slow')
10
+
11
+ rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
12
+
13
+ for hi in range(h):
14
+ for wi in range(w):
15
+ curh = 2*hi/(h-1)-1.0
16
+ curw = 2*wi/(w-1)-1.0
17
+
18
+ v = (curw-coords[:,0])**2/sigmas[:,0]**2
19
+ v -= (2*sigmas[:,2])*(curw-coords[:,0])*(curh-coords[:,1])/sigmas[:,0]/sigmas[:,1]
20
+ v += (curh-coords[:,1])**2/sigmas[:,1]**2
21
+ v *= -1.0/(2.0*(1-sigmas[:,2]**2))
22
+ v = torch.exp(v)
23
+
24
+ for ci in range(c):
25
+ rendered_img[hi, wi, ci] = torch.sum(v*colors[:, ci])
26
+
27
+ return rendered_img
28
+
29
+
30
+ if __name__ == "__main__":
31
+ s = 40 # the number of gs
32
+ image_size = (49, 49)
33
+
34
+ for _ in range(1):
35
+ print(f"--------------------------- begins --------------------------------")
36
+
37
+ sigmas = 0.999*torch.rand(s, 3).to(torch.float32).to("cuda")
38
+ # sigmas[:,:2] = 5*sigmas[:, :2]
39
+ coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
40
+ colors = torch.rand(s, 3).to(torch.float32).to("cuda")
41
+
42
+ # sigmas = torch.Tensor([[0.9196, 0.3979, 0.7784]]).to(torch.float32).to("cuda")
43
+ # coords = torch.Tensor([[-0.0469, -0.1726]]).to(torch.float32).to("cuda")
44
+ # colors = torch.Tensor([[0.3775, 0.2346, 0.1513]]).to(torch.float32).to("cuda")
45
+ # colors = torch.ones_like(coords[:,0:1])
46
+
47
+ print(f"sigmas: {sigmas}, \ncoords:{coords}, \ncolors:{colors}")
48
+
49
+ # --- check forward ---
50
+ with torch.no_grad():
51
+ rendered_img_th = torch_version(sigmas,coords,colors,image_size)
52
+ rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size)
53
+
54
+ #
55
+ distance = (rendered_img_th-rendered_img_cuda)**2
56
+ print(f"check forward - torch: {rendered_img_th[:2,:2,0]}")
57
+ print(f"check forward - cuda: {rendered_img_cuda[:2,:2,0]}")
58
+ print(f"check forward - distance: {distance[:2, :2, 0]}")
59
+ print(f"check forward - sum: {torch.sum(distance)}\n")
60
+ # --- ends ---
61
+
62
+ # --- check backward ---
63
+ sigmas.requires_grad_(True)
64
+ coords.requires_grad_(True)
65
+ colors.requires_grad_(True)
66
+ # sigmas.retain_grad()
67
+ # coords.retain_grad()
68
+ # colors.retain_grad()
69
+ weight = torch.rand_like(rendered_img_th) # make each pixel has different grads
70
+
71
+ sigmas.grad = None
72
+ coords.grad = None
73
+ colors.grad = None
74
+ rendered_img_th = torch_version(sigmas,coords,colors,image_size)
75
+ loss_th = torch.sum(weight*rendered_img_th)
76
+ loss_th.backward()
77
+
78
+ sigmas_grad_th = sigmas.grad
79
+ coords_grad_th = coords.grad
80
+ colors_grad_th = colors.grad
81
+
82
+ sigmas.grad = None
83
+ coords.grad = None
84
+ colors.grad = None
85
+ rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size)
86
+ loss_cuda = torch.sum(weight*rendered_img_cuda)
87
+ # loss_cuda = torch.sum(rendered_img_cuda)
88
+ loss_cuda.backward()
89
+
90
+ sigmas_grad_cuda = sigmas.grad
91
+ coords_grad_cuda = coords.grad
92
+ colors_grad_cuda = colors.grad
93
+
94
+ distance_sigmas_grad = (sigmas_grad_th-sigmas_grad_cuda)**2
95
+ distance_coords_grad = (coords_grad_th-coords_grad_cuda)**2
96
+ distance_colors_grad = (colors_grad_th-colors_grad_cuda)**2
97
+
98
+ print(f"check backward - sigmas - torch: {sigmas_grad_th[:2]}")
99
+ print(f"check backward - sigmas - cuda: {sigmas_grad_cuda[:2]}")
100
+ print(f"check backward - sigmas - distance: {distance_sigmas_grad[:2]}")
101
+ print(f"check backward - sigmas - sum: {torch.sum(distance_sigmas_grad)}\n")
102
+
103
+ print(f"check backward - coords - torch: {coords_grad_th[:2]}")
104
+ print(f"check backward - coords - cuda: {coords_grad_cuda[:2]}")
105
+ print(f"check backward - coords - distance: {distance_coords_grad[:2]}")
106
+ print(f"check backward - coords - sum: {torch.sum(distance_coords_grad)}\n")
107
+
108
+ print(f"check backward - colors - torch: {colors_grad_th[:2]}")
109
+ print(f"check backward - colors - cuda: {colors_grad_cuda[:2]}")
110
+ print(f"check backward - colors - distance: {distance_colors_grad[:2]}")
111
+ print(f"check backward - colors - sum: {torch.sum(distance_colors_grad)}\n")
112
+
113
+ print(f"--------------------------- ends --------------------------------\n\n")
114
+
115
+
utils/gs_cuda/gs.cu ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <cmath>
3
+ #include <curand_kernel.h>
4
+
5
+ #define PI 3.1415926536
6
+ #define PI2 6.283153072
7
+
8
+ extern "C"
9
+ __global__ void _gs_render_cuda(
10
+ const float *sigmas,
11
+ const float *coords,
12
+ const float *colors,
13
+ float *rendered_img,
14
+ const int s, // gs num
15
+ const int h,
16
+ const int w,
17
+ const int c
18
+ ){
19
+
20
+ int index = blockIdx.x*blockDim.x + threadIdx.x;
21
+ int curw = index % w;
22
+ int curh = int((index-curw)/w);
23
+ if(curw >= w || curh >=h){
24
+ return;
25
+ }
26
+
27
+ float curw_f = 2.0*curw/(w-1) - 1.0;
28
+ float curh_f = 2.0*curh/(h-1) - 1.0;
29
+
30
+ // printf("index:%d, curw:%d, curh:%d, curw_f:%f, curh_f:%f\n",index,curw,curh,curw_f,curh_f);
31
+
32
+ for(int si=0; si<s; si++){
33
+
34
+ // compute the 2d gs value
35
+ float sigma_x = sigmas[si*3+0];
36
+ float sigma_y = sigmas[si*3+1];
37
+ float rho = sigmas[si*3+2];
38
+ float x = coords[si*2+0];
39
+ float y = coords[si*2+1];
40
+
41
+ //
42
+ float one_div_one_minus_rho2 = 1.0 / (1-rho*rho) ;
43
+ float one_div_sigma_x = 1.0 / sigma_x;
44
+ float one_div_sigma_y = 1.0 / sigma_y;
45
+ float d_x = curw_f - x;
46
+ float d_y = curh_f - y;
47
+
48
+ float v = one_div_sigma_x*one_div_sigma_x*d_x*d_x;
49
+ v -= 2*rho*d_x*d_y*one_div_sigma_x*one_div_sigma_y;
50
+ v += d_y*d_y*one_div_sigma_y*one_div_sigma_y;
51
+ v *= -one_div_one_minus_rho2 / 2.0;
52
+ v = exp(v);
53
+ // since we normlize the v with the max, we remove this step to obtain equal result
54
+ // v *= one_div_sigma_x * one_div_sigma_y * pow(one_div_one_minus_rho2, 0.5) / PI2 ;
55
+ // printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
56
+
57
+ for(int ci=0; ci<c; ci++){
58
+ rendered_img[(curh*w+curw)*c+ci] += v*colors[si*3+ci];
59
+ }
60
+ }
61
+ }
62
+
63
+
64
+ void _gs_render(
65
+ const float *sigmas,
66
+ const float *coords,
67
+ const float *colors,
68
+ float *rendered_img,
69
+ const int s,
70
+ const int h,
71
+ const int w,
72
+ const int c
73
+ ) {
74
+
75
+ int threads=64;
76
+ dim3 grid( h*w, 1);
77
+ dim3 block( threads, 1);
78
+ _gs_render_cuda<<<grid, block>>>(sigmas, coords, colors, rendered_img, s, h, w, c);
79
+ }
80
+
81
+ extern "C"
82
+ __global__ void _gs_render_backward_cuda(
83
+ const float *sigmas,
84
+ const float *coords,
85
+ const float *colors,
86
+ const float *grads,
87
+ float *grads_sigmas,
88
+ float *grads_coords,
89
+ float *grads_colors,
90
+ const int s, // gs num
91
+ const int h,
92
+ const int w,
93
+ const int c
94
+ ){
95
+
96
+ int curs = blockIdx.x*blockDim.x + threadIdx.x;
97
+ if(curs >= s){
98
+ return ;
99
+ }
100
+
101
+ // obtain parameters of gs
102
+ float sigma_x = sigmas[curs*3+0];
103
+ float sigma_y = sigmas[curs*3+1];
104
+ float rho = sigmas[curs*3+2];
105
+ float x = coords[curs*2+0];
106
+ float y = coords[curs*2+1];
107
+ float cr = colors[curs*3+0];
108
+ float cg = colors[curs*3+1];
109
+ float cb = colors[curs*3+2];
110
+
111
+ //
112
+ float w1 = -0.5 / (1-rho*rho) ;
113
+ float w2 = 1.0 / (sigma_x*sigma_x);
114
+ float w3 = 1.0 / (sigma_x*sigma_y);
115
+ float w4 = 1.0 / (sigma_y*sigma_y);
116
+ float od_sx = 1.0 / sigma_x;
117
+ float od_sy = 1.0 / sigma_y;
118
+
119
+ // init
120
+ float _gr=0.0, _gg=0.0, _gb=0.0;
121
+ float _gx=0.0, _gy=0.0;
122
+ float _gsx=0.0, _gsy=0.0, _gsr=0.0;
123
+
124
+ for(int hi = 0; hi < h; hi++){
125
+ for( int wi=0; wi < w; wi++){
126
+
127
+ float curw_f = 2.0*wi/(w-1) - 1.0;
128
+ float curh_f = 2.0*hi/(h-1) - 1.0;
129
+
130
+ // obtain grad to p^t_r, p^t_g, p^t_b
131
+ float gptr = grads[(hi*w+wi)*c+0]; // grad of loss to P^t_r
132
+ float gptg = grads[(hi*w+wi)*c+1];
133
+ float gptb = grads[(hi*w+wi)*c+2];
134
+
135
+ // compute the 2d gs value
136
+
137
+ float d_x = curw_f - x; // distance along x axis
138
+ float d_y = curh_f - y;
139
+ float d = w2*d_x*d_x - 2*rho*w3*d_x*d_y + w4*d_y*d_y;
140
+ float v = w1*d;
141
+ v = exp(v);
142
+ // printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
143
+
144
+ // compute grad of colors
145
+ _gr += v*gptr;
146
+ _gg += v*gptg;
147
+ _gb += v*gptb;
148
+
149
+ // compute grad of coords
150
+ float gpt = gptr*cr+gptg*cg+gptb*cb;
151
+ float v_2_w1 = v*2*w1;
152
+
153
+ float g_vst_to_gsx = v_2_w1*(-w2*d_x+rho*w3*d_y); // grad of v^{st} to G^s_x
154
+ _gx += gpt*g_vst_to_gsx;
155
+ float g_vst_to_gsy = v_2_w1*(-w4*d_y+rho*w3*d_x); // grad of v^{st} to G^s_y
156
+ _gy += gpt*g_vst_to_gsy;
157
+
158
+ // compute grad of sigmas
159
+ float g_vst_to_gsigx = v_2_w1*od_sx* (w3*rho*d_x*d_y - w2*d_x*d_x);
160
+ _gsx += gpt*g_vst_to_gsigx;
161
+ float g_vst_to_gsigy = v_2_w1*od_sy* (w3*rho*d_x*d_y - w4*d_y*d_y);
162
+ _gsy += gpt*g_vst_to_gsigy;
163
+ float g_vst_to_rho = -v_2_w1*(2*w1*rho*d+w3*d_x*d_y);
164
+ _gsr += gpt*g_vst_to_rho;
165
+ }
166
+ }
167
+
168
+ // write the values
169
+ grads_sigmas[curs*3+0] = _gsx;
170
+ grads_sigmas[curs*3+1] = _gsy;
171
+ grads_sigmas[curs*3+2] = _gsr;
172
+ grads_coords[curs*2+0] = _gx;
173
+ grads_coords[curs*2+1] = _gy;
174
+ grads_colors[curs*3+0] = _gr;
175
+ grads_colors[curs*3+1] = _gg;
176
+ grads_colors[curs*3+2] = _gb;
177
+
178
+ }
179
+
180
+ void _gs_render_backward(
181
+ const float *sigmas,
182
+ const float *coords,
183
+ const float *colors,
184
+ const float *grads, // (h, w, c)
185
+ float *grads_sigmas,
186
+ float *grads_coords,
187
+ float *grads_colors,
188
+ const int s,
189
+ const int h,
190
+ const int w,
191
+ const int c
192
+ ) {
193
+
194
+ int threads=64;
195
+ dim3 grid(s, 1);
196
+ dim3 block( threads, 1);
197
+ _gs_render_backward_cuda<<<grid, block>>>(sigmas, coords, colors, grads, grads_sigmas, grads_coords, grads_colors, s, h, w, c);
198
+ }
199
+
utils/gs_cuda/gs.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ void _gs_render(
2
+ const float *sigmas,
3
+ const float *coords,
4
+ const float *colors,
5
+ float *rendered_img,
6
+ const int s,
7
+ const int h,
8
+ const int w,
9
+ const int c
10
+ );
11
+
12
+ void _gs_render_backward(
13
+ const float *sigmas,
14
+ const float *coords,
15
+ const float *colors,
16
+ const float *grads,
17
+ float *grads_sigmas,
18
+ float *grads_coords,
19
+ float *grads_colors,
20
+ const int s,
21
+ const int h,
22
+ const int w,
23
+ const int c
24
+ );
utils/gs_cuda/gswrapper.cpp ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "gs.h"
2
+ #include <torch/extension.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
6
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
7
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
8
+
9
+ void gs_render(
10
+ torch::Tensor &sigmas,
11
+ torch::Tensor &coords,
12
+ torch::Tensor &colors,
13
+ torch::Tensor &rendered_img,
14
+ const int s,
15
+ const int h,
16
+ const int w,
17
+ const int c
18
+ ){
19
+
20
+ CHECK_INPUT(sigmas);
21
+ CHECK_INPUT(coords);
22
+ CHECK_INPUT(colors);
23
+ CHECK_INPUT(rendered_img);
24
+
25
+ // run the code at the cuda device same with the input
26
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));
27
+
28
+ _gs_render(
29
+ (const float *) sigmas.data_ptr(),
30
+ (const float *) coords.data_ptr(),
31
+ (const float *) colors.data_ptr(),
32
+ (float *) rendered_img.data_ptr(),
33
+ s, h, w, c);
34
+ }
35
+
36
+ void gs_render_backward(
37
+ torch::Tensor &sigmas,
38
+ torch::Tensor &coords,
39
+ torch::Tensor &colors,
40
+ torch::Tensor &grads,
41
+ torch::Tensor &grads_sigmas,
42
+ torch::Tensor &grads_coords,
43
+ torch::Tensor &grads_colors,
44
+ const int s,
45
+ const int h,
46
+ const int w,
47
+ const int c
48
+ ){
49
+
50
+ CHECK_INPUT(sigmas);
51
+ CHECK_INPUT(coords);
52
+ CHECK_INPUT(colors);
53
+ CHECK_INPUT(grads);
54
+ CHECK_INPUT(grads_sigmas);
55
+ CHECK_INPUT(grads_coords);
56
+ CHECK_INPUT(grads_colors);
57
+
58
+
59
+ // run the code at the cuda device same with the input
60
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));
61
+
62
+ _gs_render_backward(
63
+ (const float *) sigmas.data_ptr(),
64
+ (const float *) coords.data_ptr(),
65
+ (const float *) colors.data_ptr(),
66
+ (const float *) grads.data_ptr(),
67
+ (float *) grads_sigmas.data_ptr(),
68
+ (float *) grads_coords.data_ptr(),
69
+ (float *) grads_colors.data_ptr(),
70
+ s, h, w, c);
71
+ }
72
+
73
+ PYBIND11_MODULE( TORCH_EXTENSION_NAME, m) {
74
+ m.def( "gs_render",
75
+ &gs_render,
76
+ "cuda forward wrapper");
77
+ m.def( "gs_render_backward",
78
+ &gs_render_backward,
79
+ "cuda backward wrapper");
80
+ }
utils/gs_cuda/gswrapper.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+ from torch.autograd import Function
5
+ from torch.autograd.function import once_differentiable
6
+
7
+ build_path = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'build')
8
+ os.makedirs(build_path, exist_ok=True)
9
+
10
+ file_path = os.path.split(os.path.abspath(__file__))[0]
11
+ GSWrapper = load(
12
+ name="gscuda",
13
+ # sources=["gs_cuda/gswrapper.cpp", "gs_cuda/gs.cu"],
14
+ sources=[os.path.join(file_path, "gswrapper.cpp"),
15
+ os.path.join(file_path, "gs.cu")],
16
+ build_directory=build_path,
17
+ verbose=True)
18
+
19
+ class GSCUDA(Function):
20
+
21
+ @staticmethod
22
+ def forward(ctx, sigmas, coords, colors, rendered_img):
23
+ ctx.save_for_backward(sigmas, coords, colors)
24
+ h, w, c = rendered_img.shape
25
+ s = sigmas.shape[0]
26
+ GSWrapper.gs_render(sigmas, coords, colors, rendered_img, s, h, w, c)
27
+ return rendered_img
28
+
29
+ @staticmethod
30
+ @once_differentiable
31
+ def backward(ctx, grad_output):
32
+ sigmas, coords, colors = ctx.saved_tensors
33
+ h, w, c = grad_output.shape
34
+ s = sigmas.shape[0]
35
+ grads_sigmas = torch.zeros_like(sigmas)
36
+ grads_coords = torch.zeros_like(coords)
37
+ grads_colors = torch.zeros_like(colors)
38
+ GSWrapper.gs_render_backward(sigmas, coords, colors, grad_output.contiguous(), grads_sigmas, grads_coords, grads_colors, s, h, w, c)
39
+ return (grads_sigmas, grads_coords, grads_colors, None)
40
+
41
+ def gaussiansplatting_render(sigmas, coords, colors, image_size):
42
+ sigmas = sigmas.contiguous() # (gs num, 3)
43
+ coords = coords.contiguous() # (gs num, 2)
44
+ colors = colors.contiguous() # (gs num, c)
45
+ h, w = image_size[:2]
46
+ c = colors.shape[-1]
47
+ rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
48
+ return GSCUDA.apply(sigmas, coords, colors, rendered_img)
49
+
utils/gs_cuda/mylineprofiler.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import sys
4
+ import timeit
5
+ import tokenize
6
+ import torch
7
+ import psutil
8
+ import inspect
9
+ from loguru import logger
10
+ from prettytable import PrettyTable
11
+
12
+ # implement by xtudbxk
13
+ # github: https://github.com/xtudbxk/lineprofiler
14
+ class MyLineProfiler():
15
+ def __init__(self, base='ms', cuda_sync=True, gpuids=(0,), warmup=0, warmup_lineno=-1):
16
+
17
+ if base == 'ms':
18
+ self.base_n = 1000
19
+ elif base == 's':
20
+ self.base_n = 1
21
+ else:
22
+ logguru.warning(f'Unsupported base - {base}, using "s" instead')
23
+
24
+ self.base = base
25
+ self.cuda_sync = cuda_sync
26
+ self.gpuids = gpuids
27
+ self.warmup = warmup
28
+ self.warmup_counter = warmup
29
+ # we should wait this line execute warup_counter times
30
+ # before recording the stats
31
+ self.warmup_lineno = warmup_lineno
32
+
33
+ # for time profiling
34
+ self._times = {}
35
+ self._func_name = None
36
+ self._func_filename = None
37
+ self._last_time = -1
38
+ self._last_lineno = -1
39
+ self._func_hit_count = 0
40
+ self._func_firstlineno = 0
41
+
42
+ # for memory profiling
43
+ self._process = psutil.Process(os.getpid())
44
+ self._memory = {}
45
+ self._last_memory = 0
46
+
47
+ # for cuda memory profiling
48
+ self._gpu_memory = {}
49
+ self._gpu_last_memory = 0
50
+
51
+ def __trace_func__(self, frame, event, arg):
52
+ # print(f'in {frame.f_code.co_filename} func {frame.f_code.co_name} line {frame.f_lineno}, event - {event}')
53
+
54
+ # check if run into the decorated func
55
+ if self._func_firstlineno == frame.f_code.co_firstlineno and frame.f_code.co_name == self._func_name and frame.f_code.co_filename == self._func_filename:
56
+
57
+ # --- obtain info for current hit ---
58
+ # cuda related
59
+ if self.cuda_sync is True:
60
+ torch.cuda.synchronize()
61
+
62
+ current_time = timeit.default_timer()
63
+ memory = self._process.memory_info().rss
64
+ gpu_memory = torch.cuda.memory_allocated()
65
+ # --- ends ---
66
+
67
+ # --- initilize the info when first hit ---
68
+ if frame.f_lineno not in self._times: # first hit time for this line
69
+ self._times[frame.f_lineno] = {'hit':0, 'time': 0}
70
+ self._memory[frame.f_lineno] = 0
71
+ self._gpu_memory[frame.f_lineno] = 0
72
+ # --- ends ---
73
+
74
+ # --- record info before call the decorated func ---
75
+ # 'call' - before call the func
76
+ if event == 'call':
77
+ self._last_time = current_time
78
+ self._last_lineno = frame.f_lineno
79
+ self._last_memory = memory
80
+ self._last_gpu_memory = gpu_memory
81
+
82
+ if self.warmup_lineno < 0:
83
+ self.warmup_counter -= 1
84
+ if self.warmup_counter < 0:
85
+ self._func_hit_count += 1
86
+ # --- ends ---
87
+
88
+ # 'line' - after excuting the line
89
+ # 'return' - return from the function
90
+ if event == 'line' or event == 'return':
91
+
92
+ if event == 'line' and self.warmup_counter < 0:
93
+ self._times[frame.f_lineno]['hit'] += 1
94
+
95
+
96
+ # --- obtain the memory and time consumed by this line ---
97
+ if self.warmup_counter < 0:
98
+ self._times[self._last_lineno]['time'] += current_time - self._last_time
99
+ self._memory[self._last_lineno] += memory - self._last_memory
100
+ self._gpu_memory[self._last_lineno] += gpu_memory - self._gpu_last_memory
101
+ # --- ends ---
102
+
103
+ if self.cuda_sync is True:
104
+ torch.cuda.synchronize()
105
+
106
+ self._last_time = timeit.default_timer()
107
+ self._last_memory = memory
108
+ self._gpu_last_memory = gpu_memory
109
+ self._last_lineno = frame.f_lineno
110
+
111
+ return self.__trace_func__
112
+
113
+ def decorate(self, func):
114
+ if self._func_name is not None:
115
+ logger.warning(f'Only support decorate only one func. Aready decorated "{self._func_name}"')
116
+ self._func_name = func.__name__
117
+ self._func_filename = func.__code__.co_filename
118
+ self._func_firstlineno = func.__code__.co_firstlineno
119
+
120
+ def _f(*args, **kwargs):
121
+ origin_trace_func = sys.gettrace()
122
+ sys.settrace(self.__trace_func__)
123
+ ret = func(*args, **kwargs)
124
+ sys.settrace(origin_trace_func)
125
+ return ret
126
+ return _f
127
+
128
+ def _get_table(self):
129
+
130
+ if len(self._times) <= 0:
131
+ logger.warning(f"un recorded datas, please ensure the function is executed")
132
+ return None
133
+
134
+ # --- load the source code ---
135
+ with open(self._func_filename, 'r') as f:
136
+ source_lines = [line.strip('\n') for line in f.readlines()]
137
+ code_str = "\n".join(source_lines)
138
+
139
+ def_lineno = min(self._times.keys())
140
+ final_lineno = max(self._times.keys())
141
+
142
+ # remove the additional blank content
143
+ pre_blank_count = len(source_lines[def_lineno-1]) - len(source_lines[def_lineno-1].lstrip(' ').lstrip('\t'))
144
+ # --- ends ---
145
+
146
+ # --- analysize the source code and collect infos for multi-line code ---
147
+ new_logic_linenos = [token.start[0] for token in tokenize.generate_tokens(
148
+ io.StringIO(code_str).readline) if token.type == 4]
149
+ # --- ends ---
150
+
151
+ # --- merge the stats multi-line code ---
152
+ sorted_linenos = [lineno for lineno in self._times.keys()]
153
+ sorted_linenos.sort(key=int)
154
+
155
+ lineno_cache = []
156
+ for lineno in sorted_linenos:
157
+ if lineno not in new_logic_linenos:
158
+ lineno_cache.append(lineno)
159
+ else:
160
+ # we should merge its info to the prev_lineno
161
+ if len(lineno_cache) <= 0:
162
+ continue
163
+ else:
164
+ lineno_cache.append(lineno)
165
+ first_lineno = lineno_cache[0]
166
+ for prev_lineno in lineno_cache[1:]:
167
+ self._times[first_lineno]["hit"] = min(self._times[first_lineno]["hit"], self._times[prev_lineno]["hit"])
168
+ self._times[first_lineno]["time"] += self._times[prev_lineno]["time"]
169
+ del self._times[prev_lineno]
170
+
171
+ self._memory[first_lineno] += self._memory[prev_lineno]
172
+ del self._memory[prev_lineno]
173
+
174
+ self._gpu_memory[first_lineno] += self._gpu_memory[prev_lineno]
175
+ del self._gpu_memory[prev_lineno]
176
+ lineno_cache = []
177
+ # --- ends ---
178
+
179
+ # --- initialize the pretty table for output ---
180
+ table = PrettyTable(['lineno', 'hits', 'time', 'time per hit', 'hit perc', 'time perc', 'mem inc', 'mem peak', 'gpu mem inc', 'gpu mem peak'])
181
+ # --- ends ---
182
+
183
+ # --- compute some statisticals ---
184
+ total_hit = 0 # for compute the hit percentage
185
+ total_time = 0
186
+ for lineno, stats in self._times.items():
187
+ if lineno == def_lineno: continue
188
+ total_hit += stats['hit']
189
+ total_time += stats['time']
190
+
191
+ total_memory = sum([m for l,m in self._memory.items()]) / 1024 / 1024
192
+ total_gpu_memory = sum([m for l,m in self._gpu_memory.items()]) / 1024 / 1024
193
+ # --- ends ---
194
+
195
+ peak_cpu_memory = 0
196
+ peak_gpu_memory = 0
197
+ for lineno in range(def_lineno, final_lineno+1):
198
+ if lineno not in self._times:
199
+ # the comment line, empty line or merged line from multi-lines code
200
+ table.add_row([lineno, '-', '-', '-', '-', '-', '-',f'{peak_cpu_memory:5.3f} MB', '-', f'{peak_gpu_memory:5.3f} MB'])
201
+ else:
202
+ stats = self._times[lineno]
203
+ if lineno == def_lineno:
204
+ table.add_row([lineno, self._func_hit_count, f'{total_time*self.base_n:.4f} {self.base}', f'{total_time/self._func_hit_count*self.base_n:.4f} {self.base}', '-', '-', f'{total_memory:5.3f} MB', 'baseline', f'{total_gpu_memory:5.3f} MB', 'baseline'])
205
+ else:
206
+
207
+ line_result = [lineno, stats['hit'],
208
+ f'{stats["time"]*self.base_n:.4f} {self.base}',
209
+ f'{stats["time"]/stats["hit"]*self.base_n:.4f} {self.base}' if stats['hit'] > 0 else 'nan',
210
+ f'{stats["hit"]/total_hit*100:.3f}%' if total_hit > 0 else 'nan',
211
+ f'{stats["time"]/total_time*100:.3f}%'] if total_time > 0 else 'nan'
212
+
213
+ line_result += [f'{self._memory[lineno]/1024/1024:5.3f} MB' if stats['hit'] > 0 else '0 MB']
214
+ peak_cpu_memory = peak_cpu_memory + self._memory[lineno]/1024/1024
215
+ line_result += [f'{peak_cpu_memory:5.3f} MB']
216
+
217
+ line_result += [f'{self._gpu_memory[lineno]/1024/1024:5.3f} MB' if stats['hit'] > 0 else '0 MB']
218
+ peak_gpu_memory = peak_gpu_memory + self._gpu_memory[lineno]/1024/1024
219
+ line_result += [f'{peak_gpu_memory:5.3f} MB']
220
+
221
+ table.add_row(line_result)
222
+
223
+ table.add_column('sources', [source_lines[i-1][pre_blank_count:] if len(source_lines[i-1])>pre_blank_count else '' for i in range(def_lineno, final_lineno+1)], 'l')
224
+ return table
225
+
226
+ def print(self, filename=None, mode="w"):
227
+ introducation = '''
228
+ 1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
229
+ 2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
230
+ 3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
231
+ 4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
232
+ '''
233
+ print(introducation)
234
+
235
+ table = PrettyTable(['lineno', 'hits', 'time', 'time per hit', 'hit perc', 'time perc', 'mem inc', 'mem peak', 'gpu mem inc', 'gpu mem peak'])
236
+ table = self._get_table()
237
+ print(table)
238
+ if filename is not None:
239
+ with open(filename, mode) as f:
240
+ f.write(introducation)
241
+ f.write(f"args - base={self.base}, cuda_sync={self.cuda_sync}, gpuids={self.gpuids}, warmup={self.warmup}\n")
242
+ f.write(str(table))
243
+
244
+ if __name__ == '__main__':
245
+ import numpy as np
246
+ def mytest(h='hello',
247
+ xx="xx"):
248
+
249
+ h = h + 'world'
250
+ a = []
251
+ for _ in range(200):
252
+ # a = np.zeros((1000, 1000), dtype=np.float32)
253
+ a.append(np.zeros((1000, 1000), dtype=np.float32))
254
+ a.append(
255
+ np.zeros((1000, 1000),
256
+ dtype=np.float32))
257
+ # print(a[0,0])
258
+ print(h)
259
+
260
+ profiler = MyLineProfiler(cuda_sync=False, warmup=2)
261
+ mytest = profiler.decorate(mytest)
262
+ for _ in range(5):
263
+ mytest()
264
+ profiler.print()
utils/gs_cuda/profile.log ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ 1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
3
+ 2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
4
+ 3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
5
+ 4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
6
+ args - base=ms, cuda_sync=True, gpuids=(0,), warmup=0
7
+ +--------+------+------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
8
+ | lineno | hits | time | time per hit | hit perc | time perc | mem inc | mem peak | gpu mem inc | gpu mem peak | sources |
9
+ +--------+------+------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
10
+ | 41 | 1 | 76.8299 ms | 76.8299 ms | - | - | 0.902 MB | baseline | 3.500 MB | baseline | def gaussiansplatting_render(sigmas, coords, colors, image_size): |
11
+ | 42 | 1 | 0.0353 ms | 0.0353 ms | 14.286% | 0.046% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | sigmas = sigmas.contiguous() # (gs num, 3) |
12
+ | 43 | 1 | 0.0078 ms | 0.0078 ms | 14.286% | 0.010% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | coords = coords.contiguous() # (gs num, 2) |
13
+ | 44 | 1 | 0.0063 ms | 0.0063 ms | 14.286% | 0.008% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | colors = colors.contiguous() # (gs num, c) |
14
+ | 45 | 1 | 0.0063 ms | 0.0063 ms | 14.286% | 0.008% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | h, w = image_size[:2] |
15
+ | 46 | 1 | 0.0093 ms | 0.0093 ms | 14.286% | 0.012% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | c = colors.shape[-1] |
16
+ | 47 | 1 | 1.8306 ms | 1.8306 ms | 14.286% | 2.383% | 0.438 MB | 0.438 MB | 3.000 MB | 3.000 MB | rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
17
+ | 48 | 1 | 74.9344 ms | 74.9344 ms | 14.286% | 97.533% | 0.465 MB | 0.902 MB | 0.000 MB | 3.000 MB | return GSCUDA.apply(sigmas, coords, colors, rendered_img) |
18
+ +--------+------+------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
19
+ 1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
20
+ 2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
21
+ 3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
22
+ 4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
23
+ args - base=ms, cuda_sync=True, gpuids=(0,), warmup=0
24
+ +--------+------+--------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
25
+ | lineno | hits | time | time per hit | hit perc | time perc | mem inc | mem peak | gpu mem inc | gpu mem peak | sources |
26
+ +--------+------+--------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
27
+ | 41 | 1 | 1175.7406 ms | 1175.7406 ms | - | - | 0.777 MB | baseline | 12.000 MB | baseline | def gaussiansplatting_render(sigmas, coords, colors, image_size): |
28
+ | 42 | 1 | 0.0304 ms | 0.0304 ms | 14.286% | 0.003% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | sigmas = sigmas.contiguous() # (gs num, 3) |
29
+ | 43 | 1 | 0.0069 ms | 0.0069 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | coords = coords.contiguous() # (gs num, 2) |
30
+ | 44 | 1 | 0.0064 ms | 0.0064 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | colors = colors.contiguous() # (gs num, c) |
31
+ | 45 | 1 | 0.0065 ms | 0.0065 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | h, w = image_size[:2] |
32
+ | 46 | 1 | 0.0099 ms | 0.0099 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | c = colors.shape[-1] |
33
+ | 47 | 1 | 1.2594 ms | 1.2594 ms | 14.286% | 0.107% | 0.133 MB | 0.133 MB | 3.000 MB | 3.000 MB | rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
34
+ | 48 | 1 | 1174.4211 ms | 1174.4211 ms | 14.286% | 99.888% | 0.645 MB | 0.777 MB | 0.000 MB | 3.000 MB | return GSCUDA.apply(sigmas, coords, colors, rendered_img) |
35
+ +--------+------+--------------+--------------+----------+-----------+----------+----------+-------------+--------------+-----------------------------------------------------------------------------+
36
+ 1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
37
+ 2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
38
+ 3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
39
+ 4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
40
+ args - base=ms, cuda_sync=True, gpuids=(0,), warmup=0
41
+ +--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
42
+ | lineno | hits | time | time per hit | hit perc | time perc | mem inc | mem peak | gpu mem inc | gpu mem peak | sources |
43
+ +--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
44
+ | 41 | 10 | 11844.9229 ms | 1184.4923 ms | - | - | 20.227 MB | baseline | 15.000 MB | baseline | def gaussiansplatting_render(sigmas, coords, colors, image_size): |
45
+ | 42 | 10 | 0.1342 ms | 0.0134 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | sigmas = sigmas.contiguous() # (gs num, 3) |
46
+ | 43 | 10 | 0.0654 ms | 0.0065 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | coords = coords.contiguous() # (gs num, 2) |
47
+ | 44 | 10 | 0.0618 ms | 0.0062 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | colors = colors.contiguous() # (gs num, c) |
48
+ | 45 | 10 | 0.0710 ms | 0.0071 ms | 14.286% | 0.001% | 0.000 MB | 0.000 MB | 0.000 MB | 0.000 MB | h, w = image_size[:2] |
49
+ | 46 | 10 | 0.0803 ms | 0.0080 ms | 14.286% | 0.001% | 0.062 MB | 0.062 MB | 0.000 MB | 0.000 MB | c = colors.shape[-1] |
50
+ | 47 | 10 | 7.2555 ms | 0.7256 ms | 14.286% | 0.061% | 19.105 MB | 19.168 MB | 30.000 MB | 30.000 MB | rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
51
+ | 48 | 10 | 11837.2547 ms | 1183.7255 ms | 14.286% | 99.935% | 1.059 MB | 20.227 MB | 0.000 MB | 30.000 MB | return GSCUDA.apply(sigmas, coords, colors, rendered_img) |
52
+ +--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
53
+ 1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
54
+ 2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
55
+ 3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
56
+ 4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
57
+ args - base=ms, cuda_sync=True, gpuids=(0,), warmup=0
58
+ +--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
59
+ | lineno | hits | time | time per hit | hit perc | time perc | mem inc | mem peak | gpu mem inc | gpu mem peak | sources |
60
+ +--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
61
+ | 41 | 10 | 11855.0900 ms | 1185.5090 ms | - | - | 20.242 MB | baseline | 15.000 MB | baseline | def gaussiansplatting_render(sigmas, coords, colors, image_size): |
62
+ | 42 | 10 | 0.1263 ms | 0.0126 ms | 14.286% | 0.001% | 0.078 MB | 0.078 MB | 0.000 MB | 0.000 MB | sigmas = sigmas.contiguous() # (gs num, 3) |
63
+ | 43 | 10 | 0.0632 ms | 0.0063 ms | 14.286% | 0.001% | 0.000 MB | 0.078 MB | 0.000 MB | 0.000 MB | coords = coords.contiguous() # (gs num, 2) |
64
+ | 44 | 10 | 0.0588 ms | 0.0059 ms | 14.286% | 0.000% | 0.000 MB | 0.078 MB | 0.000 MB | 0.000 MB | colors = colors.contiguous() # (gs num, c) |
65
+ | 45 | 10 | 0.0626 ms | 0.0063 ms | 14.286% | 0.001% | 0.000 MB | 0.078 MB | 0.000 MB | 0.000 MB | h, w = image_size[:2] |
66
+ | 46 | 10 | 0.0747 ms | 0.0075 ms | 14.286% | 0.001% | 0.000 MB | 0.078 MB | 0.000 MB | 0.000 MB | c = colors.shape[-1] |
67
+ | 47 | 10 | 7.0497 ms | 0.7050 ms | 14.286% | 0.059% | 19.078 MB | 19.156 MB | 30.000 MB | 30.000 MB | rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32) |
68
+ | 48 | 10 | 11847.6547 ms | 1184.7655 ms | 14.286% | 99.937% | 0.820 MB | 19.977 MB | 0.000 MB | 30.000 MB | return GSCUDA.apply(sigmas, coords, colors, rendered_img) |
69
+ +--------+------+---------------+--------------+----------+-----------+-----------+-----------+-------------+--------------+-----------------------------------------------------------------------------+
utils/gs_cuda/profile.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from torchvision.utils import save_image
6
+ from gswrapper import gaussiansplatting_render
7
+
8
+ def generate_2D_gaussian_splatting(kernel_size, sigma_x, sigma_y, rho, coords,
9
+ colours, image_size=(256, 256, 3), device="cuda"):
10
+
11
+ batch_size = colours.shape[0]
12
+
13
+ sigma_x = sigma_x.view(batch_size, 1, 1)
14
+ sigma_y = sigma_y.view(batch_size, 1, 1)
15
+ rho = rho.view(batch_size, 1, 1)
16
+
17
+ covariance = torch.stack(
18
+ [torch.stack([sigma_x**2, rho*sigma_x*sigma_y], dim=-1),
19
+ torch.stack([rho*sigma_x*sigma_y, sigma_y**2], dim=-1)],
20
+ dim=-2
21
+ )
22
+
23
+ # Check for positive semi-definiteness
24
+ # determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y)**2
25
+ # if (determinant <= 0).any():
26
+ # raise ValueError("Covariance matrix must be positive semi-definite")
27
+
28
+ inv_covariance = torch.inverse(covariance)
29
+
30
+ # Choosing quite a broad range for the distribution [-5,5] to avoid any clipping
31
+ start = torch.tensor([-5.0], device=device).view(-1, 1)
32
+ end = torch.tensor([5.0], device=device).view(-1, 1)
33
+ base_linspace = torch.linspace(0, 1, steps=kernel_size, device=device)
34
+ ax_batch = start + (end - start) * base_linspace
35
+
36
+ # Expanding dims for broadcasting
37
+ ax_batch_expanded_x = ax_batch.unsqueeze(-1).expand(-1, -1, kernel_size)
38
+ ax_batch_expanded_y = ax_batch.unsqueeze(1).expand(-1, kernel_size, -1)
39
+
40
+ # Creating a batch-wise meshgrid using broadcasting
41
+ xx, yy = ax_batch_expanded_x, ax_batch_expanded_y # (batchsize, kernelsize, kernelsize)
42
+
43
+ xy = torch.stack([xx, yy], dim=-1) # (batchsize, kernelsize, kernelsize, 2)
44
+ z = torch.einsum('b...i,b...ij,b...j->b...', xy, -0.5 * inv_covariance, xy) # (batchsize, kernelsize, kernelsize, 2)
45
+ kernel = torch.exp(z) / (2 * torch.tensor(np.pi, device=device) * torch.sqrt(torch.det(covariance)).view(batch_size, 1, 1)) # (batchsize, kernelsize, kernelsize)
46
+
47
+
48
+ kernel_max_1, _ = kernel.max(dim=-1, keepdim=True) # Find max along the last dimension
49
+ kernel_max_2, _ = kernel_max_1.max(dim=-2, keepdim=True) # Find max along the second-to-last dimension
50
+ kernel_normalized = kernel / kernel_max_2 # (batchsize, kernelsize, kernelsize)
51
+
52
+
53
+ kernel_reshaped = kernel_normalized.repeat(1, 3, 1).view(batch_size * 3, kernel_size, kernel_size)
54
+ kernel_rgb = kernel_reshaped.unsqueeze(0).reshape(batch_size, 3, kernel_size, kernel_size) # (batchsize, 3, kernelsize, kernelsize)
55
+
56
+ # Calculating the padding needed to match the image size
57
+ pad_h = image_size[0] - kernel_size
58
+ pad_w = image_size[1] - kernel_size
59
+
60
+ if pad_h < 0 or pad_w < 0:
61
+ raise ValueError("Kernel size should be smaller or equal to the image size.")
62
+
63
+ # Adding padding to make kernel size equal to the image size
64
+ padding = (pad_w // 2, pad_w // 2 + pad_w % 2, # padding left and right
65
+ pad_h // 2, pad_h // 2 + pad_h % 2) # padding top and bottom
66
+
67
+ kernel_rgb_padded = torch.nn.functional.pad(kernel_rgb, padding, "constant", 0) # (batchsize, 3, h, w)
68
+
69
+ # Extracting shape information
70
+ b, c, h, w = kernel_rgb_padded.shape
71
+
72
+ # Create a batch of 2D affine matrices
73
+ theta = torch.zeros(b, 2, 3, dtype=torch.float32, device=device)
74
+ theta[:, 0, 0] = 1.0
75
+ theta[:, 1, 1] = 1.0
76
+ theta[:, :, 2] = -coords # (b, 2) - the offset of gaussian splating
77
+
78
+ # Creating grid and performing grid sampling
79
+ grid = F.affine_grid(theta, size=(b, c, h, w), align_corners=True) # (b, 3, h, w)
80
+ # grid_y = torch.linspace(-1, 1, steps=h, device=device).reshape(1, h, 1, 1).repeat(1, 1, w, 1)
81
+ # grid_x = torch.linspace(-1, 1, steps=w, device=device).reshape(1, 1, w, 1).repeat(1, h, 1, 1)
82
+ # grid = torch.cat([grid_x, grid_y], dim=-1)
83
+ # grid = grid - coords.reshape(-1, 1, 1, 2)
84
+
85
+ kernel_rgb_padded_translated = F.grid_sample(kernel_rgb_padded, grid, align_corners=True) # (b, 3, h, w)
86
+
87
+ rgb_values_reshaped = colours.unsqueeze(-1).unsqueeze(-1)
88
+
89
+ final_image_layers = rgb_values_reshaped * kernel_rgb_padded_translated
90
+ final_image = final_image_layers.sum(dim=0)
91
+ # final_image = torch.clamp(final_image, 0, 1)
92
+ final_image = final_image.permute(1,2,0)
93
+
94
+ return final_image
95
+
96
+
97
+ if __name__ == "__main__":
98
+ from mylineprofiler import MyLineProfiler
99
+ profiler_th = MyLineProfiler(cuda_sync=True)
100
+ generate_2D_gaussian_splatting = profiler_th.decorate(generate_2D_gaussian_splatting)
101
+ profiler_cuda = MyLineProfiler(cuda_sync=True)
102
+ gaussiansplatting_render = profiler_cuda.decorate(gaussiansplatting_render)
103
+
104
+
105
+ # --- test ---
106
+ s = int(512 * 512)
107
+ # s = 5
108
+ image_size = (512, 512, 3)
109
+
110
+ sigmas = 0.2*torch.rand(s, 3).to(torch.float32).to("cuda")
111
+ sigmas[:,:2] = 5*sigmas[:, :2]
112
+ coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
113
+ colors = torch.rand(s, 3).to(torch.float32).to("cuda")
114
+
115
+ # --- torch version ---
116
+ import gc
117
+ # gc.collect()
118
+ # torch.cuda.empty_cache()
119
+ # for _ in range(1):
120
+ # img_python = generate_2D_gaussian_splatting(128, sigmas[:,1], sigmas[:,0], sigmas[:,2], coords, colors, image_size)
121
+ # profiler_th.print("profile.log", "w")
122
+ # cv2.imwrite("th.png", 255.0*img_python.detach().clamp(0,1).cpu().numpy())
123
+ # --- ends ---
124
+
125
+ # --- cuda version ---
126
+ sigmas[:, 0] = sigmas[:, 0]
127
+ sigmas[:, 1] = sigmas[:, 1]
128
+ gc.collect()
129
+ torch.cuda.empty_cache()
130
+ for _ in range(10):
131
+ with torch.no_grad():
132
+ img_cuda = gaussiansplatting_render(sigmas, coords, colors, image_size)
133
+
134
+ profiler_cuda.print("profile.log", "a")
135
+ cv2.imwrite("cuda.png", 255.0*img_cuda.detach().clamp(0,1).cpu().numpy())
136
+ # --- ends ---
137
+ pass
utils/gs_cuda_dmax/__init__.py ADDED
File without changes
utils/gs_cuda_dmax/check.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from gswrapper import gaussiansplatting_render
3
+
4
+ def torch_version(sigmas, coords, colors, image_size, dmax=100):
5
+ h, w = image_size
6
+ c = colors.shape[-1]
7
+
8
+ if h >= 50 or w >= 50:
9
+ logger.warning(f'too large values for h({h}), w({w}), torch version would be slow')
10
+
11
+ rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
12
+
13
+ for hi in range(h):
14
+ for wi in range(w):
15
+ curh = 2*hi/(h-1)-1.0
16
+ curw = 2*wi/(w-1)-1.0
17
+
18
+ v = (curw-coords[:,0])**2/sigmas[:,0]**2
19
+ v -= (2*sigmas[:,2])*(curw-coords[:,0])*(curh-coords[:,1])/sigmas[:,0]/sigmas[:,1]
20
+ v += (curh-coords[:,1])**2/sigmas[:,1]**2
21
+ v *= -1.0/(2.0*(1-sigmas[:,2]**2))
22
+ v = torch.exp(v)
23
+
24
+ mask_w = abs(curw-coords[:,0]) <= dmax
25
+ mask_h = abs(curh-coords[:,1]) <= dmax
26
+ mask = torch.logical_and(mask_w, mask_h)
27
+
28
+ for ci in range(c):
29
+ rendered_img[hi, wi, ci] = torch.sum((v*colors[:, ci])[mask])
30
+
31
+ return rendered_img
32
+
33
+
34
+ if __name__ == "__main__":
35
+ s = 4 # the number of gs
36
+ image_size = (10, 10)
37
+
38
+ for _ in range(1):
39
+ print(f"--------------------------- begins --------------------------------")
40
+
41
+ sigmas = 0.999*torch.rand(s, 3).to(torch.float32).to("cuda")
42
+ sigmas[:,:2] = 5*sigmas[:, :2]
43
+ coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
44
+ colors = torch.rand(s, 3).to(torch.float32).to("cuda")
45
+ # colors = torch.rand(s, 5).to(torch.float32).to("cuda")
46
+ dmax = 0.5
47
+
48
+ # sigmas = torch.Tensor([[0.9196, 0.3979, 0.7784]]).to(torch.float32).to("cuda")
49
+ # coords = torch.Tensor([[-0.0469, -0.1726]]).to(torch.float32).to("cuda")
50
+ # colors = torch.Tensor([[0.3775, 0.2346, 0.1513]]).to(torch.float32).to("cuda")
51
+ # colors = torch.ones_like(coords[:,0:1])
52
+
53
+ print(f"sigmas: {sigmas}, \ncoords:{coords}, \ncolors:{colors}\ndmax:{dmax}")
54
+
55
+ # --- check forward ---
56
+ with torch.no_grad():
57
+ rendered_img_th = torch_version(sigmas,coords,colors,image_size,dmax)
58
+ rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size,dmax)
59
+
60
+ #
61
+ distance = (rendered_img_th-rendered_img_cuda)**2
62
+ print(f"check forward - torch: {rendered_img_th[:2,:2,0]}")
63
+ print(f"check forward - cuda: {rendered_img_cuda[:2,:2,0]}")
64
+ print(f"check forward - distance: {distance[:2, :2, 0]}")
65
+ print(f"check forward - sum: {torch.sum(distance)}\n")
66
+ # --- ends ---
67
+
68
+ # --- check backward ---
69
+ sigmas.requires_grad_(True)
70
+ coords.requires_grad_(True)
71
+ colors.requires_grad_(True)
72
+ # sigmas.retain_grad()
73
+ # coords.retain_grad()
74
+ # colors.retain_grad()
75
+ weight = torch.rand_like(rendered_img_th) # make each pixel has different grads
76
+
77
+ sigmas.grad = None
78
+ coords.grad = None
79
+ colors.grad = None
80
+ rendered_img_th = torch_version(sigmas,coords,colors,image_size,dmax)
81
+ loss_th = torch.sum(weight*rendered_img_th)
82
+ # loss_th = torch.sum(rendered_img_th)
83
+ loss_th.backward()
84
+
85
+ sigmas_grad_th = sigmas.grad
86
+ coords_grad_th = coords.grad
87
+ colors_grad_th = colors.grad
88
+
89
+ sigmas.grad = None
90
+ coords.grad = None
91
+ colors.grad = None
92
+ rendered_img_cuda = gaussiansplatting_render(sigmas,coords,colors,image_size,dmax)
93
+ loss_cuda = torch.sum(weight*rendered_img_cuda)
94
+ # loss_cuda = torch.sum(rendered_img_cuda)
95
+ loss_cuda.backward()
96
+
97
+ sigmas_grad_cuda = sigmas.grad
98
+ coords_grad_cuda = coords.grad
99
+ colors_grad_cuda = colors.grad
100
+
101
+ distance_sigmas_grad = (sigmas_grad_th-sigmas_grad_cuda)**2
102
+ distance_coords_grad = (coords_grad_th-coords_grad_cuda)**2
103
+ distance_colors_grad = (colors_grad_th-colors_grad_cuda)**2
104
+
105
+ print(f"check backward - sigmas - torch: {sigmas_grad_th[:2]}")
106
+ print(f"check backward - sigmas - cuda: {sigmas_grad_cuda[:2]}")
107
+ print(f"check backward - sigmas - distance: {distance_sigmas_grad[:2]}")
108
+ print(f"check backward - sigmas - sum: {torch.sum(distance_sigmas_grad)}\n")
109
+
110
+ print(f"check backward - coords - torch: {coords_grad_th[:2]}")
111
+ print(f"check backward - coords - cuda: {coords_grad_cuda[:2]}")
112
+ print(f"check backward - coords - distance: {distance_coords_grad[:2]}")
113
+ print(f"check backward - coords - sum: {torch.sum(distance_coords_grad)}\n")
114
+
115
+ print(f"check backward - colors - torch: {colors_grad_th[:2]}")
116
+ print(f"check backward - colors - cuda: {colors_grad_cuda[:2]}")
117
+ print(f"check backward - colors - distance: {distance_colors_grad[:2]}")
118
+ print(f"check backward - colors - sum: {torch.sum(distance_colors_grad)}\n")
119
+
120
+ print(f"--------------------------- ends --------------------------------\n\n")
121
+
122
+
utils/gs_cuda_dmax/gs copy.cu ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <cmath>
3
+
4
+ #define PI 3.1415926536
5
+ #define PI2 6.283153072
6
+
7
+ __global__ void _gs_render_cuda(
8
+ const float *sigmas,
9
+ const float *coords,
10
+ const float *colors,
11
+ float *rendered_img,
12
+ const int s, // gs num
13
+ const int h,
14
+ const int w,
15
+ const int c,
16
+ const float dmax
17
+ ){
18
+
19
+ int index = blockIdx.x*blockDim.x + threadIdx.x;
20
+ int curw = index % w;
21
+ int curh = int((index-curw)/w);
22
+ if(curw >= w || curh >=h){
23
+ return;
24
+ }
25
+
26
+ float curw_f = 2.0*curw/(w-1) - 1.0;
27
+ float curh_f = 2.0*curh/(h-1) - 1.0;
28
+
29
+ // printf("index:%d, curw:%d, curh:%d, curw_f:%f, curh_f:%f\n",index,curw,curh,curw_f,curh_f);
30
+
31
+ for(int si=0; si<s; si++){
32
+
33
+ // compute the 2d gs value
34
+ float sigma_x = sigmas[si*3+0];
35
+ float sigma_y = sigmas[si*3+1];
36
+ float rho = sigmas[si*3+2];
37
+ float x = coords[si*2+0];
38
+ float y = coords[si*2+1];
39
+
40
+ //
41
+ float one_div_one_minus_rho2 = 1.0 / (1-rho*rho) ;
42
+ float one_div_sigma_x = 1.0 / sigma_x;
43
+ float one_div_sigma_y = 1.0 / sigma_y;
44
+ float d_x = curw_f - x;
45
+ float d_y = curh_f - y;
46
+
47
+ if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
48
+ continue;
49
+ }
50
+
51
+ float v = one_div_sigma_x*one_div_sigma_x*d_x*d_x;
52
+ v -= 2*rho*d_x*d_y*one_div_sigma_x*one_div_sigma_y;
53
+ v += d_y*d_y*one_div_sigma_y*one_div_sigma_y;
54
+ v *= -one_div_one_minus_rho2 / 2.0;
55
+ v = exp(v);
56
+ // since we normlize the v with the max, we remove this step to obtain equal result
57
+ // v *= one_div_sigma_x * one_div_sigma_y * pow(one_div_one_minus_rho2, 0.5) / PI2 ;
58
+ // printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
59
+
60
+ for(int ci=0; ci<c; ci++){
61
+ rendered_img[(curh*w+curw)*c+ci] += v*colors[si*c+ci];
62
+ }
63
+ }
64
+
65
+ }
66
+
67
+
68
+ void _gs_render(
69
+ const float *sigmas,
70
+ const float *coords,
71
+ const float *colors,
72
+ float *rendered_img,
73
+ const int s,
74
+ const int h,
75
+ const int w,
76
+ const int c,
77
+ const float dmax
78
+ ) {
79
+
80
+ int threads=16;
81
+ dim3 grid( h*w, 1);
82
+ dim3 block( threads, 1);
83
+ _gs_render_cuda<<<grid, block>>>(sigmas, coords, colors, rendered_img, s, h, w, c, dmax);
84
+ }
85
+
86
+
87
+ __global__ void _gs_render_backward_cuda(
88
+ const float *sigmas,
89
+ const float *coords,
90
+ const float *colors,
91
+ const float *grads,
92
+ float *grads_sigmas,
93
+ float *grads_coords,
94
+ float *grads_colors,
95
+ const int s, // gs num
96
+ const int h,
97
+ const int w,
98
+ const int c,
99
+ const float dmax
100
+
101
+ ){
102
+
103
+ int curs = blockIdx.x*blockDim.x + threadIdx.x;
104
+ if(curs >= s){
105
+ return ;
106
+ }
107
+
108
+ // obtain parameters of gs
109
+ float sigma_x = sigmas[curs*3+0];
110
+ float sigma_y = sigmas[curs*3+1];
111
+ float rho = sigmas[curs*3+2];
112
+ float x = coords[curs*2+0];
113
+ float y = coords[curs*2+1];
114
+ float cr = colors[curs*3+0];
115
+ float cg = colors[curs*3+1];
116
+ float cb = colors[curs*3+2];
117
+
118
+ //
119
+ float w1 = -0.5 / (1-rho*rho) ;
120
+ float w2 = 1.0 / (sigma_x*sigma_x);
121
+ float w3 = 1.0 / (sigma_x*sigma_y);
122
+ float w4 = 1.0 / (sigma_y*sigma_y);
123
+ float od_sx = 1.0 / sigma_x;
124
+ float od_sy = 1.0 / sigma_y;
125
+
126
+ // init
127
+ float _gr=0.0, _gg=0.0, _gb=0.0;
128
+ float _gx=0.0, _gy=0.0;
129
+ float _gsx=0.0, _gsy=0.0, _gsr=0.0;
130
+
131
+ for(int hi = 0; hi < h; hi++){
132
+ for( int wi=0; wi < w; wi++){
133
+
134
+ float curw_f = 2.0*wi/(w-1) - 1.0;
135
+ float curh_f = 2.0*hi/(h-1) - 1.0;
136
+
137
+ // obtain grad to p^t_r, p^t_g, p^t_b
138
+ float gptr = grads[(hi*w+wi)*c+0]; // grad of loss to P^t_r
139
+ float gptg = grads[(hi*w+wi)*c+1];
140
+ float gptb = grads[(hi*w+wi)*c+2];
141
+
142
+ // compute the 2d gs value
143
+
144
+ float d_x = curw_f - x; // distance along x axis
145
+ float d_y = curh_f - y;
146
+ // if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
147
+ // continue;
148
+ // }
149
+ // printf("here");
150
+
151
+ float d = w2*d_x*d_x - 2*rho*w3*d_x*d_y + w4*d_y*d_y;
152
+ float v = w1*d;
153
+ v = exp(v);
154
+ // printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
155
+
156
+ // compute grad of colors
157
+ _gr += v*gptr;
158
+ _gg += v*gptg;
159
+ _gb += v*gptb;
160
+
161
+ // compute grad of coords
162
+ float gpt = gptr*cr+gptg*cg+gptb*cb;
163
+ float v_2_w1 = v*2*w1;
164
+
165
+ float g_vst_to_gsx = v_2_w1*(-w2*d_x+rho*w3*d_y); // grad of v^{st} to G^s_x
166
+ _gx += gpt*g_vst_to_gsx;
167
+ float g_vst_to_gsy = v_2_w1*(-w4*d_y+rho*w3*d_x); // grad of v^{st} to G^s_y
168
+ _gy += gpt*g_vst_to_gsy;
169
+
170
+ // compute grad of sigmas
171
+ float g_vst_to_gsigx = v_2_w1*od_sx* (w3*rho*d_x*d_y - w2*d_x*d_x);
172
+ _gsx += gpt*g_vst_to_gsigx;
173
+ float g_vst_to_gsigy = v_2_w1*od_sy* (w3*rho*d_x*d_y - w4*d_y*d_y);
174
+ _gsy += gpt*g_vst_to_gsigy;
175
+ float g_vst_to_rho = -v_2_w1*(2*w1*rho*d+w3*d_x*d_y);
176
+ _gsr += gpt*g_vst_to_rho;
177
+ }
178
+ }
179
+
180
+ // write the values
181
+ grads_sigmas[curs*3+0] = _gsx;
182
+ grads_sigmas[curs*3+1] = _gsy;
183
+ grads_sigmas[curs*3+2] = _gsr;
184
+ grads_coords[curs*2+0] = _gx;
185
+ grads_coords[curs*2+1] = _gy;
186
+ grads_colors[curs*3+0] = _gr;
187
+ grads_colors[curs*3+1] = _gg;
188
+ grads_colors[curs*3+2] = _gb;
189
+
190
+ }
191
+
192
+ void _gs_render_backward(
193
+ const float *sigmas,
194
+ const float *coords,
195
+ const float *colors,
196
+ const float *grads, // (h, w, c)
197
+ float *grads_sigmas,
198
+ float *grads_coords,
199
+ float *grads_colors,
200
+ const int s,
201
+ const int h,
202
+ const int w,
203
+ const int c,
204
+ const float dmax
205
+ ) {
206
+
207
+ int threads=16;
208
+ dim3 grid(s, 1);
209
+ dim3 block( threads, 1);
210
+ _gs_render_backward_cuda<<<grid, block>>>(sigmas, coords, colors, grads, grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax);
211
+ }
212
+
utils/gs_cuda_dmax/gs.backup.cu ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <cmath>
3
+
4
+ #define PI 3.1415926536
5
+ #define PI2 6.283153072
6
+
7
+ __global__ void _gs_render_cuda(
8
+ const float *sigmas,
9
+ const float *coords,
10
+ const float *colors,
11
+ float *rendered_img,
12
+ const int s, // gs num
13
+ const int h,
14
+ const int w,
15
+ const int c,
16
+ const float dmax
17
+ ){
18
+
19
+ int index = blockIdx.x*blockDim.x + threadIdx.x;
20
+ int curw = index % w;
21
+ int curh = int((index-curw)/w);
22
+ if(curw >= w || curh >=h){
23
+ return;
24
+ }
25
+
26
+ float curw_f = 2.0*curw/(w-1) - 1.0;
27
+ float curh_f = 2.0*curh/(h-1) - 1.0;
28
+
29
+ // printf("index:%d, curw:%d, curh:%d, curw_f:%f, curh_f:%f\n",index,curw,curh,curw_f,curh_f);
30
+
31
+ for(int si=0; si<s; si++){
32
+
33
+ // compute the 2d gs value
34
+ float sigma_x = sigmas[si*3+0];
35
+ float sigma_y = sigmas[si*3+1];
36
+ float rho = sigmas[si*3+2];
37
+ float x = coords[si*2+0];
38
+ float y = coords[si*2+1];
39
+
40
+ //
41
+ float one_div_one_minus_rho2 = 1.0 / (1-rho*rho) ;
42
+ float one_div_sigma_x = 1.0 / sigma_x;
43
+ float one_div_sigma_y = 1.0 / sigma_y;
44
+ float d_x = curw_f - x;
45
+ float d_y = curh_f - y;
46
+
47
+ if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
48
+ continue;
49
+ }
50
+
51
+ float v = one_div_sigma_x*one_div_sigma_x*d_x*d_x;
52
+ v -= 2*rho*d_x*d_y*one_div_sigma_x*one_div_sigma_y;
53
+ v += d_y*d_y*one_div_sigma_y*one_div_sigma_y;
54
+ v *= -one_div_one_minus_rho2 / 2.0;
55
+ v = exp(v);
56
+ // since we normlize the v with the max, we remove this step to obtain equal result
57
+ // v *= one_div_sigma_x * one_div_sigma_y * pow(one_div_one_minus_rho2, 0.5) / PI2 ;
58
+ // printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
59
+
60
+ for(int ci=0; ci<c; ci++){
61
+ rendered_img[(curh*w+curw)*c+ci] += v*colors[si*c+ci];
62
+ }
63
+ }
64
+
65
+ }
66
+
67
+
68
+ void _gs_render(
69
+ const float *sigmas,
70
+ const float *coords,
71
+ const float *colors,
72
+ float *rendered_img,
73
+ const int s,
74
+ const int h,
75
+ const int w,
76
+ const int c,
77
+ const float dmax
78
+ ) {
79
+
80
+ int threads=16;
81
+ dim3 grid( h*w, 1);
82
+ dim3 block( threads, 1);
83
+ _gs_render_cuda<<<grid, block>>>(sigmas, coords, colors, rendered_img, s, h, w, c, dmax);
84
+ }
85
+
86
+ __global__ void _gs_render_backward_cuda(
87
+ const float *sigmas,
88
+ const float *coords,
89
+ const float *colors,
90
+ const float *grads,
91
+ float *grads_sigmas,
92
+ float *grads_coords,
93
+ float *grads_colors,
94
+ const int s, // gs num
95
+ const int h,
96
+ const int w,
97
+ const int c,
98
+ const float dmax
99
+ ){
100
+
101
+ int curs = blockIdx.x*blockDim.x + threadIdx.x;
102
+ if(curs >= s){
103
+ return ;
104
+ }
105
+
106
+ // obtain parameters of gs
107
+ float sigma_x = sigmas[curs*3+0];
108
+ float sigma_y = sigmas[curs*3+1];
109
+ float rho = sigmas[curs*3+2];
110
+ float x = coords[curs*2+0];
111
+ float y = coords[curs*2+1];
112
+
113
+ //
114
+ float w1 = -0.5 / (1-rho*rho) ;
115
+ float w2 = 1.0 / (sigma_x*sigma_x);
116
+ float w3 = 1.0 / (sigma_x*sigma_y);
117
+ float w4 = 1.0 / (sigma_y*sigma_y);
118
+ float od_sx = 1.0 / sigma_x;
119
+ float od_sy = 1.0 / sigma_y;
120
+
121
+ // init
122
+ for(int hi = 0; hi < h; hi++){
123
+ for( int wi=0; wi < w; wi++){
124
+
125
+ float curw_f = 2.0*wi/(w-1) - 1.0;
126
+ float curh_f = 2.0*hi/(h-1) - 1.0;
127
+
128
+ // compute the 2d gs value
129
+ float d_x = curw_f - x; // distance along x axis
130
+ float d_y = curh_f - y;
131
+ if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
132
+ continue;
133
+ }
134
+ float d = w2*d_x*d_x - 2*rho*w3*d_x*d_y + w4*d_y*d_y;
135
+ float v = w1*d;
136
+ v = exp(v);
137
+ // printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
138
+
139
+ // compute grad of coords
140
+ float v_2_w1 = v*2*w1;
141
+ float g_vst_to_gsx = v_2_w1*(-w2*d_x+rho*w3*d_y); // grad of v^{st} to G^s_x
142
+ float g_vst_to_gsy = v_2_w1*(-w4*d_y+rho*w3*d_x); // grad of v^{st} to G^s_y
143
+
144
+ // compute grad of sigmas
145
+ float g_vst_to_gsigx = v_2_w1*od_sx* (w3*rho*d_x*d_y - w2*d_x*d_x);
146
+ float g_vst_to_gsigy = v_2_w1*od_sy* (w3*rho*d_x*d_y - w4*d_y*d_y);
147
+ float g_vst_to_rho = -v_2_w1*(2*w1*rho*d+w3*d_x*d_y);
148
+
149
+ for(int ci=0; ci<c; ci++){
150
+ float _gptc = grads[(hi*w+wi)*c+ci];
151
+ float _gpt = _gptc*colors[curs*c+ci];
152
+
153
+ grads_colors[curs*c+ci] += v*_gptc;
154
+
155
+ grads_coords[curs*2+0] += _gpt*g_vst_to_gsx;
156
+ grads_coords[curs*2+1] += _gpt*g_vst_to_gsy;
157
+
158
+ grads_sigmas[curs*3+0] += _gpt*g_vst_to_gsigx;
159
+ grads_sigmas[curs*3+1] += _gpt*g_vst_to_gsigy;
160
+ grads_sigmas[curs*3+2] += _gpt*g_vst_to_rho;
161
+ }
162
+
163
+ }
164
+ }
165
+
166
+ }
167
+
168
+ void _gs_render_backward(
169
+ const float *sigmas,
170
+ const float *coords,
171
+ const float *colors,
172
+ const float *grads, // (h, w, c)
173
+ float *grads_sigmas,
174
+ float *grads_coords,
175
+ float *grads_colors,
176
+ const int s,
177
+ const int h,
178
+ const int w,
179
+ const int c,
180
+ const float dmax
181
+ ) {
182
+
183
+ int threads=16;
184
+ dim3 grid(s, 1);
185
+ dim3 block( threads, 1);
186
+ _gs_render_backward_cuda<<<grid, block>>>(sigmas, coords, colors, grads, grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax);
187
+ }
188
+
utils/gs_cuda_dmax/gs.cu ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <cmath>
3
+
4
+ #define PI 3.1415926536
5
+ #define PI2 6.283153072
6
+
7
+ __global__ void _gs_render_cuda(
8
+ const float *sigmas,
9
+ const float *coords,
10
+ const float *colors,
11
+ float *rendered_img,
12
+ const int s, // gs num
13
+ const int h,
14
+ const int w,
15
+ const int c,
16
+ const float dmax
17
+ ){
18
+
19
+ int curs = blockIdx.x*blockDim.x + threadIdx.x;
20
+ if(curs >= s){
21
+ return;
22
+ }
23
+
24
+ float sigma_x = sigmas[curs*3+0];
25
+ float sigma_y = sigmas[curs*3+1];
26
+ float rho = sigmas[curs*3+2];
27
+ float x = coords[curs*2+0];
28
+ float y = coords[curs*2+1];
29
+ float r = colors[curs*3];
30
+ float g = colors[curs*3+1];
31
+ float b = colors[curs*3+2];
32
+
33
+ float negative_half_one_div_one_minus_rho2 = -0.5 / (1-rho*rho);
34
+ float one_div_sigma_x_2 = 1.0 / sigma_x / sigma_x;
35
+ float one_div_sigma_y_2 = 1.0 / sigma_y / sigma_y;
36
+ float two_rho_div_sigma_x_one_div_sigma_y = 2*rho / sigma_x / sigma_y;
37
+
38
+ for(int hi=0; hi<h; hi++){
39
+ float curh_f = 2.0*hi/(h-1) - 1.0;
40
+ float d_y = curh_f - y;
41
+ if(d_y > dmax || d_y < -dmax){
42
+ continue;
43
+ }
44
+
45
+ for(int wi=0; wi<w; wi++){
46
+ float curw_f = 2.0*wi/(w-1) - 1.0;
47
+ float d_x = curw_f - x;
48
+ if(d_x > dmax || d_x < -dmax){
49
+ continue;
50
+ }
51
+
52
+ float v = one_div_sigma_x_2*d_x*d_x;
53
+ v -= two_rho_div_sigma_x_one_div_sigma_y*d_x*d_y;
54
+ v += one_div_sigma_y_2*d_y*d_y;
55
+ v *= negative_half_one_div_one_minus_rho2;
56
+ v = exp(v);
57
+
58
+ atomicAdd(&rendered_img[(hi*w+wi)*c+0], v*r);
59
+ atomicAdd(&rendered_img[(hi*w+wi)*c+1], v*g);
60
+ atomicAdd(&rendered_img[(hi*w+wi)*c+2], v*b);
61
+ }
62
+ }
63
+
64
+ }
65
+
66
+
67
+ void _gs_render(
68
+ const float *sigmas,
69
+ const float *coords,
70
+ const float *colors,
71
+ float *rendered_img,
72
+ const int s,
73
+ const int h,
74
+ const int w,
75
+ const int c,
76
+ const float dmax
77
+ ) {
78
+
79
+ int threads=64;
80
+ dim3 grid(int(s/threads)+1);
81
+ dim3 block(threads);
82
+ _gs_render_cuda<<<grid, block>>>(sigmas, coords, colors, rendered_img, s, h, w, c, dmax);
83
+ }
84
+
85
+ __global__ void _gs_render_backward_cuda(
86
+ const float *sigmas,
87
+ const float *coords,
88
+ const float *colors,
89
+ const float *grads,
90
+ float *grads_sigmas,
91
+ float *grads_coords,
92
+ float *grads_colors,
93
+ const int s, // gs num
94
+ const int h,
95
+ const int w,
96
+ const int c,
97
+ const float dmax
98
+ ){
99
+
100
+ int curs = blockIdx.x*blockDim.x + threadIdx.x;
101
+ if(curs >= s){
102
+ return ;
103
+ }
104
+
105
+ // obtain parameters of gs
106
+ float sigma_x = sigmas[curs*3+0];
107
+ float sigma_y = sigmas[curs*3+1];
108
+ float rho = sigmas[curs*3+2];
109
+ float x = coords[curs*2+0];
110
+ float y = coords[curs*2+1];
111
+
112
+ //
113
+ float w1 = -0.5 / (1-rho*rho) ;
114
+ float w2 = 1.0 / (sigma_x*sigma_x);
115
+ float w3 = 1.0 / (sigma_x*sigma_y);
116
+ float w4 = 1.0 / (sigma_y*sigma_y);
117
+ float od_sx = 1.0 / sigma_x;
118
+ float od_sy = 1.0 / sigma_y;
119
+
120
+ // init
121
+ for(int hi = 0; hi < h; hi++){
122
+ for( int wi=0; wi < w; wi++){
123
+
124
+ float curw_f = 2.0*wi/(w-1) - 1.0;
125
+ float curh_f = 2.0*hi/(h-1) - 1.0;
126
+
127
+ // compute the 2d gs value
128
+ float d_x = curw_f - x; // distance along x axis
129
+ float d_y = curh_f - y;
130
+ if(d_x > dmax || d_x < -dmax || d_y > dmax || d_y < -dmax){
131
+ continue;
132
+ }
133
+ float d = w2*d_x*d_x - 2*rho*w3*d_x*d_y + w4*d_y*d_y;
134
+ float v = w1*d;
135
+ v = exp(v);
136
+ // printf("si:%d, sigma_x: %f, sigma_y:%f, rho:%f, x:%f, y:%f, v:%f\n", si, sigma_x, sigma_y, rho, x,y,v);
137
+
138
+ // compute grad of coords
139
+ float v_2_w1 = v*2*w1;
140
+ float g_vst_to_gsx = v_2_w1*(-w2*d_x+rho*w3*d_y); // grad of v^{st} to G^s_x
141
+ float g_vst_to_gsy = v_2_w1*(-w4*d_y+rho*w3*d_x); // grad of v^{st} to G^s_y
142
+
143
+ // compute grad of sigmas
144
+ float g_vst_to_gsigx = v_2_w1*od_sx* (w3*rho*d_x*d_y - w2*d_x*d_x);
145
+ float g_vst_to_gsigy = v_2_w1*od_sy* (w3*rho*d_x*d_y - w4*d_y*d_y);
146
+ float g_vst_to_rho = -v_2_w1*(2*w1*rho*d+w3*d_x*d_y);
147
+
148
+ for(int ci=0; ci<c; ci++){
149
+ float _gptc = grads[(hi*w+wi)*c+ci];
150
+ float _gpt = _gptc*colors[curs*c+ci];
151
+
152
+ grads_colors[curs*c+ci] += v*_gptc;
153
+
154
+ grads_coords[curs*2+0] += _gpt*g_vst_to_gsx;
155
+ grads_coords[curs*2+1] += _gpt*g_vst_to_gsy;
156
+
157
+ grads_sigmas[curs*3+0] += _gpt*g_vst_to_gsigx;
158
+ grads_sigmas[curs*3+1] += _gpt*g_vst_to_gsigy;
159
+ grads_sigmas[curs*3+2] += _gpt*g_vst_to_rho;
160
+ }
161
+
162
+ }
163
+ }
164
+
165
+ }
166
+
167
+ void _gs_render_backward(
168
+ const float *sigmas,
169
+ const float *coords,
170
+ const float *colors,
171
+ const float *grads, // (h, w, c)
172
+ float *grads_sigmas,
173
+ float *grads_coords,
174
+ float *grads_colors,
175
+ const int s,
176
+ const int h,
177
+ const int w,
178
+ const int c,
179
+ const float dmax
180
+ ) {
181
+
182
+ int threads=64;
183
+ dim3 grid(s, 1);
184
+ dim3 block( threads, 1);
185
+ _gs_render_backward_cuda<<<grid, block>>>(sigmas, coords, colors, grads, grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax);
186
+ }
187
+
utils/gs_cuda_dmax/gs.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ void _gs_render(
2
+ const float *sigmas,
3
+ const float *coords,
4
+ const float *colors,
5
+ float *rendered_img,
6
+ const int s,
7
+ const int h,
8
+ const int w,
9
+ const int c,
10
+ const float dmax
11
+ );
12
+
13
+ void _gs_render_backward(
14
+ const float *sigmas,
15
+ const float *coords,
16
+ const float *colors,
17
+ const float *grads,
18
+ float *grads_sigmas,
19
+ float *grads_coords,
20
+ float *grads_colors,
21
+ const int s,
22
+ const int h,
23
+ const int w,
24
+ const int c,
25
+ const float dmax
26
+ );
utils/gs_cuda_dmax/gswrapper.cpp ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "gs.h"
2
+ #include <torch/extension.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
6
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
7
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
8
+
9
+ void gs_render(
10
+ torch::Tensor &sigmas,
11
+ torch::Tensor &coords,
12
+ torch::Tensor &colors,
13
+ torch::Tensor &rendered_img,
14
+ const int s,
15
+ const int h,
16
+ const int w,
17
+ const int c,
18
+ const float dmax
19
+ ){
20
+
21
+ CHECK_INPUT(sigmas);
22
+ CHECK_INPUT(coords);
23
+ CHECK_INPUT(colors);
24
+ CHECK_INPUT(rendered_img);
25
+
26
+ // run the code at the cuda device same with the input
27
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));
28
+
29
+ _gs_render(
30
+ (const float *) sigmas.data_ptr(),
31
+ (const float *) coords.data_ptr(),
32
+ (const float *) colors.data_ptr(),
33
+ (float *) rendered_img.data_ptr(),
34
+ s, h, w, c, dmax);
35
+ }
36
+
37
+ void gs_render_backward(
38
+ torch::Tensor &sigmas,
39
+ torch::Tensor &coords,
40
+ torch::Tensor &colors,
41
+ torch::Tensor &grads,
42
+ torch::Tensor &grads_sigmas,
43
+ torch::Tensor &grads_coords,
44
+ torch::Tensor &grads_colors,
45
+ const int s,
46
+ const int h,
47
+ const int w,
48
+ const int c,
49
+ const float dmax
50
+ ){
51
+
52
+ CHECK_INPUT(sigmas);
53
+ CHECK_INPUT(coords);
54
+ CHECK_INPUT(colors);
55
+ CHECK_INPUT(grads);
56
+ CHECK_INPUT(grads_sigmas);
57
+ CHECK_INPUT(grads_coords);
58
+ CHECK_INPUT(grads_colors);
59
+
60
+
61
+ // run the code at the cuda device same with the input
62
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));
63
+
64
+ _gs_render_backward(
65
+ (const float *) sigmas.data_ptr(),
66
+ (const float *) coords.data_ptr(),
67
+ (const float *) colors.data_ptr(),
68
+ (const float *) grads.data_ptr(),
69
+ (float *) grads_sigmas.data_ptr(),
70
+ (float *) grads_coords.data_ptr(),
71
+ (float *) grads_colors.data_ptr(),
72
+ s, h, w, c, dmax);
73
+ }
74
+
75
+ PYBIND11_MODULE( TORCH_EXTENSION_NAME, m) {
76
+ m.def( "gs_render",
77
+ &gs_render,
78
+ "cuda forward wrapper");
79
+ m.def( "gs_render_backward",
80
+ &gs_render_backward,
81
+ "cuda backward wrapper");
82
+ }
utils/gs_cuda_dmax/gswrapper.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+ from torch.autograd import Function
5
+ from torch.autograd.function import once_differentiable
6
+
7
+ #
8
+ build_path = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'build')
9
+ os.makedirs(build_path, exist_ok=True)
10
+
11
+ file_path = os.path.split(os.path.abspath(__file__))[0]
12
+ # GSWrapper = load(
13
+ # name="gscuda",
14
+ # # sources=["gs_cuda/gswrapper.cpp", "gs_cuda/gs.cu"],
15
+ # sources=[os.path.join(file_path, "gswrapper.cpp"),
16
+ # os.path.join(file_path, "gs.cu")],
17
+ # build_directory=build_path,
18
+ # verbose=True)
19
+
20
+ import gscuda
21
+ GSWrapper = gscuda
22
+
23
+ class GSCUDA(Function):
24
+
25
+ @staticmethod
26
+ def forward(ctx, sigmas, coords, colors, rendered_img, dmax):
27
+ ctx.save_for_backward(sigmas, coords, colors)
28
+ ctx.dmax = dmax
29
+ h, w, c = rendered_img.shape
30
+ s = sigmas.shape[0]
31
+ GSWrapper.gs_render(sigmas, coords, colors, rendered_img, s, h, w, c, dmax)
32
+ return rendered_img
33
+
34
+ @staticmethod
35
+ @once_differentiable
36
+ def backward(ctx, grad_output):
37
+ sigmas, coords, colors = ctx.saved_tensors
38
+ dmax = ctx.dmax
39
+ h, w, c = grad_output.shape
40
+ s = sigmas.shape[0]
41
+ grads_sigmas = torch.zeros_like(sigmas)
42
+ grads_coords = torch.zeros_like(coords)
43
+ grads_colors = torch.zeros_like(colors)
44
+ GSWrapper.gs_render_backward(sigmas, coords, colors, grad_output.contiguous(), grads_sigmas, grads_coords, grads_colors, s, h, w, c, dmax)
45
+ return (grads_sigmas, grads_coords, grads_colors, None, None)
46
+
47
+ def gaussiansplatting_render(sigmas, coords, colors, image_size,dmax=100):
48
+ sigmas = sigmas.contiguous() # (gs num, 3)
49
+ coords = coords.contiguous() # (gs num, 2)
50
+ colors = colors.contiguous() # (gs num, c)
51
+ h, w = image_size[:2]
52
+ c = colors.shape[-1]
53
+ rendered_img = torch.zeros(h, w, c).to(colors.device).to(torch.float32)
54
+ return GSCUDA.apply(sigmas, coords, colors, rendered_img, dmax)
55
+
56
+ if __name__ == "__main__":
57
+ sigmas = torch.randn(10, 3).cuda()
58
+ coords = torch.randn(10, 2).cuda()
59
+ colors = torch.randn(10, 3).cuda()
60
+ image_size = (100, 100)
61
+ dmax = 0.1
62
+ rendered_img = gaussiansplatting_render(sigmas, coords, colors, image_size, dmax)
63
+ print(rendered_img.shape)
utils/gs_cuda_dmax/mylineprofiler.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import sys
4
+ import timeit
5
+ import tokenize
6
+ import torch
7
+ import psutil
8
+ import inspect
9
+ from loguru import logger
10
+ from prettytable import PrettyTable
11
+
12
+ # implement by xtudbxk
13
+ # github: https://github.com/xtudbxk/lineprofiler
14
+ class MyLineProfiler():
15
+ def __init__(self, base='ms', cuda_sync=True, gpuids=(0,), warmup=0, warmup_lineno=-1):
16
+
17
+ if base == 'ms':
18
+ self.base_n = 1000
19
+ elif base == 's':
20
+ self.base_n = 1
21
+ else:
22
+ logguru.warning(f'Unsupported base - {base}, using "s" instead')
23
+
24
+ self.base = base
25
+ self.cuda_sync = cuda_sync
26
+ self.gpuids = gpuids
27
+ self.warmup = warmup
28
+ self.warmup_counter = warmup
29
+ # we should wait this line execute warup_counter times
30
+ # before recording the stats
31
+ self.warmup_lineno = warmup_lineno
32
+
33
+ # for time profiling
34
+ self._times = {}
35
+ self._func_name = None
36
+ self._func_filename = None
37
+ self._last_time = -1
38
+ self._last_lineno = -1
39
+ self._func_hit_count = 0
40
+ self._func_firstlineno = 0
41
+
42
+ # for memory profiling
43
+ self._process = psutil.Process(os.getpid())
44
+ self._memory = {}
45
+ self._last_memory = 0
46
+
47
+ # for cuda memory profiling
48
+ self._gpu_memory = {}
49
+ self._gpu_last_memory = 0
50
+
51
+ def __trace_func__(self, frame, event, arg):
52
+ # print(f'in {frame.f_code.co_filename} func {frame.f_code.co_name} line {frame.f_lineno}, event - {event}')
53
+
54
+ # check if run into the decorated func
55
+ if self._func_firstlineno == frame.f_code.co_firstlineno and frame.f_code.co_name == self._func_name and frame.f_code.co_filename == self._func_filename:
56
+
57
+ # --- obtain info for current hit ---
58
+ # cuda related
59
+ if self.cuda_sync is True:
60
+ torch.cuda.synchronize()
61
+
62
+ current_time = timeit.default_timer()
63
+ memory = self._process.memory_info().rss
64
+ gpu_memory = torch.cuda.memory_allocated()
65
+ # --- ends ---
66
+
67
+ # --- initilize the info when first hit ---
68
+ if frame.f_lineno not in self._times: # first hit time for this line
69
+ self._times[frame.f_lineno] = {'hit':0, 'time': 0}
70
+ self._memory[frame.f_lineno] = 0
71
+ self._gpu_memory[frame.f_lineno] = 0
72
+ # --- ends ---
73
+
74
+ # --- record info before call the decorated func ---
75
+ # 'call' - before call the func
76
+ if event == 'call':
77
+ self._last_time = current_time
78
+ self._last_lineno = frame.f_lineno
79
+ self._last_memory = memory
80
+ self._last_gpu_memory = gpu_memory
81
+
82
+ if self.warmup_lineno < 0:
83
+ self.warmup_counter -= 1
84
+ if self.warmup_counter < 0:
85
+ self._func_hit_count += 1
86
+ # --- ends ---
87
+
88
+ # 'line' - after excuting the line
89
+ # 'return' - return from the function
90
+ if event == 'line' or event == 'return':
91
+
92
+ if event == 'line' and self.warmup_counter < 0:
93
+ self._times[frame.f_lineno]['hit'] += 1
94
+
95
+
96
+ # --- obtain the memory and time consumed by this line ---
97
+ if self.warmup_counter < 0:
98
+ self._times[self._last_lineno]['time'] += current_time - self._last_time
99
+ self._memory[self._last_lineno] += memory - self._last_memory
100
+ self._gpu_memory[self._last_lineno] += gpu_memory - self._gpu_last_memory
101
+ # --- ends ---
102
+
103
+ if self.cuda_sync is True:
104
+ torch.cuda.synchronize()
105
+
106
+ self._last_time = timeit.default_timer()
107
+ self._last_memory = memory
108
+ self._gpu_last_memory = gpu_memory
109
+ self._last_lineno = frame.f_lineno
110
+
111
+ return self.__trace_func__
112
+
113
+ def decorate(self, func):
114
+ if self._func_name is not None:
115
+ logger.warning(f'Only support decorate only one func. Aready decorated "{self._func_name}"')
116
+ self._func_name = func.__name__
117
+ self._func_filename = func.__code__.co_filename
118
+ self._func_firstlineno = func.__code__.co_firstlineno
119
+
120
+ def _f(*args, **kwargs):
121
+ origin_trace_func = sys.gettrace()
122
+ sys.settrace(self.__trace_func__)
123
+ ret = func(*args, **kwargs)
124
+ sys.settrace(origin_trace_func)
125
+ return ret
126
+ return _f
127
+
128
+ def _get_table(self):
129
+
130
+ if len(self._times) <= 0:
131
+ logger.warning(f"un recorded datas, please ensure the function is executed")
132
+ return None
133
+
134
+ # --- load the source code ---
135
+ with open(self._func_filename, 'r') as f:
136
+ source_lines = [line.strip('\n') for line in f.readlines()]
137
+ code_str = "\n".join(source_lines)
138
+
139
+ def_lineno = min(self._times.keys())
140
+ final_lineno = max(self._times.keys())
141
+
142
+ # remove the additional blank content
143
+ pre_blank_count = len(source_lines[def_lineno-1]) - len(source_lines[def_lineno-1].lstrip(' ').lstrip('\t'))
144
+ # --- ends ---
145
+
146
+ # --- analysize the source code and collect infos for multi-line code ---
147
+ new_logic_linenos = [token.start[0] for token in tokenize.generate_tokens(
148
+ io.StringIO(code_str).readline) if token.type == 4]
149
+ # --- ends ---
150
+
151
+ # --- merge the stats multi-line code ---
152
+ sorted_linenos = [lineno for lineno in self._times.keys()]
153
+ sorted_linenos.sort(key=int)
154
+
155
+ lineno_cache = []
156
+ for lineno in sorted_linenos:
157
+ if lineno not in new_logic_linenos:
158
+ lineno_cache.append(lineno)
159
+ else:
160
+ # we should merge its info to the prev_lineno
161
+ if len(lineno_cache) <= 0:
162
+ continue
163
+ else:
164
+ lineno_cache.append(lineno)
165
+ first_lineno = lineno_cache[0]
166
+ for prev_lineno in lineno_cache[1:]:
167
+ self._times[first_lineno]["hit"] = min(self._times[first_lineno]["hit"], self._times[prev_lineno]["hit"])
168
+ self._times[first_lineno]["time"] += self._times[prev_lineno]["time"]
169
+ del self._times[prev_lineno]
170
+
171
+ self._memory[first_lineno] += self._memory[prev_lineno]
172
+ del self._memory[prev_lineno]
173
+
174
+ self._gpu_memory[first_lineno] += self._gpu_memory[prev_lineno]
175
+ del self._gpu_memory[prev_lineno]
176
+ lineno_cache = []
177
+ # --- ends ---
178
+
179
+ # --- initialize the pretty table for output ---
180
+ table = PrettyTable(['lineno', 'hits', 'time', 'time per hit', 'hit perc', 'time perc', 'mem inc', 'mem peak', 'gpu mem inc', 'gpu mem peak'])
181
+ # --- ends ---
182
+
183
+ # --- compute some statisticals ---
184
+ total_hit = 0 # for compute the hit percentage
185
+ total_time = 0
186
+ for lineno, stats in self._times.items():
187
+ if lineno == def_lineno: continue
188
+ total_hit += stats['hit']
189
+ total_time += stats['time']
190
+
191
+ total_memory = sum([m for l,m in self._memory.items()]) / 1024 / 1024
192
+ total_gpu_memory = sum([m for l,m in self._gpu_memory.items()]) / 1024 / 1024
193
+ # --- ends ---
194
+
195
+ peak_cpu_memory = 0
196
+ peak_gpu_memory = 0
197
+ for lineno in range(def_lineno, final_lineno+1):
198
+ if lineno not in self._times:
199
+ # the comment line, empty line or merged line from multi-lines code
200
+ table.add_row([lineno, '-', '-', '-', '-', '-', '-',f'{peak_cpu_memory:5.3f} MB', '-', f'{peak_gpu_memory:5.3f} MB'])
201
+ else:
202
+ stats = self._times[lineno]
203
+ if lineno == def_lineno:
204
+ table.add_row([lineno, self._func_hit_count, f'{total_time*self.base_n:.4f} {self.base}', f'{total_time/self._func_hit_count*self.base_n:.4f} {self.base}', '-', '-', f'{total_memory:5.3f} MB', 'baseline', f'{total_gpu_memory:5.3f} MB', 'baseline'])
205
+ else:
206
+
207
+ line_result = [lineno, stats['hit'],
208
+ f'{stats["time"]*self.base_n:.4f} {self.base}',
209
+ f'{stats["time"]/stats["hit"]*self.base_n:.4f} {self.base}' if stats['hit'] > 0 else 'nan',
210
+ f'{stats["hit"]/total_hit*100:.3f}%' if total_hit > 0 else 'nan',
211
+ f'{stats["time"]/total_time*100:.3f}%'] if total_time > 0 else 'nan'
212
+
213
+ line_result += [f'{self._memory[lineno]/1024/1024:5.3f} MB' if stats['hit'] > 0 else '0 MB']
214
+ peak_cpu_memory = peak_cpu_memory + self._memory[lineno]/1024/1024
215
+ line_result += [f'{peak_cpu_memory:5.3f} MB']
216
+
217
+ line_result += [f'{self._gpu_memory[lineno]/1024/1024:5.3f} MB' if stats['hit'] > 0 else '0 MB']
218
+ peak_gpu_memory = peak_gpu_memory + self._gpu_memory[lineno]/1024/1024
219
+ line_result += [f'{peak_gpu_memory:5.3f} MB']
220
+
221
+ table.add_row(line_result)
222
+
223
+ table.add_column('sources', [source_lines[i-1][pre_blank_count:] if len(source_lines[i-1])>pre_blank_count else '' for i in range(def_lineno, final_lineno+1)], 'l')
224
+ return table
225
+
226
+ def print(self, filename=None, mode="w"):
227
+ introducation = '''
228
+ 1. The first line of table reports the overall results of the whole function and the following lines reports the statistics of each line in the function.
229
+ 2. The `hit perc` and `time perc` represent `hit percentage` and `time percentage`.
230
+ 3. For memory, there exists four categories `mem inc`, `mem peak`, `gpu mem inc` and `gpu mem peak`. They denotes `cpu memory increasement`, `cpu memory peak`, `gpu memory increasement` and `gpu memory peak`. All the results are collected in the last run. The number in the increasement field denots the increasement of corresponding memory of each line (the first line is related to the whole function). Sometimes, the number of each line is far less of the number of the first line, which is valid since python may auto release the unused memory after the function execution. The number of each line in the peak filed is a simple sum of the numbers of above lines in the increasement field, which is used to demonstrate the possible maxinum memory usage in the function.
231
+ 4. For any issue, please concact us via https://github.com/xtudbxk/lineprofiler or zhengqiang.zhang@hotmail.com
232
+ '''
233
+ print(introducation)
234
+
235
+ table = PrettyTable(['lineno', 'hits', 'time', 'time per hit', 'hit perc', 'time perc', 'mem inc', 'mem peak', 'gpu mem inc', 'gpu mem peak'])
236
+ table = self._get_table()
237
+ print(table)
238
+ if filename is not None:
239
+ with open(filename, mode) as f:
240
+ f.write(introducation)
241
+ f.write(f"args - base={self.base}, cuda_sync={self.cuda_sync}, gpuids={self.gpuids}, warmup={self.warmup}\n")
242
+ f.write(str(table))
243
+
244
+ if __name__ == '__main__':
245
+ import numpy as np
246
+ def mytest(h='hello',
247
+ xx="xx"):
248
+
249
+ h = h + 'world'
250
+ a = []
251
+ for _ in range(200):
252
+ # a = np.zeros((1000, 1000), dtype=np.float32)
253
+ a.append(np.zeros((1000, 1000), dtype=np.float32))
254
+ a.append(
255
+ np.zeros((1000, 1000),
256
+ dtype=np.float32))
257
+ # print(a[0,0])
258
+ print(h)
259
+
260
+ profiler = MyLineProfiler(cuda_sync=False, warmup=2)
261
+ mytest = profiler.decorate(mytest)
262
+ for _ in range(5):
263
+ mytest()
264
+ profiler.print()
utils/gs_cuda_dmax/profile.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+
6
+ from gswrapper import gaussiansplatting_render
7
+
8
+ def generate_2D_gaussian_splatting(kernel_size, sigma_x, sigma_y, rho, coords,
9
+ colours, image_size=(256, 256, 3), device="cuda"):
10
+
11
+ batch_size = colours.shape[0]
12
+
13
+ sigma_x = sigma_x.view(batch_size, 1, 1)
14
+ sigma_y = sigma_y.view(batch_size, 1, 1)
15
+ rho = rho.view(batch_size, 1, 1)
16
+
17
+ covariance = torch.stack(
18
+ [torch.stack([sigma_x**2, rho*sigma_x*sigma_y], dim=-1),
19
+ torch.stack([rho*sigma_x*sigma_y, sigma_y**2], dim=-1)],
20
+ dim=-2
21
+ )
22
+
23
+ # Check for positive semi-definiteness
24
+ # determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y)**2
25
+ # if (determinant <= 0).any():
26
+ # raise ValueError("Covariance matrix must be positive semi-definite")
27
+
28
+ inv_covariance = torch.inverse(covariance)
29
+
30
+ # Choosing quite a broad range for the distribution [-5,5] to avoid any clipping
31
+ start = torch.tensor([-5.0], device=device).view(-1, 1)
32
+ end = torch.tensor([5.0], device=device).view(-1, 1)
33
+ base_linspace = torch.linspace(0, 1, steps=kernel_size, device=device)
34
+ ax_batch = start + (end - start) * base_linspace
35
+
36
+ # Expanding dims for broadcasting
37
+ ax_batch_expanded_x = ax_batch.unsqueeze(-1).expand(-1, -1, kernel_size)
38
+ ax_batch_expanded_y = ax_batch.unsqueeze(1).expand(-1, kernel_size, -1)
39
+
40
+ # Creating a batch-wise meshgrid using broadcasting
41
+ xx, yy = ax_batch_expanded_x, ax_batch_expanded_y # (batchsize, kernelsize, kernelsize)
42
+
43
+ xy = torch.stack([xx, yy], dim=-1) # (batchsize, kernelsize, kernelsize, 2)
44
+ z = torch.einsum('b...i,b...ij,b...j->b...', xy, -0.5 * inv_covariance, xy) # (batchsize, kernelsize, kernelsize, 2)
45
+ kernel = torch.exp(z) / (2 * torch.tensor(np.pi, device=device) * torch.sqrt(torch.det(covariance)).view(batch_size, 1, 1)) # (batchsize, kernelsize, kernelsize)
46
+
47
+
48
+ kernel_max_1, _ = kernel.max(dim=-1, keepdim=True) # Find max along the last dimension
49
+ kernel_max_2, _ = kernel_max_1.max(dim=-2, keepdim=True) # Find max along the second-to-last dimension
50
+ kernel_normalized = kernel / kernel_max_2 # (batchsize, kernelsize, kernelsize)
51
+
52
+
53
+ kernel_reshaped = kernel_normalized.repeat(1, 3, 1).view(batch_size * 3, kernel_size, kernel_size)
54
+ kernel_rgb = kernel_reshaped.unsqueeze(0).reshape(batch_size, 3, kernel_size, kernel_size) # (batchsize, 3, kernelsize, kernelsize)
55
+
56
+ # Calculating the padding needed to match the image size
57
+ pad_h = image_size[0] - kernel_size
58
+ pad_w = image_size[1] - kernel_size
59
+
60
+ if pad_h < 0 or pad_w < 0:
61
+ raise ValueError("Kernel size should be smaller or equal to the image size.")
62
+
63
+ # Adding padding to make kernel size equal to the image size
64
+ padding = (pad_w // 2, pad_w // 2 + pad_w % 2, # padding left and right
65
+ pad_h // 2, pad_h // 2 + pad_h % 2) # padding top and bottom
66
+
67
+ kernel_rgb_padded = torch.nn.functional.pad(kernel_rgb, padding, "constant", 0) # (batchsize, 3, h, w)
68
+
69
+ # Extracting shape information
70
+ b, c, h, w = kernel_rgb_padded.shape
71
+
72
+ # Create a batch of 2D affine matrices
73
+ theta = torch.zeros(b, 2, 3, dtype=torch.float32, device=device)
74
+ theta[:, 0, 0] = 1.0
75
+ theta[:, 1, 1] = 1.0
76
+ theta[:, :, 2] = -coords # (b, 2) - the offset of gaussian splating
77
+
78
+ # Creating grid and performing grid sampling
79
+ grid = F.affine_grid(theta, size=(b, c, h, w), align_corners=True) # (b, 3, h, w)
80
+ # grid_y = torch.linspace(-1, 1, steps=h, device=device).reshape(1, h, 1, 1).repeat(1, 1, w, 1)
81
+ # grid_x = torch.linspace(-1, 1, steps=w, device=device).reshape(1, 1, w, 1).repeat(1, h, 1, 1)
82
+ # grid = torch.cat([grid_x, grid_y], dim=-1)
83
+ # grid = grid - coords.reshape(-1, 1, 1, 2)
84
+
85
+ kernel_rgb_padded_translated = F.grid_sample(kernel_rgb_padded, grid, align_corners=True) # (b, 3, h, w)
86
+
87
+ rgb_values_reshaped = colours.unsqueeze(-1).unsqueeze(-1)
88
+
89
+ final_image_layers = rgb_values_reshaped * kernel_rgb_padded_translated
90
+ final_image = final_image_layers.sum(dim=0)
91
+ # final_image = torch.clamp(final_image, 0, 1)
92
+ final_image = final_image.permute(1,2,0)
93
+
94
+ return final_image
95
+
96
+
97
+ if __name__ == "__main__":
98
+ from mylineprofiler import MyLineProfiler
99
+ profiler_th = MyLineProfiler(cuda_sync=True)
100
+ generate_2D_gaussian_splatting = profiler_th.decorate(generate_2D_gaussian_splatting)
101
+ profiler_cuda = MyLineProfiler(cuda_sync=True)
102
+ gaussiansplatting_render = profiler_cuda.decorate(gaussiansplatting_render)
103
+
104
+
105
+ # --- test ---
106
+ # s = 1000
107
+ s = 5
108
+ # image_size = (512, 512, 3)
109
+ image_size = (511, 511, 3)
110
+ # image_size = (256, 512, 3)
111
+ # image_size = (256, 256, 3)
112
+
113
+ sigmas = 0.999*torch.rand(s, 3).to(torch.float32).to("cuda")
114
+ sigmas[:,:2] = 5*sigmas[:, :2]
115
+ coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
116
+ colors = torch.rand(s, 3).to(torch.float32).to("cuda")
117
+
118
+ # --- torch version ---
119
+ import gc
120
+ gc.collect()
121
+ torch.cuda.empty_cache()
122
+ for _ in range(20):
123
+ img = generate_2D_gaussian_splatting(101, sigmas[:,1], sigmas[:,0], sigmas[:,2], coords, colors, image_size)
124
+ profiler_th.print("profile.log", "w")
125
+ cv2.imwrite("th.png", 255.0 * img.detach().clamp(0, 1).cpu().numpy())
126
+ # --- ends ---
127
+
128
+ # --- cuda version ---
129
+ _stepsize_of_gs_th = 10 / (101-1)
130
+ _stepsize_of_gs_cuda_w = 2 / (image_size[1]-1)
131
+ _stepsize_of_gs_cuda_h = 2 / (image_size[0]-1)
132
+ sigmas[:, 0] = sigmas[:, 0] * _stepsize_of_gs_cuda_w / _stepsize_of_gs_th
133
+ sigmas[:, 1] = sigmas[:, 1] * _stepsize_of_gs_cuda_h / _stepsize_of_gs_th
134
+ dmax = 101/2*_stepsize_of_gs_cuda_w
135
+ gc.collect()
136
+ torch.cuda.empty_cache()
137
+ for _ in range(20):
138
+ img = gaussiansplatting_render(sigmas, coords, colors, image_size, dmax)
139
+
140
+ profiler_cuda.print("profile.log", "a")
141
+ cv2.imwrite("cuda.png", 255.0 * img.detach().clamp(0, 1).cpu().numpy())
142
+ # --- ends ---
utils/hatropeamp.py ADDED
@@ -0,0 +1,1156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.checkpoint import checkpoint
5
+ import torch.nn.functional as F
6
+ import collections.abc
7
+ from itertools import repeat
8
+
9
+ from functools import partial
10
+ from typing import Any, Optional, Tuple
11
+
12
+ from einops import rearrange
13
+
14
+ # From PyTorch
15
+ def _ntuple(n):
16
+
17
+ def parse(x):
18
+ if isinstance(x, collections.abc.Iterable):
19
+ return x
20
+ return tuple(repeat(x, n))
21
+
22
+ return parse
23
+
24
+
25
+ to_1tuple = _ntuple(1)
26
+ to_2tuple = _ntuple(2)
27
+ to_3tuple = _ntuple(3)
28
+ to_4tuple = _ntuple(4)
29
+ to_ntuple = _ntuple
30
+
31
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
32
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
33
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
34
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
35
+ def norm_cdf(x):
36
+ # Computes standard normal cumulative distribution function
37
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
38
+
39
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
40
+ warnings.warn(
41
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
42
+ 'The distribution of values may be incorrect.',
43
+ stacklevel=2)
44
+
45
+ with torch.no_grad():
46
+ # Values are generated by using a truncated uniform distribution and
47
+ # then using the inverse CDF for the normal distribution.
48
+ # Get upper and lower cdf values
49
+ low = norm_cdf((a - mean) / std)
50
+ up = norm_cdf((b - mean) / std)
51
+
52
+ # Uniformly fill tensor with values from [low, up], then translate to
53
+ # [2l-1, 2u-1].
54
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
55
+
56
+ # Use inverse cdf transform for normal distribution to get truncated
57
+ # standard normal
58
+ tensor.erfinv_()
59
+
60
+ # Transform to proper mean, std
61
+ tensor.mul_(std * math.sqrt(2.))
62
+ tensor.add_(mean)
63
+
64
+ # Clamp to ensure it's in the proper range
65
+ tensor.clamp_(min=a, max=b)
66
+ return tensor
67
+
68
+
69
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
70
+ r"""Fills the input Tensor with values drawn from a truncated
71
+ normal distribution.
72
+
73
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
74
+
75
+ The values are effectively drawn from the
76
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
77
+ with values outside :math:`[a, b]` redrawn until they are within
78
+ the bounds. The method used for generating the random values works
79
+ best when :math:`a \leq \text{mean} \leq b`.
80
+
81
+ Args:
82
+ tensor: an n-dimensional `torch.Tensor`
83
+ mean: the mean of the normal distribution
84
+ std: the standard deviation of the normal distribution
85
+ a: the minimum cutoff value
86
+ b: the maximum cutoff value
87
+
88
+ Examples:
89
+ >>> w = torch.empty(3, 5)
90
+ >>> nn.init.trunc_normal_(w)
91
+ """
92
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
93
+
94
+ def init_t_xy(end_x: int, end_y: int, zero_center=False):
95
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
96
+ t_x = (t % end_x).float()
97
+ t_y = torch.div(t, end_x, rounding_mode='floor').float()
98
+
99
+ return t_x, t_y
100
+
101
+ def init_random_2d_freqs(head_dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
102
+ freqs_x = []
103
+ freqs_y = []
104
+ theta = theta
105
+ mag = 1 / (theta ** (torch.arange(0, head_dim, 4)[: (head_dim // 4)].float() / head_dim))
106
+ for i in range(num_heads):
107
+ angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
108
+ fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
109
+ fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
110
+ freqs_x.append(fx)
111
+ freqs_y.append(fy)
112
+ freqs_x = torch.stack(freqs_x, dim=0)
113
+ freqs_y = torch.stack(freqs_y, dim=0)
114
+ freqs = torch.stack([freqs_x, freqs_y], dim=0)
115
+ return freqs
116
+
117
+ def compute_cis(freqs, t_x, t_y):
118
+ N = t_x.shape[0]
119
+ # No float 16 for this range
120
+ with torch.cuda.amp.autocast(enabled=False):
121
+ freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
122
+ freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
123
+ freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)
124
+
125
+ return freqs_cis
126
+
127
+
128
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
129
+ ndim = x.ndim
130
+ assert 0 <= 1 < ndim
131
+ # assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
132
+ # print(f"freqs_cis shape is {freqs_cis.shape}, x shape is {x.shape}")
133
+ if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
134
+ shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
135
+ elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
136
+ shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
137
+
138
+ return freqs_cis.view(*shape)
139
+
140
+ def apply_rotary_emb(
141
+ xq: torch.Tensor,
142
+ xk: torch.Tensor,
143
+ freqs_cis: torch.Tensor,
144
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
145
+ # print(f"xq shape is {xq.shape}, xq.shape[:-1] is {xq.shape[:-1]}")
146
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
147
+ # print(f"xq_ shape is {xq_.shape}")
148
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
149
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
150
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
151
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
152
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
153
+
154
+ def apply_rotary_emb_single(x, freqs_cis):
155
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
156
+ seq_len = x_.shape[2]
157
+ freqs_cis = freqs_cis[:, :seq_len, :]
158
+ freqs_cis = freqs_cis.unsqueeze(0).expand_as(x_)
159
+ x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
160
+ return x_out.type_as(x).to(x.device)
161
+
162
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
163
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
164
+
165
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
166
+ """
167
+ if drop_prob == 0. or not training:
168
+ return x
169
+ keep_prob = 1 - drop_prob
170
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
171
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
172
+ random_tensor.floor_() # binarize
173
+ output = x.div(keep_prob) * random_tensor
174
+ return output
175
+
176
+
177
+ class DropPath(nn.Module):
178
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
179
+
180
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
181
+ """
182
+
183
+ def __init__(self, drop_prob=None):
184
+ super(DropPath, self).__init__()
185
+ self.drop_prob = drop_prob
186
+
187
+ def forward(self, x):
188
+ return drop_path(x, self.drop_prob, self.training)
189
+
190
+
191
+ class ChannelAttention(nn.Module):
192
+ """Channel attention used in RCAN.
193
+ Args:
194
+ num_feat (int): Channel number of intermediate features.
195
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
196
+ """
197
+
198
+ def __init__(self, num_feat, squeeze_factor=16):
199
+ super(ChannelAttention, self).__init__()
200
+ self.attention = nn.Sequential(
201
+ nn.AdaptiveAvgPool2d(1),
202
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
203
+ nn.ReLU(inplace=True),
204
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
205
+ nn.Sigmoid())
206
+
207
+ def forward(self, x):
208
+ y = self.attention(x)
209
+ return x * y
210
+
211
+
212
+ class CAB(nn.Module):
213
+
214
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
215
+ super(CAB, self).__init__()
216
+
217
+ self.cab = nn.Sequential(
218
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
219
+ nn.GELU(),
220
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
221
+ ChannelAttention(num_feat, squeeze_factor)
222
+ )
223
+
224
+ def forward(self, x):
225
+ return self.cab(x)
226
+
227
+
228
+ class Mlp(nn.Module):
229
+
230
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
231
+ super().__init__()
232
+ out_features = out_features or in_features
233
+ hidden_features = hidden_features or in_features
234
+ self.fc1 = nn.Linear(in_features, hidden_features)
235
+ self.act = act_layer()
236
+ self.fc2 = nn.Linear(hidden_features, out_features)
237
+ self.drop = nn.Dropout(drop)
238
+
239
+ def forward(self, x):
240
+ x = self.fc1(x)
241
+ x = self.act(x)
242
+ x = self.drop(x)
243
+ x = self.fc2(x)
244
+ x = self.drop(x)
245
+ return x
246
+
247
+
248
+ def window_partition(x, window_size):
249
+ """
250
+ Args:
251
+ x: (b, h, w, c)
252
+ window_size (int): window size
253
+
254
+ Returns:
255
+ windows: (num_windows*b, window_size, window_size, c)
256
+ """
257
+ b, h, w, c = x.shape
258
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
259
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
260
+ return windows
261
+
262
+
263
+ def window_reverse(windows, window_size, h, w):
264
+ """
265
+ Args:
266
+ windows: (num_windows*b, window_size, window_size, c)
267
+ window_size (int): Window size
268
+ h (int): Height of image
269
+ w (int): Width of image
270
+
271
+ Returns:
272
+ x: (b, h, w, c)
273
+ """
274
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
275
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
276
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
277
+ return x
278
+
279
+
280
+ class WindowAttention(nn.Module):
281
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
282
+ It supports both of shifted and non-shifted window.
283
+
284
+ Args:
285
+ dim (int): Number of input channels.
286
+ window_size (tuple[int]): The height and width of the window.
287
+ num_heads (int): Number of attention heads.
288
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
289
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
290
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
291
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
292
+ """
293
+
294
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., rope_mixed = True, rope_theta=10.0):
295
+
296
+ super().__init__()
297
+ self.dim = dim
298
+ self.window_size = window_size # Wh, Ww
299
+ self.num_heads = num_heads
300
+ head_dim = dim // num_heads
301
+
302
+ self.rope_mixed = rope_mixed
303
+ t_x, t_y = init_t_xy(end_x=self.window_size[1], end_y=self.window_size[0])
304
+ self.register_buffer('rope_t_x', t_x)
305
+ self.register_buffer('rope_t_y', t_y)
306
+
307
+ freqs = init_random_2d_freqs(
308
+ head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
309
+ rotate=self.rope_mixed
310
+ )
311
+ if self.rope_mixed:
312
+ self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
313
+ else:
314
+ self.register_buffer('rope_freqs', freqs)
315
+ freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
316
+ self.rope_freqs_cis = freqs_cis
317
+
318
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
319
+ self.attn_drop = nn.Dropout(attn_drop)
320
+ self.proj = nn.Linear(dim, dim)
321
+
322
+ self.proj_drop = nn.Dropout(proj_drop)
323
+
324
+
325
+ def forward(self, x, rpi, mask=None):
326
+ """
327
+ Args:
328
+ x: input features with shape of (num_windows*b, n, c)
329
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
330
+ """
331
+ b_, n, c = x.shape
332
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
333
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
334
+
335
+ ###### Apply rotary position embedding
336
+ if self.rope_mixed:
337
+ freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
338
+ else:
339
+ freqs_cis = self.rope_freqs_cis.to(x.device)
340
+ q, k = apply_rotary_emb(q, k, freqs_cis)
341
+ #########
342
+
343
+ attn = F.scaled_dot_product_attention(q, k, v)
344
+
345
+ attn = attn.transpose(1, 2).reshape(b_, n, c)
346
+
347
+ x = self.proj(attn)
348
+ x = self.proj_drop(x)
349
+ return x
350
+
351
+
352
+ class HAB(nn.Module):
353
+ r""" Hybrid Attention Block.
354
+
355
+ Args:
356
+ dim (int): Number of input channels.
357
+ input_resolution (tuple[int]): Input resolution.
358
+ num_heads (int): Number of attention heads.
359
+ window_size (int): Window size.
360
+ shift_size (int): Shift size for SW-MSA.
361
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
362
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
363
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
364
+ drop (float, optional): Dropout rate. Default: 0.0
365
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
366
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
367
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
368
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
369
+ """
370
+
371
+ def __init__(self,
372
+ dim,
373
+ input_resolution,
374
+ num_heads,
375
+ window_size=7,
376
+ shift_size=0,
377
+ compress_ratio=3,
378
+ squeeze_factor=30,
379
+ conv_scale=0.01,
380
+ mlp_ratio=4.,
381
+ qkv_bias=True,
382
+ qk_scale=None,
383
+ drop=0.,
384
+ attn_drop=0.,
385
+ drop_path=0.,
386
+ act_layer=nn.GELU,
387
+ norm_layer=nn.LayerNorm,
388
+ rope_mixed = True, rope_theta=10.0):
389
+ super().__init__()
390
+ self.dim = dim
391
+ self.input_resolution = input_resolution
392
+ self.num_heads = num_heads
393
+ self.window_size = window_size
394
+ self.shift_size = shift_size
395
+ self.mlp_ratio = mlp_ratio
396
+ if min(self.input_resolution) <= self.window_size:
397
+ # if window size is larger than input resolution, we don't partition windows
398
+ self.shift_size = 0
399
+ self.window_size = min(self.input_resolution)
400
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
401
+
402
+ self.norm1 = norm_layer(dim)
403
+ self.attn = WindowAttention(
404
+ dim,
405
+ window_size=to_2tuple(self.window_size),
406
+ num_heads=num_heads,
407
+ qkv_bias=qkv_bias,
408
+ qk_scale=qk_scale,
409
+ attn_drop=attn_drop,
410
+ proj_drop=drop,
411
+ rope_mixed = rope_mixed, rope_theta=rope_theta)
412
+
413
+ self.conv_scale = conv_scale
414
+ self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
415
+
416
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
417
+ self.norm2 = norm_layer(dim)
418
+ mlp_hidden_dim = int(dim * mlp_ratio)
419
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
420
+
421
+ def forward(self, x, x_size, rpi_sa, attn_mask):
422
+ h, w = x_size
423
+ b, _, c = x.shape
424
+ # assert seq_len == h * w, "input feature has wrong size"
425
+
426
+ shortcut = x
427
+ x = self.norm1(x)
428
+ x = x.view(b, h, w, c)
429
+
430
+ # Conv_X
431
+ conv_x = self.conv_block(x.permute(0, 3, 1, 2).contiguous())
432
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
433
+
434
+ # cyclic shift
435
+ if self.shift_size > 0:
436
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
437
+ attn_mask = attn_mask
438
+ else:
439
+ shifted_x = x
440
+ attn_mask = None
441
+
442
+ # partition windows
443
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
444
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
445
+
446
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
447
+ attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
448
+
449
+ # merge windows
450
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
451
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
452
+
453
+ # reverse cyclic shift
454
+ if self.shift_size > 0:
455
+ attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
456
+ else:
457
+ attn_x = shifted_x
458
+ attn_x = attn_x.view(b, h * w, c)
459
+
460
+ # FFN
461
+ x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
462
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
463
+
464
+ return x
465
+
466
+
467
+ class PatchMerging(nn.Module):
468
+ r""" Patch Merging Layer.
469
+
470
+ Args:
471
+ input_resolution (tuple[int]): Resolution of input feature.
472
+ dim (int): Number of input channels.
473
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
474
+ """
475
+
476
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
477
+ super().__init__()
478
+ self.input_resolution = input_resolution
479
+ self.dim = dim
480
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
481
+ self.norm = norm_layer(4 * dim)
482
+
483
+ def forward(self, x):
484
+ """
485
+ x: b, h*w, c
486
+ """
487
+ h, w = self.input_resolution
488
+ b, seq_len, c = x.shape
489
+ assert seq_len == h * w, 'input feature has wrong size'
490
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
491
+
492
+ x = x.view(b, h, w, c)
493
+
494
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
495
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
496
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
497
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
498
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
499
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
500
+
501
+ x = self.norm(x)
502
+ x = self.reduction(x)
503
+
504
+ return x
505
+
506
+
507
+ class OCAB(nn.Module):
508
+ # overlapping cross-attention block
509
+
510
+ def __init__(self, dim,
511
+ input_resolution,
512
+ window_size,
513
+ overlap_ratio,
514
+ num_heads,
515
+ qkv_bias=True,
516
+ qk_scale=None,
517
+ mlp_ratio=2,
518
+ norm_layer=nn.LayerNorm,
519
+ rope_mixed = True, rope_theta = 10.0
520
+ ):
521
+
522
+ super().__init__()
523
+ self.dim = dim
524
+ self.input_resolution = input_resolution
525
+ self.window_size = window_size
526
+ self.num_heads = num_heads
527
+ head_dim = dim // num_heads
528
+ self.rope_mixed = rope_mixed
529
+
530
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
531
+
532
+ self.norm1 = norm_layer(dim)
533
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
534
+ self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
535
+
536
+ t_x, t_y = init_t_xy(end_x=max(self.window_size, self.overlap_win_size), end_y=max(self.window_size, self.overlap_win_size))
537
+ self.register_buffer('rope_t_x', t_x)
538
+ self.register_buffer('rope_t_y', t_y)
539
+
540
+ freqs = init_random_2d_freqs(
541
+ head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
542
+ rotate=self.rope_mixed
543
+ )
544
+ if self.rope_mixed:
545
+ self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
546
+ else:
547
+ self.register_buffer('rope_freqs', freqs)
548
+ freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
549
+ self.rope_freqs_cis = freqs_cis
550
+
551
+
552
+ self.proj = nn.Linear(dim,dim)
553
+
554
+ self.norm2 = norm_layer(dim)
555
+ mlp_hidden_dim = int(dim * mlp_ratio)
556
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
557
+
558
+ def forward(self, x, x_size, rpi):
559
+ h, w = x_size
560
+ b, _, c = x.shape
561
+
562
+ shortcut = x
563
+ x = self.norm1(x)
564
+ x = x.view(b, h, w, c)
565
+
566
+ qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2).contiguous() # 3, b, c, h, w
567
+ q = qkv[0].permute(0, 2, 3, 1).contiguous() # b, h, w, c
568
+ kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
569
+
570
+ # partition windows
571
+ q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c
572
+ q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
573
+
574
+ kv_windows = self.unfold(kv) # b, c*w*w, nw
575
+ kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch', nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous() # 2, nw*b, ow*ow, c
576
+ k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
577
+
578
+ b_, nq, _ = q_windows.shape
579
+ _, n, _ = k_windows.shape
580
+ # print(f"nq is {nq}, n is {n}")
581
+ d = self.dim // self.num_heads
582
+ q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3).contiguous() # nw*b, nH, nq, d
583
+ k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3).contiguous() # nw*b, nH, n, d
584
+ v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3).contiguous() # nw*b, nH, n, d
585
+
586
+ ###### Apply rotary position embedding
587
+ if self.rope_mixed:
588
+ freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
589
+ else:
590
+ freqs_cis = self.rope_freqs_cis.to(x.device)
591
+ q = apply_rotary_emb_single(q, freqs_cis)
592
+ k = apply_rotary_emb_single(k, freqs_cis)
593
+ #########
594
+
595
+ attn = F.scaled_dot_product_attention(q, k, v)
596
+ attn_windows = attn.transpose(1, 2).reshape(b_, nq, self.dim)
597
+
598
+ # merge windows
599
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
600
+ x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
601
+ x = x.view(b, h * w, self.dim)
602
+
603
+ x = self.proj(x) + shortcut
604
+
605
+ x = x + self.mlp(self.norm2(x))
606
+ return x
607
+
608
+
609
+ class AttenBlocks(nn.Module):
610
+ """ A series of attention blocks for one RHAG.
611
+
612
+ Args:
613
+ dim (int): Number of input channels.
614
+ input_resolution (tuple[int]): Input resolution.
615
+ depth (int): Number of blocks.
616
+ num_heads (int): Number of attention heads.
617
+ window_size (int): Local window size.
618
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
619
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
620
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
621
+ drop (float, optional): Dropout rate. Default: 0.0
622
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
623
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
624
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
625
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
626
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
627
+ """
628
+
629
+ def __init__(self,
630
+ dim,
631
+ input_resolution,
632
+ depth,
633
+ num_heads,
634
+ window_size,
635
+ compress_ratio,
636
+ squeeze_factor,
637
+ conv_scale,
638
+ overlap_ratio,
639
+ mlp_ratio=4.,
640
+ qkv_bias=True,
641
+ qk_scale=None,
642
+ drop=0.,
643
+ attn_drop=0.,
644
+ drop_path=0.,
645
+ norm_layer=nn.LayerNorm,
646
+ downsample=None,
647
+ use_checkpoint=False,
648
+ rope_mixed = True, rope_theta=10.0):
649
+
650
+ super().__init__()
651
+ self.dim = dim
652
+ self.input_resolution = input_resolution
653
+ self.depth = depth
654
+ self.use_checkpoint = use_checkpoint
655
+
656
+ # build blocks
657
+ self.blocks = nn.ModuleList([
658
+ HAB(
659
+ dim=dim,
660
+ input_resolution=input_resolution,
661
+ num_heads=num_heads,
662
+ window_size=window_size,
663
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
664
+ compress_ratio=compress_ratio,
665
+ squeeze_factor=squeeze_factor,
666
+ conv_scale=conv_scale,
667
+ mlp_ratio=mlp_ratio,
668
+ qkv_bias=qkv_bias,
669
+ qk_scale=qk_scale,
670
+ drop=drop,
671
+ attn_drop=attn_drop,
672
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
673
+ norm_layer=norm_layer,
674
+ rope_mixed = rope_mixed, rope_theta=rope_theta) for i in range(depth)
675
+ ])
676
+
677
+ # OCAB
678
+ self.overlap_attn = OCAB(
679
+ dim=dim,
680
+ input_resolution=input_resolution,
681
+ window_size=window_size,
682
+ overlap_ratio=overlap_ratio,
683
+ num_heads=num_heads,
684
+ qkv_bias=qkv_bias,
685
+ qk_scale=qk_scale,
686
+ mlp_ratio=mlp_ratio,
687
+ norm_layer=norm_layer,
688
+ rope_mixed = rope_mixed, rope_theta = rope_theta)
689
+
690
+
691
+ # patch merging layer
692
+ if downsample is not None:
693
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
694
+ else:
695
+ self.downsample = None
696
+
697
+ def forward(self, x, x_size, params):
698
+ for blk in self.blocks:
699
+ x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
700
+
701
+
702
+ x = self.overlap_attn(x, x_size, params['rpi_oca'])
703
+
704
+
705
+ if self.downsample is not None:
706
+ x = self.downsample(x)
707
+ return x
708
+
709
+
710
+ class RHAG(nn.Module):
711
+ """Residual Hybrid Attention Group (RHAG).
712
+
713
+ Args:
714
+ dim (int): Number of input channels.
715
+ input_resolution (tuple[int]): Input resolution.
716
+ depth (int): Number of blocks.
717
+ num_heads (int): Number of attention heads.
718
+ window_size (int): Local window size.
719
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
720
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
721
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
722
+ drop (float, optional): Dropout rate. Default: 0.0
723
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
724
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
725
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
726
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
727
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
728
+ img_size: Input image size.
729
+ patch_size: Patch size.
730
+ resi_connection: The convolutional block before residual connection.
731
+ """
732
+
733
+ def __init__(self,
734
+ dim,
735
+ input_resolution,
736
+ depth,
737
+ num_heads,
738
+ window_size,
739
+ compress_ratio,
740
+ squeeze_factor,
741
+ conv_scale,
742
+ overlap_ratio,
743
+ mlp_ratio=4.,
744
+ qkv_bias=True,
745
+ qk_scale=None,
746
+ drop=0.,
747
+ attn_drop=0.,
748
+ drop_path=0.,
749
+ norm_layer=nn.LayerNorm,
750
+ downsample=None,
751
+ use_checkpoint=False,
752
+ img_size=224,
753
+ patch_size=4,
754
+ resi_connection='1conv',
755
+ rope_mixed = True, rope_theta=10.0):
756
+ super(RHAG, self).__init__()
757
+
758
+ self.dim = dim
759
+ self.input_resolution = input_resolution
760
+
761
+ self.residual_group = AttenBlocks(
762
+ dim=dim,
763
+ input_resolution=input_resolution,
764
+ depth=depth,
765
+ num_heads=num_heads,
766
+ window_size=window_size,
767
+ compress_ratio=compress_ratio,
768
+ squeeze_factor=squeeze_factor,
769
+ conv_scale=conv_scale,
770
+ overlap_ratio=overlap_ratio,
771
+ mlp_ratio=mlp_ratio,
772
+ qkv_bias=qkv_bias,
773
+ qk_scale=qk_scale,
774
+ drop=drop,
775
+ attn_drop=attn_drop,
776
+ drop_path=drop_path,
777
+ norm_layer=norm_layer,
778
+ downsample=downsample,
779
+ use_checkpoint=use_checkpoint,
780
+ rope_mixed = rope_mixed, rope_theta=rope_theta)
781
+
782
+ if resi_connection == '1conv':
783
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
784
+ elif resi_connection == 'identity':
785
+ self.conv = nn.Identity()
786
+
787
+ self.patch_embed = PatchEmbed(
788
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
789
+
790
+ self.patch_unembed = PatchUnEmbed(
791
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
792
+
793
+ def forward(self, x, x_size, params):
794
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
795
+
796
+
797
+ class PatchEmbed(nn.Module):
798
+ r""" Image to Patch Embedding
799
+
800
+ Args:
801
+ img_size (int): Image size. Default: 224.
802
+ patch_size (int): Patch token size. Default: 4.
803
+ in_chans (int): Number of input image channels. Default: 3.
804
+ embed_dim (int): Number of linear projection output channels. Default: 96.
805
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
806
+ """
807
+
808
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
809
+ super().__init__()
810
+ img_size = to_2tuple(img_size)
811
+ patch_size = to_2tuple(patch_size)
812
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
813
+ self.img_size = img_size
814
+ self.patch_size = patch_size
815
+ self.patches_resolution = patches_resolution
816
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
817
+
818
+ self.in_chans = in_chans
819
+ self.embed_dim = embed_dim
820
+
821
+ if norm_layer is not None:
822
+ self.norm = norm_layer(embed_dim)
823
+ else:
824
+ self.norm = None
825
+
826
+ def forward(self, x):
827
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
828
+ if self.norm is not None:
829
+ x = self.norm(x)
830
+ return x
831
+
832
+
833
+ class PatchUnEmbed(nn.Module):
834
+ r""" Image to Patch Unembedding
835
+
836
+ Args:
837
+ img_size (int): Image size. Default: 224.
838
+ patch_size (int): Patch token size. Default: 4.
839
+ in_chans (int): Number of input image channels. Default: 3.
840
+ embed_dim (int): Number of linear projection output channels. Default: 96.
841
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
842
+ """
843
+
844
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
845
+ super().__init__()
846
+ img_size = to_2tuple(img_size)
847
+ patch_size = to_2tuple(patch_size)
848
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
849
+ self.img_size = img_size
850
+ self.patch_size = patch_size
851
+ self.patches_resolution = patches_resolution
852
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
853
+
854
+ self.in_chans = in_chans
855
+ self.embed_dim = embed_dim
856
+
857
+ def forward(self, x, x_size):
858
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
859
+ return x
860
+
861
+
862
+ class Upsample(nn.Sequential):
863
+ """Upsample module.
864
+
865
+ Args:
866
+ scale (int): Scale factor. Supported scales: 2^n and 3.
867
+ num_feat (int): Channel number of intermediate features.
868
+ """
869
+
870
+ def __init__(self, scale, num_feat):
871
+ m = []
872
+ if (scale & (scale - 1)) == 0: # scale = 2^n
873
+ for _ in range(int(math.log(scale, 2))):
874
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
875
+ m.append(nn.PixelShuffle(2))
876
+ elif scale == 3:
877
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
878
+ m.append(nn.PixelShuffle(3))
879
+ else:
880
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
881
+ super(Upsample, self).__init__(*m)
882
+
883
+
884
+
885
+
886
+ class HATNOUP_ROPE_AMP(nn.Module):
887
+ def __init__(self,
888
+ img_size=64,
889
+ patch_size=1,
890
+ in_chans=3,
891
+ embed_dim=192,
892
+ depths=(6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6),
893
+ num_heads=(6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6),
894
+ window_size=16,
895
+ compress_ratio=3,
896
+ squeeze_factor=32,
897
+ conv_scale=0.01,
898
+ overlap_ratio=0.5,
899
+ mlp_ratio=2,
900
+ qkv_bias=True,
901
+ qk_scale=None,
902
+ drop_rate=0.,
903
+ attn_drop_rate=0.,
904
+ drop_path_rate=0.1,
905
+ norm_layer=nn.LayerNorm,
906
+ ape=False,
907
+ patch_norm=True,
908
+ use_checkpoint=False,
909
+ upscale=4,
910
+ img_range=1.,
911
+ upsampler='pixelshuffle',
912
+ resi_connection='1conv',
913
+ rope_mixed = True,
914
+ rope_theta=10.0,
915
+ **kwargs):
916
+ super(HATNOUP_ROPE_AMP, self).__init__()
917
+
918
+ self.window_size = window_size
919
+ self.shift_size = window_size // 2
920
+ self.overlap_ratio = overlap_ratio
921
+
922
+ num_in_ch = in_chans
923
+ num_out_ch = in_chans
924
+ num_feat = 64
925
+ self.img_range = img_range
926
+ if in_chans == 3:
927
+ rgb_mean = (0.4488, 0.4371, 0.4040)
928
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
929
+ else:
930
+ self.mean = torch.zeros(1, 1, 1, 1)
931
+ self.upscale = upscale
932
+ self.upsampler = upsampler
933
+
934
+ # relative position index
935
+ relative_position_index_SA = self.calculate_rpi_sa()
936
+ relative_position_index_OCA = self.calculate_rpi_oca()
937
+ self.register_buffer('relative_position_index_SA', relative_position_index_SA)
938
+ self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
939
+
940
+ # ------------------------- 1, shallow feature extraction ------------------------- #
941
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
942
+
943
+ # ------------------------- 2, deep feature extraction ------------------------- #
944
+ self.num_layers = len(depths)
945
+ self.embed_dim = embed_dim
946
+ self.ape = ape
947
+ self.patch_norm = patch_norm
948
+ self.num_features = embed_dim
949
+ self.mlp_ratio = mlp_ratio
950
+
951
+ # split image into non-overlapping patches
952
+ self.patch_embed = PatchEmbed(
953
+ img_size=img_size,
954
+ patch_size=patch_size,
955
+ in_chans=embed_dim,
956
+ embed_dim=embed_dim,
957
+ norm_layer=norm_layer if self.patch_norm else None)
958
+ num_patches = self.patch_embed.num_patches
959
+ patches_resolution = self.patch_embed.patches_resolution
960
+ self.patches_resolution = patches_resolution
961
+
962
+ # merge non-overlapping patches into image
963
+ self.patch_unembed = PatchUnEmbed(
964
+ img_size=img_size,
965
+ patch_size=patch_size,
966
+ in_chans=embed_dim,
967
+ embed_dim=embed_dim,
968
+ norm_layer=norm_layer if self.patch_norm else None)
969
+
970
+ # absolute position embedding
971
+ if self.ape:
972
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
973
+ trunc_normal_(self.absolute_pos_embed, std=.02)
974
+
975
+ self.pos_drop = nn.Dropout(p=drop_rate)
976
+
977
+ # stochastic depth
978
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
979
+
980
+ # build Residual Hybrid Attention Groups (RHAG)
981
+ self.layers = nn.ModuleList()
982
+ for i_layer in range(self.num_layers):
983
+ layer = RHAG(
984
+ dim=embed_dim,
985
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
986
+ depth=depths[i_layer],
987
+ num_heads=num_heads[i_layer],
988
+ window_size=window_size,
989
+ compress_ratio=compress_ratio,
990
+ squeeze_factor=squeeze_factor,
991
+ conv_scale=conv_scale,
992
+ overlap_ratio=overlap_ratio,
993
+ mlp_ratio=self.mlp_ratio,
994
+ qkv_bias=qkv_bias,
995
+ qk_scale=qk_scale,
996
+ drop=drop_rate,
997
+ attn_drop=attn_drop_rate,
998
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
999
+ norm_layer=norm_layer,
1000
+ downsample=None,
1001
+ use_checkpoint=use_checkpoint,
1002
+ img_size=img_size,
1003
+ patch_size=patch_size,
1004
+ resi_connection=resi_connection,
1005
+ rope_mixed = rope_mixed, rope_theta=rope_theta)
1006
+ self.layers.append(layer)
1007
+ self.norm = norm_layer(self.num_features)
1008
+
1009
+ self.use_checkpoint = use_checkpoint
1010
+
1011
+ # build the last conv layer in deep feature extraction
1012
+ if resi_connection == '1conv':
1013
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1014
+ elif resi_connection == 'identity':
1015
+ self.conv_after_body = nn.Identity()
1016
+
1017
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
1018
+ if self.upsampler == 'pixelshuffle':
1019
+ # for classical SR
1020
+ self.conv_before_upsample = nn.Sequential(
1021
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
1022
+ # self.upsample = Upsample(upscale, num_feat)
1023
+ # self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1024
+
1025
+ self.apply(self._init_weights)
1026
+
1027
+ def _init_weights(self, m):
1028
+ if isinstance(m, nn.Linear):
1029
+ trunc_normal_(m.weight, std=.02)
1030
+ if isinstance(m, nn.Linear) and m.bias is not None:
1031
+ nn.init.constant_(m.bias, 0)
1032
+ elif isinstance(m, nn.LayerNorm):
1033
+ nn.init.constant_(m.bias, 0)
1034
+ nn.init.constant_(m.weight, 1.0)
1035
+
1036
+ def calculate_rpi_sa(self):
1037
+ # calculate relative position index for SA
1038
+ coords_h = torch.arange(self.window_size)
1039
+ coords_w = torch.arange(self.window_size)
1040
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1041
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1042
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
1043
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
1044
+ relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
1045
+ relative_coords[:, :, 1] += self.window_size - 1
1046
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
1047
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1048
+ return relative_position_index
1049
+
1050
+ def calculate_rpi_oca(self):
1051
+ # calculate relative position index for OCA
1052
+ window_size_ori = self.window_size
1053
+ window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
1054
+
1055
+ coords_h = torch.arange(window_size_ori)
1056
+ coords_w = torch.arange(window_size_ori)
1057
+ coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
1058
+ coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
1059
+
1060
+ coords_h = torch.arange(window_size_ext)
1061
+ coords_w = torch.arange(window_size_ext)
1062
+ coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
1063
+ coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
1064
+
1065
+ relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None] # 2, ws*ws, wse*wse
1066
+
1067
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # ws*ws, wse*wse, 2
1068
+ relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1 # shift to start from 0
1069
+ relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
1070
+
1071
+ relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
1072
+ relative_position_index = relative_coords.sum(-1)
1073
+ return relative_position_index
1074
+
1075
+ def calculate_mask(self, x_size):
1076
+ # calculate attention mask for SW-MSA
1077
+ h, w = x_size
1078
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
1079
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
1080
+ -self.shift_size), slice(-self.shift_size, None))
1081
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
1082
+ -self.shift_size), slice(-self.shift_size, None))
1083
+ cnt = 0
1084
+ for h in h_slices:
1085
+ for w in w_slices:
1086
+ img_mask[:, h, w, :] = cnt
1087
+ cnt += 1
1088
+
1089
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
1090
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1091
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1092
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
1093
+
1094
+ return attn_mask
1095
+
1096
+ @torch.jit.ignore
1097
+ def no_weight_decay(self):
1098
+ return {'absolute_pos_embed'}
1099
+
1100
+ @torch.jit.ignore
1101
+ def no_weight_decay_keywords(self):
1102
+ return {'relative_position_bias_table'}
1103
+
1104
+ def forward_features(self, x):
1105
+ x_size = (x.shape[2], x.shape[3])
1106
+
1107
+ # Calculate attention mask and relative position index in advance to speed up inference.
1108
+ # The original code is very time-consuming for large window size.
1109
+ attn_mask = self.calculate_mask(x_size).to(x.device)
1110
+ params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
1111
+
1112
+ x = self.patch_embed(x)
1113
+ if self.ape:
1114
+ x = x + self.absolute_pos_embed
1115
+ x = self.pos_drop(x)
1116
+
1117
+ for layer in self.layers:
1118
+ x = layer(x, x_size, params)
1119
+
1120
+ x = self.norm(x) # b seq_len c
1121
+ x = self.patch_unembed(x, x_size)
1122
+
1123
+ return x
1124
+
1125
+ def forward(self, x):
1126
+ # self.mean = self.mean.type_as(x)
1127
+ # x = (x - self.mean) * self.img_range
1128
+
1129
+ if self.upsampler == 'pixelshuffle':
1130
+ # for classical SR
1131
+ x = self.conv_first(x)
1132
+ if self.use_checkpoint:
1133
+ x = self.conv_after_body(checkpoint(self.forward_features, x)) + x
1134
+ else:
1135
+ x = self.conv_after_body(self.forward_features(x)) + x
1136
+ x = self.conv_before_upsample(x)
1137
+ # x = self.conv_last(self.upsample(x))
1138
+
1139
+ # x = x / self.img_range + self.mean
1140
+
1141
+ return x
1142
+
1143
+
1144
+ if __name__ == '__main__':
1145
+ srcs = torch.randn(8, 3, 64, 64).cuda()
1146
+ encoder = HATNOUP_ROPE_AMP(upscale=4, in_chans=3, img_size=64, window_size=16, compress_ratio=3, squeeze_factor=32, conv_scale=0.01, overlap_ratio=0.5,
1147
+ img_range=1.,
1148
+ depths=(6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6),
1149
+ embed_dim=192,
1150
+ num_heads=(6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6),
1151
+ mlp_ratio=2,
1152
+ upsampler='pixelshuffle',
1153
+ resi_connection='1conv',
1154
+ use_checkpoint=False).cuda()
1155
+ feature = encoder(srcs)
1156
+ pass
utils/rdn.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ from distutils.version import LooseVersion
7
+ from itertools import repeat
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init as init
11
+ from torch.nn.modules.batchnorm import _BatchNorm
12
+
13
+ class RDB_Conv(nn.Module):
14
+ def __init__(self, inChannels, growRate, kSize=3):
15
+ super(RDB_Conv, self).__init__()
16
+ Cin = inChannels
17
+ G = growRate
18
+ self.conv = nn.Sequential(*[
19
+ nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
20
+ nn.ReLU()
21
+ ])
22
+
23
+ def forward(self, x):
24
+ out = self.conv(x)
25
+ return torch.cat((x, out), 1)
26
+
27
+ class RDB(nn.Module):
28
+ def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
29
+ super(RDB, self).__init__()
30
+ G0 = growRate0
31
+ G = growRate
32
+ C = nConvLayers
33
+
34
+ convs = []
35
+ for c in range(C):
36
+ convs.append(RDB_Conv(G0 + c*G, G))
37
+ self.convs = nn.Sequential(*convs)
38
+
39
+ # Local Feature Fusion
40
+ self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
41
+
42
+ def forward(self, x):
43
+ return self.LFF(self.convs(x)) + x
44
+
45
+ class RDNNOUP(nn.Module):
46
+ def __init__(self, G0 = 64, kSize = 3, r = 4, n_colors = 3, RDNconfig = 'B',
47
+ no_upsampling = True, img_range = 1.0):
48
+ super(RDNNOUP, self).__init__()
49
+
50
+ self.no_upsampling = no_upsampling
51
+ self.img_range = img_range
52
+
53
+ # number of RDB blocks, conv layers, out channels
54
+ self.D, C, G = {
55
+ 'A': (20, 6, 32),
56
+ 'B': (16, 8, 64),
57
+ }[RDNconfig]
58
+
59
+ # Shallow feature extraction net
60
+ self.SFENet1 = nn.Conv2d(n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
61
+ self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
62
+
63
+ # Redidual dense blocks and dense feature fusion
64
+ self.RDBs = nn.ModuleList()
65
+ for i in range(self.D):
66
+ self.RDBs.append(
67
+ RDB(growRate0 = G0, growRate = G, nConvLayers = C)
68
+ )
69
+
70
+ # Global Feature Fusion
71
+ self.GFF = nn.Sequential(*[
72
+ nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
73
+ nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
74
+ ])
75
+
76
+ if no_upsampling:
77
+ self.out_dim = G0
78
+ else:
79
+ self.out_dim = n_colors
80
+ # Up-sampling net
81
+ if r == 2 or r == 3:
82
+ self.UPNet = nn.Sequential(*[
83
+ nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
84
+ nn.PixelShuffle(r),
85
+ nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)
86
+ ])
87
+ elif r == 4:
88
+ self.UPNet = nn.Sequential(*[
89
+ nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
90
+ nn.PixelShuffle(2),
91
+ nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
92
+ nn.PixelShuffle(2),
93
+ nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)
94
+ ])
95
+ else:
96
+ raise ValueError("scale must be 2 or 3 or 4.")
97
+
98
+ def forward(self, x):
99
+ x = x * self.img_range
100
+ f__1 = self.SFENet1(x)
101
+ x = self.SFENet2(f__1)
102
+
103
+ RDBs_out = []
104
+ for i in range(self.D):
105
+ x = self.RDBs[i](x)
106
+ RDBs_out.append(x)
107
+
108
+ x = self.GFF(torch.cat(RDBs_out,1))
109
+ x += f__1
110
+
111
+ if self.no_upsampling:
112
+ return x
113
+ else:
114
+ return self.UPNet(x)
115
+
116
+ if __name__ == '__main__':
117
+ x = torch.randn(8,3,48,48)
118
+ model = RDNNOUP()
119
+ y = model(x)
120
+ print(y.shape)
utils/split_and_joint_image.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+ from utils.gaussian_splatting import generate_2D_gaussian_splatting_step, generate_2D_gaussian_splatting_step_buffer
6
+
7
+
8
+ ### If the GPU memory is limited, please use the following code to do tiling process for input LR image
9
+ # def split_and_joint_image(lq, scale_factor, model_g, model_fea2gs, scale_modify, split_size = 48,
10
+ # overlap_size = 8,
11
+ # crop_size = 4,
12
+ # default_step_size = 1.2, mode = 'scale_modify',
13
+ # cuda_rendering = True,
14
+ # if_dmax = False,
15
+ # dmax_mode = 'fix',
16
+ # dmax = 0.1):
17
+ # h_lq, w_lq = lq.shape[-2:]
18
+
19
+ # assert overlap_size > 0 and overlap_size < split_size // 2, f"overlap size is wrong"
20
+
21
+ # tile_nums_h = math.ceil((h_lq - overlap_size) / (split_size - overlap_size))
22
+ # tile_nums_w = math.ceil((w_lq - overlap_size) / (split_size - overlap_size))
23
+
24
+ # pad_h_lq = tile_nums_h * (split_size - overlap_size) + overlap_size - h_lq
25
+ # pad_w_lq = tile_nums_w * (split_size - overlap_size) + overlap_size - w_lq
26
+
27
+ # lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='reflect')
28
+
29
+ # split_size_sr = math.ceil(split_size * scale_factor)
30
+ # sr_tile_list = []
31
+ # for h_num in range(tile_nums_h):
32
+ # for w_num in range(tile_nums_w):
33
+ # tile_lq_position_start_h = h_num * (split_size - overlap_size)
34
+ # tile_lq_position_start_w = w_num * (split_size - overlap_size)
35
+ # tile_lq_position_end_h = tile_lq_position_start_h + split_size
36
+ # tile_lq_position_end_w = tile_lq_position_start_w + split_size
37
+
38
+ # input_tile = lq_pad[:,:, tile_lq_position_start_h:tile_lq_position_end_h, tile_lq_position_start_w:tile_lq_position_end_w]
39
+
40
+ # model_g_output = model_g(input_tile)
41
+
42
+ # scale_vector = scale_modify[0].unsqueeze(0).to(model_g_output.device)
43
+ # batch_gs_parameters = model_fea2gs(model_g_output, scale_vector)
44
+
45
+ # gs_parameters = batch_gs_parameters[0, :]
46
+ # b_output = generate_2D_gaussian_splatting_step(sr_size=torch.tensor([split_size_sr, split_size_sr]), gs_parameters=gs_parameters,
47
+ # lq=input_tile[0, :], scale=scale_factor, sample_coords=None,
48
+ # scale_modify = scale_modify,
49
+ # default_step_size = default_step_size, mode = mode,
50
+ # cuda_rendering = cuda_rendering,
51
+ # if_dmax = if_dmax,
52
+ # dmax_mode = dmax_mode,
53
+ # dmax = dmax)
54
+ # sr_tile_list.append(b_output.unsqueeze(0))
55
+
56
+ # tile_sr_h = sr_tile_list[0].shape[2]
57
+ # tile_sr_w = sr_tile_list[0].shape[3]
58
+
59
+ # assert tile_sr_w == split_size_sr and tile_sr_h == split_size_sr, \
60
+ # f'tile_sr_h-{tile_sr_w}, tile_sr_w-{tile_sr_w}, split_size_sr-{split_size_sr} is not the same'
61
+
62
+ # overlap_sr = math.ceil(overlap_size * scale_factor)
63
+
64
+ # sr_pad = torch.zeros(lq.shape[0], lq.shape[1],
65
+ # math.ceil(lq_pad.shape[2] * scale_factor),
66
+ # math.ceil(lq_pad.shape[3] * scale_factor),
67
+ # device=lq.device)
68
+
69
+ # idx = 0
70
+ # for h_num in range(tile_nums_h):
71
+ # for w_num in range(tile_nums_w):
72
+ # tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
73
+ # tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
74
+ # tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
75
+ # tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
76
+ # if h_num == 0 and w_num == 0:
77
+ # sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
78
+ # tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
79
+ # elif h_num == 0 and w_num !=0:
80
+ # sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
81
+ # tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
82
+ # elif h_num != 0 and w_num ==0:
83
+ # sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
84
+ # tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
85
+ # else:
86
+ # sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
87
+ # tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
88
+ # idx = idx + 1
89
+
90
+ # print(f"sr_pad shape is {sr_pad.shape}")
91
+
92
+ # # sr_final = sr_pad[:,:, 0:math.ceil(h_lq * scale_factor), 0: math.ceil(w_lq * scale_factor)]
93
+ # sr_final = sr_pad
94
+
95
+ # return sr_final
96
+
97
+
98
+ def split_and_joint_image(lq, scale_factor, split_size,
99
+ overlap_size, model_g, model_fea2gs,
100
+ scale_modify, crop_size = 2,
101
+ default_step_size = 1.2, mode = 'scale_modify',
102
+ cuda_rendering = True,
103
+ if_dmax = False,
104
+ dmax_mode = 'fix',
105
+ dmax = 25):
106
+ h_lq, w_lq = lq.shape[-2:]
107
+
108
+ # assert h_lq > split_size, f'h_lq-{h_lq} should be larger than split_size-{split_size}, please do not use tile_process, or decrease the split_size'
109
+ # assert w_lq > split_size, f'w_lq-{w_lq} should be larger than split_size-{split_size}, please do not use tile_process, or decrease the split_size'
110
+
111
+ assert overlap_size > 0 and overlap_size < split_size // 2, f"overlap size is wrong"
112
+
113
+ tile_nums_h = math.ceil((h_lq - overlap_size) / (split_size - overlap_size))
114
+ tile_nums_w = math.ceil((w_lq - overlap_size) / (split_size - overlap_size))
115
+
116
+ pad_h_lq = tile_nums_h * (split_size - overlap_size) + overlap_size - h_lq
117
+ pad_w_lq = tile_nums_w * (split_size - overlap_size) + overlap_size - w_lq
118
+
119
+ assert pad_h_lq < h_lq, f'pad_h_lq-{pad_h_lq} should be smaller than h_lq-{h_lq}, please decrease the split_size-{split_size}'
120
+ assert pad_w_lq < w_lq, f'pad_w_lq-{pad_w_lq} should be smaller than w_lq-{w_lq}, please decrease the split_size-{split_size}'
121
+
122
+ lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='reflect')
123
+ # lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='constant', value=0)
124
+
125
+ split_size_sr = math.ceil(split_size * scale_factor)
126
+ sr_tile_list = []
127
+ for h_num in range(tile_nums_h):
128
+ for w_num in range(tile_nums_w):
129
+ tile_lq_position_start_h = h_num * (split_size - overlap_size)
130
+ tile_lq_position_start_w = w_num * (split_size - overlap_size)
131
+ tile_lq_position_end_h = tile_lq_position_start_h + split_size
132
+ tile_lq_position_end_w = tile_lq_position_start_w + split_size
133
+
134
+ input_tile = lq_pad[:,:, tile_lq_position_start_h:tile_lq_position_end_h, tile_lq_position_start_w:tile_lq_position_end_w]
135
+
136
+ model_g_output = model_g(input_tile)
137
+
138
+ scale_vector = scale_modify[0].unsqueeze(0).to(model_g_output.device)
139
+ batch_gs_parameters = model_fea2gs(model_g_output, scale_vector)
140
+
141
+
142
+ gs_parameters = batch_gs_parameters[0, :]
143
+ b_output = generate_2D_gaussian_splatting_step(sr_size=torch.tensor([split_size_sr, split_size_sr]), gs_parameters=gs_parameters,
144
+ scale=scale_factor, sample_coords=None,
145
+ scale_modify = scale_modify,
146
+ default_step_size = default_step_size, mode = mode,
147
+ cuda_rendering = cuda_rendering,
148
+ if_dmax = if_dmax,
149
+ dmax_mode = dmax_mode,
150
+ dmax = dmax)
151
+ sr_tile_list.append(b_output.unsqueeze(0))
152
+
153
+ tile_sr_h = sr_tile_list[0].shape[2]
154
+ tile_sr_w = sr_tile_list[0].shape[3]
155
+
156
+ assert tile_sr_w == split_size_sr and tile_sr_h == split_size_sr, \
157
+ f'tile_sr_h-{tile_sr_w}, tile_sr_w-{tile_sr_w}, split_size_sr-{split_size_sr} is not the same'
158
+
159
+ overlap_sr = math.ceil(overlap_size * scale_factor)
160
+
161
+ sr_pad = torch.zeros(lq.shape[0], lq.shape[1],
162
+ (tile_nums_h - 1) * (split_size_sr - overlap_sr) + split_size_sr,
163
+ (tile_nums_w - 1) * (split_size_sr - overlap_sr) + split_size_sr,
164
+ device=lq.device)
165
+
166
+ idx = 0
167
+
168
+ if scale_factor != int(scale_factor):
169
+ for h_num in range(tile_nums_h):
170
+ for w_num in range(tile_nums_w):
171
+ tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
172
+ tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
173
+ tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
174
+ tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
175
+ if h_num == 0 and w_num == 0:
176
+ sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
177
+ tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
178
+ elif h_num == 0 and w_num !=0:
179
+ if w_num != tile_nums_w - 1:
180
+ sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
181
+ tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
182
+ else:
183
+ sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
184
+ tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,:,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
185
+ elif h_num != 0 and w_num ==0:
186
+ if h_num != tile_nums_h - 1:
187
+ sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
188
+ tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
189
+ else:
190
+ sr_pad[:, :, tile_sr_position_start_h+crop_size:sr_pad.shape[2],
191
+ tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,:]
192
+ else:
193
+ if w_num != tile_nums_w - 1 and h_num != tile_nums_h - 1:
194
+ sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
195
+ tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
196
+ elif w_num == tile_nums_w - 1 and h_num != tile_nums_h - 1:
197
+ sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
198
+ tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,:,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
199
+ elif w_num != tile_nums_w - 1 and h_num == tile_nums_h - 1:
200
+ sr_pad[:, :, tile_sr_position_start_h+crop_size:sr_pad.shape[2],
201
+ tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,:]
202
+ elif w_num == tile_nums_w - 1 and h_num == tile_nums_h - 1:
203
+ sr_pad[:,:,tile_sr_position_start_h+crop_size:sr_pad.shape[2],
204
+ tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,crop_size:sr_pad.shape[3] - tile_sr_position_start_w]
205
+ idx = idx + 1
206
+ else:
207
+ for h_num in range(tile_nums_h):
208
+ for w_num in range(tile_nums_w):
209
+ tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr)
210
+ tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr
211
+ tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr)
212
+ tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr
213
+ if h_num == 0 and w_num == 0:
214
+ sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
215
+ tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx]
216
+ elif h_num == 0 and w_num !=0:
217
+ sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h,
218
+ tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:]
219
+ elif h_num != 0 and w_num ==0:
220
+ sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
221
+ tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:]
222
+ else:
223
+ sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h,
224
+ tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:]
225
+ idx = idx + 1
226
+
227
+ print(f"sr_pad shape is {sr_pad.shape}")
228
+
229
+ # sr_final = sr_pad[:,:, 0:math.ceil(h_lq * scale_factor), 0: math.ceil(w_lq * scale_factor)]
230
+ sr_final = sr_pad
231
+
232
+ return sr_final
utils/swinir.py ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/JingyunLiang/SwinIR
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+
5
+ import collections.abc
6
+ import torchvision
7
+ import warnings
8
+ from distutils.version import LooseVersion
9
+ from itertools import repeat
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.utils.checkpoint as checkpoint
15
+
16
+ # From PyTorch
17
+ def _ntuple(n):
18
+
19
+ def parse(x):
20
+ if isinstance(x, collections.abc.Iterable):
21
+ return x
22
+ return tuple(repeat(x, n))
23
+
24
+ return parse
25
+
26
+
27
+ to_1tuple = _ntuple(1)
28
+ to_2tuple = _ntuple(2)
29
+ to_3tuple = _ntuple(3)
30
+ to_4tuple = _ntuple(4)
31
+ to_ntuple = _ntuple
32
+
33
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
34
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
35
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
36
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
37
+ def norm_cdf(x):
38
+ # Computes standard normal cumulative distribution function
39
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
40
+
41
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
42
+ warnings.warn(
43
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
44
+ 'The distribution of values may be incorrect.',
45
+ stacklevel=2)
46
+
47
+ with torch.no_grad():
48
+ # Values are generated by using a truncated uniform distribution and
49
+ # then using the inverse CDF for the normal distribution.
50
+ # Get upper and lower cdf values
51
+ low = norm_cdf((a - mean) / std)
52
+ up = norm_cdf((b - mean) / std)
53
+
54
+ # Uniformly fill tensor with values from [low, up], then translate to
55
+ # [2l-1, 2u-1].
56
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
57
+
58
+ # Use inverse cdf transform for normal distribution to get truncated
59
+ # standard normal
60
+ tensor.erfinv_()
61
+
62
+ # Transform to proper mean, std
63
+ tensor.mul_(std * math.sqrt(2.))
64
+ tensor.add_(mean)
65
+
66
+ # Clamp to ensure it's in the proper range
67
+ tensor.clamp_(min=a, max=b)
68
+ return tensor
69
+
70
+
71
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
72
+ r"""Fills the input Tensor with values drawn from a truncated
73
+ normal distribution.
74
+
75
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
76
+
77
+ The values are effectively drawn from the
78
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
79
+ with values outside :math:`[a, b]` redrawn until they are within
80
+ the bounds. The method used for generating the random values works
81
+ best when :math:`a \leq \text{mean} \leq b`.
82
+
83
+ Args:
84
+ tensor: an n-dimensional `torch.Tensor`
85
+ mean: the mean of the normal distribution
86
+ std: the standard deviation of the normal distribution
87
+ a: the minimum cutoff value
88
+ b: the maximum cutoff value
89
+
90
+ Examples:
91
+ >>> w = torch.empty(3, 5)
92
+ >>> nn.init.trunc_normal_(w)
93
+ """
94
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
95
+
96
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
97
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
98
+
99
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
100
+ """
101
+ if drop_prob == 0. or not training:
102
+ return x
103
+ keep_prob = 1 - drop_prob
104
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
105
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
106
+ random_tensor.floor_() # binarize
107
+ output = x.div(keep_prob) * random_tensor
108
+ return output
109
+
110
+
111
+ class DropPath(nn.Module):
112
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
113
+
114
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
115
+ """
116
+
117
+ def __init__(self, drop_prob=None):
118
+ super(DropPath, self).__init__()
119
+ self.drop_prob = drop_prob
120
+
121
+ def forward(self, x):
122
+ return drop_path(x, self.drop_prob, self.training)
123
+
124
+
125
+ class Mlp(nn.Module):
126
+
127
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
128
+ super().__init__()
129
+ out_features = out_features or in_features
130
+ hidden_features = hidden_features or in_features
131
+ self.fc1 = nn.Linear(in_features, hidden_features)
132
+ self.act = act_layer()
133
+ self.fc2 = nn.Linear(hidden_features, out_features)
134
+ self.drop = nn.Dropout(drop)
135
+
136
+ def forward(self, x):
137
+ x = self.fc1(x)
138
+ x = self.act(x)
139
+ x = self.drop(x)
140
+ x = self.fc2(x)
141
+ x = self.drop(x)
142
+ return x
143
+
144
+
145
+ def window_partition(x, window_size):
146
+ """
147
+ Args:
148
+ x: (b, h, w, c)
149
+ window_size (int): window size
150
+
151
+ Returns:
152
+ windows: (num_windows*b, window_size, window_size, c)
153
+ """
154
+ b, h, w, c = x.shape
155
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
156
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
157
+ return windows
158
+
159
+
160
+ def window_reverse(windows, window_size, h, w):
161
+ """
162
+ Args:
163
+ windows: (num_windows*b, window_size, window_size, c)
164
+ window_size (int): Window size
165
+ h (int): Height of image
166
+ w (int): Width of image
167
+
168
+ Returns:
169
+ x: (b, h, w, c)
170
+ """
171
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
172
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
173
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
174
+ return x
175
+
176
+
177
+ class WindowAttention(nn.Module):
178
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
179
+ It supports both of shifted and non-shifted window.
180
+
181
+ Args:
182
+ dim (int): Number of input channels.
183
+ window_size (tuple[int]): The height and width of the window.
184
+ num_heads (int): Number of attention heads.
185
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
186
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
187
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
188
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
189
+ """
190
+
191
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
192
+
193
+ super().__init__()
194
+ self.dim = dim
195
+ self.window_size = window_size # Wh, Ww
196
+ self.num_heads = num_heads
197
+ head_dim = dim // num_heads
198
+ self.scale = qk_scale or head_dim**-0.5
199
+
200
+ # define a parameter table of relative position bias
201
+ self.relative_position_bias_table = nn.Parameter(
202
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
203
+
204
+ # get pair-wise relative position index for each token inside the window
205
+ coords_h = torch.arange(self.window_size[0])
206
+ coords_w = torch.arange(self.window_size[1])
207
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
208
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
209
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
210
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
211
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
212
+ relative_coords[:, :, 1] += self.window_size[1] - 1
213
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
214
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
215
+ self.register_buffer('relative_position_index', relative_position_index)
216
+
217
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
218
+ self.attn_drop = nn.Dropout(attn_drop)
219
+ self.proj = nn.Linear(dim, dim)
220
+
221
+ self.proj_drop = nn.Dropout(proj_drop)
222
+
223
+ trunc_normal_(self.relative_position_bias_table, std=.02)
224
+ self.softmax = nn.Softmax(dim=-1)
225
+
226
+ def forward(self, x, mask=None):
227
+ """
228
+ Args:
229
+ x: input features with shape of (num_windows*b, n, c)
230
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
231
+ """
232
+ b_, n, c = x.shape
233
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
234
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
235
+
236
+ q = q * self.scale
237
+ attn = (q @ k.transpose(-2, -1))
238
+
239
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
240
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
241
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
242
+ attn = attn + relative_position_bias.unsqueeze(0)
243
+
244
+ if mask is not None:
245
+ nw = mask.shape[0]
246
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
247
+ attn = attn.view(-1, self.num_heads, n, n)
248
+ attn = self.softmax(attn)
249
+ else:
250
+ attn = self.softmax(attn)
251
+
252
+ attn = self.attn_drop(attn)
253
+
254
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
255
+ x = self.proj(x)
256
+ x = self.proj_drop(x)
257
+ return x
258
+
259
+ def extra_repr(self) -> str:
260
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
261
+
262
+ def flops(self, n):
263
+ # calculate flops for 1 window with token length of n
264
+ flops = 0
265
+ # qkv = self.qkv(x)
266
+ flops += n * self.dim * 3 * self.dim
267
+ # attn = (q @ k.transpose(-2, -1))
268
+ flops += self.num_heads * n * (self.dim // self.num_heads) * n
269
+ # x = (attn @ v)
270
+ flops += self.num_heads * n * n * (self.dim // self.num_heads)
271
+ # x = self.proj(x)
272
+ flops += n * self.dim * self.dim
273
+ return flops
274
+
275
+
276
+ class SwinTransformerBlock(nn.Module):
277
+ r""" Swin Transformer Block.
278
+
279
+ Args:
280
+ dim (int): Number of input channels.
281
+ input_resolution (tuple[int]): Input resolution.
282
+ num_heads (int): Number of attention heads.
283
+ window_size (int): Window size.
284
+ shift_size (int): Shift size for SW-MSA.
285
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
286
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
287
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
288
+ drop (float, optional): Dropout rate. Default: 0.0
289
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
290
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
291
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
292
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
293
+ """
294
+
295
+ def __init__(self,
296
+ dim,
297
+ input_resolution,
298
+ num_heads,
299
+ window_size=7,
300
+ shift_size=0,
301
+ mlp_ratio=4.,
302
+ qkv_bias=True,
303
+ qk_scale=None,
304
+ drop=0.,
305
+ attn_drop=0.,
306
+ drop_path=0.,
307
+ act_layer=nn.GELU,
308
+ norm_layer=nn.LayerNorm):
309
+ super().__init__()
310
+ self.dim = dim
311
+ self.input_resolution = input_resolution
312
+ self.num_heads = num_heads
313
+ self.window_size = window_size
314
+ self.shift_size = shift_size
315
+ self.mlp_ratio = mlp_ratio
316
+ if min(self.input_resolution) <= self.window_size:
317
+ # if window size is larger than input resolution, we don't partition windows
318
+ self.shift_size = 0
319
+ self.window_size = min(self.input_resolution)
320
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
321
+
322
+ self.norm1 = norm_layer(dim)
323
+ self.attn = WindowAttention(
324
+ dim,
325
+ window_size=to_2tuple(self.window_size),
326
+ num_heads=num_heads,
327
+ qkv_bias=qkv_bias,
328
+ qk_scale=qk_scale,
329
+ attn_drop=attn_drop,
330
+ proj_drop=drop)
331
+
332
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
333
+ self.norm2 = norm_layer(dim)
334
+ mlp_hidden_dim = int(dim * mlp_ratio)
335
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
336
+
337
+ if self.shift_size > 0:
338
+ attn_mask = self.calculate_mask(self.input_resolution)
339
+ else:
340
+ attn_mask = None
341
+
342
+ self.register_buffer('attn_mask', attn_mask)
343
+
344
+ def calculate_mask(self, x_size):
345
+ # calculate attention mask for SW-MSA
346
+ h, w = x_size
347
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
348
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
349
+ -self.shift_size), slice(-self.shift_size, None))
350
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
351
+ -self.shift_size), slice(-self.shift_size, None))
352
+ cnt = 0
353
+ for h in h_slices:
354
+ for w in w_slices:
355
+ img_mask[:, h, w, :] = cnt
356
+ cnt += 1
357
+
358
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
359
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
360
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
361
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
362
+
363
+ return attn_mask
364
+
365
+ def forward(self, x, x_size):
366
+ h, w = x_size
367
+ b, _, c = x.shape
368
+ # assert seq_len == h * w, "input feature has wrong size"
369
+
370
+ shortcut = x
371
+ x = self.norm1(x)
372
+ x = x.view(b, h, w, c)
373
+
374
+ # cyclic shift
375
+ if self.shift_size > 0:
376
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
377
+ else:
378
+ shifted_x = x
379
+
380
+ # partition windows
381
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
382
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
383
+
384
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
385
+ if self.input_resolution == x_size:
386
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
387
+ else:
388
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
389
+
390
+ # merge windows
391
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
392
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
393
+
394
+ # reverse cyclic shift
395
+ if self.shift_size > 0:
396
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
397
+ else:
398
+ x = shifted_x
399
+ x = x.view(b, h * w, c)
400
+
401
+ # FFN
402
+ x = shortcut + self.drop_path(x)
403
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
404
+
405
+ return x
406
+
407
+ def extra_repr(self) -> str:
408
+ return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
409
+ f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
410
+
411
+ def flops(self):
412
+ flops = 0
413
+ h, w = self.input_resolution
414
+ # norm1
415
+ flops += self.dim * h * w
416
+ # W-MSA/SW-MSA
417
+ nw = h * w / self.window_size / self.window_size
418
+ flops += nw * self.attn.flops(self.window_size * self.window_size)
419
+ # mlp
420
+ flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
421
+ # norm2
422
+ flops += self.dim * h * w
423
+ return flops
424
+
425
+
426
+ class PatchMerging(nn.Module):
427
+ r""" Patch Merging Layer.
428
+
429
+ Args:
430
+ input_resolution (tuple[int]): Resolution of input feature.
431
+ dim (int): Number of input channels.
432
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
433
+ """
434
+
435
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
436
+ super().__init__()
437
+ self.input_resolution = input_resolution
438
+ self.dim = dim
439
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
440
+ self.norm = norm_layer(4 * dim)
441
+
442
+ def forward(self, x):
443
+ """
444
+ x: b, h*w, c
445
+ """
446
+ h, w = self.input_resolution
447
+ b, seq_len, c = x.shape
448
+ assert seq_len == h * w, 'input feature has wrong size'
449
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
450
+
451
+ x = x.view(b, h, w, c)
452
+
453
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
454
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
455
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
456
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
457
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
458
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
459
+
460
+ x = self.norm(x)
461
+ x = self.reduction(x)
462
+
463
+ return x
464
+
465
+ def extra_repr(self) -> str:
466
+ return f'input_resolution={self.input_resolution}, dim={self.dim}'
467
+
468
+ def flops(self):
469
+ h, w = self.input_resolution
470
+ flops = h * w * self.dim
471
+ flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
472
+ return flops
473
+
474
+
475
+ class BasicLayer(nn.Module):
476
+ """ A basic Swin Transformer layer for one stage.
477
+
478
+ Args:
479
+ dim (int): Number of input channels.
480
+ input_resolution (tuple[int]): Input resolution.
481
+ depth (int): Number of blocks.
482
+ num_heads (int): Number of attention heads.
483
+ window_size (int): Local window size.
484
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
485
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
486
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
487
+ drop (float, optional): Dropout rate. Default: 0.0
488
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
489
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
490
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
491
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
492
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
493
+ """
494
+
495
+ def __init__(self,
496
+ dim,
497
+ input_resolution,
498
+ depth,
499
+ num_heads,
500
+ window_size,
501
+ mlp_ratio=4.,
502
+ qkv_bias=True,
503
+ qk_scale=None,
504
+ drop=0.,
505
+ attn_drop=0.,
506
+ drop_path=0.,
507
+ norm_layer=nn.LayerNorm,
508
+ downsample=None,
509
+ use_checkpoint=False):
510
+
511
+ super().__init__()
512
+ self.dim = dim
513
+ self.input_resolution = input_resolution
514
+ self.depth = depth
515
+ self.use_checkpoint = use_checkpoint
516
+
517
+ # build blocks
518
+ self.blocks = nn.ModuleList([
519
+ SwinTransformerBlock(
520
+ dim=dim,
521
+ input_resolution=input_resolution,
522
+ num_heads=num_heads,
523
+ window_size=window_size,
524
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
525
+ mlp_ratio=mlp_ratio,
526
+ qkv_bias=qkv_bias,
527
+ qk_scale=qk_scale,
528
+ drop=drop,
529
+ attn_drop=attn_drop,
530
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
531
+ norm_layer=norm_layer) for i in range(depth)
532
+ ])
533
+
534
+ # patch merging layer
535
+ if downsample is not None:
536
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
537
+ else:
538
+ self.downsample = None
539
+
540
+ def forward(self, x, x_size):
541
+ for blk in self.blocks:
542
+ if self.use_checkpoint:
543
+ x = checkpoint.checkpoint(blk, x)
544
+ else:
545
+ x = blk(x, x_size)
546
+ if self.downsample is not None:
547
+ x = self.downsample(x)
548
+ return x
549
+
550
+ def extra_repr(self) -> str:
551
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
552
+
553
+ def flops(self):
554
+ flops = 0
555
+ for blk in self.blocks:
556
+ flops += blk.flops()
557
+ if self.downsample is not None:
558
+ flops += self.downsample.flops()
559
+ return flops
560
+
561
+
562
+ class RSTB(nn.Module):
563
+ """Residual Swin Transformer Block (RSTB).
564
+
565
+ Args:
566
+ dim (int): Number of input channels.
567
+ input_resolution (tuple[int]): Input resolution.
568
+ depth (int): Number of blocks.
569
+ num_heads (int): Number of attention heads.
570
+ window_size (int): Local window size.
571
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
572
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
573
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
574
+ drop (float, optional): Dropout rate. Default: 0.0
575
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
576
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
577
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
578
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
579
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
580
+ img_size: Input image size.
581
+ patch_size: Patch size.
582
+ resi_connection: The convolutional block before residual connection.
583
+ """
584
+
585
+ def __init__(self,
586
+ dim,
587
+ input_resolution,
588
+ depth,
589
+ num_heads,
590
+ window_size,
591
+ mlp_ratio=4.,
592
+ qkv_bias=True,
593
+ qk_scale=None,
594
+ drop=0.,
595
+ attn_drop=0.,
596
+ drop_path=0.,
597
+ norm_layer=nn.LayerNorm,
598
+ downsample=None,
599
+ use_checkpoint=False,
600
+ img_size=224,
601
+ patch_size=4,
602
+ resi_connection='1conv'):
603
+ super(RSTB, self).__init__()
604
+
605
+ self.dim = dim
606
+ self.input_resolution = input_resolution
607
+
608
+ self.residual_group = BasicLayer(
609
+ dim=dim,
610
+ input_resolution=input_resolution,
611
+ depth=depth,
612
+ num_heads=num_heads,
613
+ window_size=window_size,
614
+ mlp_ratio=mlp_ratio,
615
+ qkv_bias=qkv_bias,
616
+ qk_scale=qk_scale,
617
+ drop=drop,
618
+ attn_drop=attn_drop,
619
+ drop_path=drop_path,
620
+ norm_layer=norm_layer,
621
+ downsample=downsample,
622
+ use_checkpoint=use_checkpoint)
623
+
624
+ if resi_connection == '1conv':
625
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
626
+ elif resi_connection == '3conv':
627
+ # to save parameters and memory
628
+ self.conv = nn.Sequential(
629
+ nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
630
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
631
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
632
+
633
+ self.patch_embed = PatchEmbed(
634
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
635
+
636
+ self.patch_unembed = PatchUnEmbed(
637
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
638
+
639
+ def forward(self, x, x_size):
640
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
641
+
642
+ def flops(self):
643
+ flops = 0
644
+ flops += self.residual_group.flops()
645
+ h, w = self.input_resolution
646
+ flops += h * w * self.dim * self.dim * 9
647
+ flops += self.patch_embed.flops()
648
+ flops += self.patch_unembed.flops()
649
+
650
+ return flops
651
+
652
+
653
+ class PatchEmbed(nn.Module):
654
+ r""" Image to Patch Embedding
655
+
656
+ Args:
657
+ img_size (int): Image size. Default: 224.
658
+ patch_size (int): Patch token size. Default: 4.
659
+ in_chans (int): Number of input image channels. Default: 3.
660
+ embed_dim (int): Number of linear projection output channels. Default: 96.
661
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
662
+ """
663
+
664
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
665
+ super().__init__()
666
+ img_size = to_2tuple(img_size)
667
+ patch_size = to_2tuple(patch_size)
668
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
669
+ self.img_size = img_size
670
+ self.patch_size = patch_size
671
+ self.patches_resolution = patches_resolution
672
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
673
+
674
+ self.in_chans = in_chans
675
+ self.embed_dim = embed_dim
676
+
677
+ if norm_layer is not None:
678
+ self.norm = norm_layer(embed_dim)
679
+ else:
680
+ self.norm = None
681
+
682
+ def forward(self, x):
683
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
684
+ if self.norm is not None:
685
+ x = self.norm(x)
686
+ return x
687
+
688
+ def flops(self):
689
+ flops = 0
690
+ h, w = self.img_size
691
+ if self.norm is not None:
692
+ flops += h * w * self.embed_dim
693
+ return flops
694
+
695
+
696
+ class PatchUnEmbed(nn.Module):
697
+ r""" Image to Patch Unembedding
698
+
699
+ Args:
700
+ img_size (int): Image size. Default: 224.
701
+ patch_size (int): Patch token size. Default: 4.
702
+ in_chans (int): Number of input image channels. Default: 3.
703
+ embed_dim (int): Number of linear projection output channels. Default: 96.
704
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
705
+ """
706
+
707
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
708
+ super().__init__()
709
+ img_size = to_2tuple(img_size)
710
+ patch_size = to_2tuple(patch_size)
711
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
712
+ self.img_size = img_size
713
+ self.patch_size = patch_size
714
+ self.patches_resolution = patches_resolution
715
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
716
+
717
+ self.in_chans = in_chans
718
+ self.embed_dim = embed_dim
719
+
720
+ def forward(self, x, x_size):
721
+ x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
722
+ return x
723
+
724
+ def flops(self):
725
+ flops = 0
726
+ return flops
727
+
728
+
729
+ class Upsample(nn.Sequential):
730
+ """Upsample module.
731
+
732
+ Args:
733
+ scale (int): Scale factor. Supported scales: 2^n and 3.
734
+ num_feat (int): Channel number of intermediate features.
735
+ """
736
+
737
+ def __init__(self, scale, num_feat):
738
+ m = []
739
+ if (scale & (scale - 1)) == 0: # scale = 2^n
740
+ for _ in range(int(math.log(scale, 2))):
741
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
742
+ m.append(nn.PixelShuffle(2))
743
+ elif scale == 3:
744
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
745
+ m.append(nn.PixelShuffle(3))
746
+ else:
747
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
748
+ super(Upsample, self).__init__(*m)
749
+
750
+
751
+ class UpsampleOneStep(nn.Sequential):
752
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
753
+ Used in lightweight SR to save parameters.
754
+
755
+ Args:
756
+ scale (int): Scale factor. Supported scales: 2^n and 3.
757
+ num_feat (int): Channel number of intermediate features.
758
+
759
+ """
760
+
761
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
762
+ self.num_feat = num_feat
763
+ self.input_resolution = input_resolution
764
+ m = []
765
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
766
+ m.append(nn.PixelShuffle(scale))
767
+ super(UpsampleOneStep, self).__init__(*m)
768
+
769
+ def flops(self):
770
+ h, w = self.input_resolution
771
+ flops = h * w * self.num_feat * 3 * 9
772
+ return flops
773
+
774
+
775
+ class SwinIR(nn.Module):
776
+ r""" SwinIR
777
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
778
+
779
+ Args:
780
+ img_size (int | tuple(int)): Input image size. Default 64
781
+ patch_size (int | tuple(int)): Patch size. Default: 1
782
+ in_chans (int): Number of input image channels. Default: 3
783
+ embed_dim (int): Patch embedding dimension. Default: 96
784
+ depths (tuple(int)): Depth of each Swin Transformer layer.
785
+ num_heads (tuple(int)): Number of attention heads in different layers.
786
+ window_size (int): Window size. Default: 7
787
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
788
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
789
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
790
+ drop_rate (float): Dropout rate. Default: 0
791
+ attn_drop_rate (float): Attention dropout rate. Default: 0
792
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
793
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
794
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
795
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
796
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
797
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
798
+ img_range: Image range. 1. or 255.
799
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
800
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
801
+ """
802
+
803
+ def __init__(self,
804
+ img_size=64,
805
+ patch_size=1,
806
+ in_chans=3,
807
+ embed_dim=96,
808
+ depths=(6, 6, 6, 6),
809
+ num_heads=(6, 6, 6, 6),
810
+ window_size=7,
811
+ mlp_ratio=4.,
812
+ qkv_bias=True,
813
+ qk_scale=None,
814
+ drop_rate=0.,
815
+ attn_drop_rate=0.,
816
+ drop_path_rate=0.1,
817
+ norm_layer=nn.LayerNorm,
818
+ ape=False,
819
+ patch_norm=True,
820
+ use_checkpoint=False,
821
+ upscale=2,
822
+ img_range=1.,
823
+ upsampler='',
824
+ resi_connection='1conv',
825
+ **kwargs):
826
+ super(SwinIR, self).__init__()
827
+ num_in_ch = in_chans
828
+ num_out_ch = in_chans
829
+ num_feat = 64
830
+ self.img_range = img_range
831
+ if in_chans == 3:
832
+ rgb_mean = (0.4488, 0.4371, 0.4040)
833
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
834
+ else:
835
+ self.mean = torch.zeros(1, 1, 1, 1)
836
+ self.upscale = upscale
837
+ self.upsampler = upsampler
838
+
839
+ # ------------------------- 1, shallow feature extraction ------------------------- #
840
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
841
+
842
+ # ------------------------- 2, deep feature extraction ------------------------- #
843
+ self.num_layers = len(depths)
844
+ self.embed_dim = embed_dim
845
+ self.ape = ape
846
+ self.patch_norm = patch_norm
847
+ self.num_features = embed_dim
848
+ self.mlp_ratio = mlp_ratio
849
+
850
+ # split image into non-overlapping patches
851
+ self.patch_embed = PatchEmbed(
852
+ img_size=img_size,
853
+ patch_size=patch_size,
854
+ in_chans=embed_dim,
855
+ embed_dim=embed_dim,
856
+ norm_layer=norm_layer if self.patch_norm else None)
857
+ num_patches = self.patch_embed.num_patches
858
+ patches_resolution = self.patch_embed.patches_resolution
859
+ self.patches_resolution = patches_resolution
860
+
861
+ # merge non-overlapping patches into image
862
+ self.patch_unembed = PatchUnEmbed(
863
+ img_size=img_size,
864
+ patch_size=patch_size,
865
+ in_chans=embed_dim,
866
+ embed_dim=embed_dim,
867
+ norm_layer=norm_layer if self.patch_norm else None)
868
+
869
+ # absolute position embedding
870
+ if self.ape:
871
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
872
+ trunc_normal_(self.absolute_pos_embed, std=.02)
873
+
874
+ self.pos_drop = nn.Dropout(p=drop_rate)
875
+
876
+ # stochastic depth
877
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
878
+
879
+ # build Residual Swin Transformer blocks (RSTB)
880
+ self.layers = nn.ModuleList()
881
+ for i_layer in range(self.num_layers):
882
+ layer = RSTB(
883
+ dim=embed_dim,
884
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
885
+ depth=depths[i_layer],
886
+ num_heads=num_heads[i_layer],
887
+ window_size=window_size,
888
+ mlp_ratio=self.mlp_ratio,
889
+ qkv_bias=qkv_bias,
890
+ qk_scale=qk_scale,
891
+ drop=drop_rate,
892
+ attn_drop=attn_drop_rate,
893
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
894
+ norm_layer=norm_layer,
895
+ downsample=None,
896
+ use_checkpoint=use_checkpoint,
897
+ img_size=img_size,
898
+ patch_size=patch_size,
899
+ resi_connection=resi_connection)
900
+ self.layers.append(layer)
901
+ self.norm = norm_layer(self.num_features)
902
+
903
+ # build the last conv layer in deep feature extraction
904
+ if resi_connection == '1conv':
905
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
906
+ elif resi_connection == '3conv':
907
+ # to save parameters and memory
908
+ self.conv_after_body = nn.Sequential(
909
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
910
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
911
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
912
+
913
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
914
+ if self.upsampler == 'pixelshuffle':
915
+ # for classical SR
916
+ self.conv_before_upsample = nn.Sequential(
917
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
918
+ self.upsample = Upsample(upscale, num_feat)
919
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
920
+ elif self.upsampler == 'pixelshuffledirect':
921
+ # for lightweight SR (to save parameters)
922
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
923
+ (patches_resolution[0], patches_resolution[1]))
924
+ elif self.upsampler == 'nearest+conv':
925
+ # for real-world SR (less artifacts)
926
+ assert self.upscale == 4, 'only support x4 now.'
927
+ self.conv_before_upsample = nn.Sequential(
928
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
929
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
930
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
931
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
932
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
933
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
934
+ else:
935
+ # for image denoising and JPEG compression artifact reduction
936
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
937
+
938
+ self.apply(self._init_weights)
939
+
940
+ def _init_weights(self, m):
941
+ if isinstance(m, nn.Linear):
942
+ trunc_normal_(m.weight, std=.02)
943
+ if isinstance(m, nn.Linear) and m.bias is not None:
944
+ nn.init.constant_(m.bias, 0)
945
+ elif isinstance(m, nn.LayerNorm):
946
+ nn.init.constant_(m.bias, 0)
947
+ nn.init.constant_(m.weight, 1.0)
948
+
949
+ @torch.jit.ignore
950
+ def no_weight_decay(self):
951
+ return {'absolute_pos_embed'}
952
+
953
+ @torch.jit.ignore
954
+ def no_weight_decay_keywords(self):
955
+ return {'relative_position_bias_table'}
956
+
957
+ def forward_features(self, x):
958
+ x_size = (x.shape[2], x.shape[3])
959
+ x = self.patch_embed(x)
960
+ if self.ape:
961
+ x = x + self.absolute_pos_embed
962
+ x = self.pos_drop(x)
963
+
964
+ for layer in self.layers:
965
+ x = layer(x, x_size)
966
+
967
+ x = self.norm(x) # b seq_len c
968
+ x = self.patch_unembed(x, x_size)
969
+
970
+ return x
971
+
972
+ def forward(self, x):
973
+ self.mean = self.mean.type_as(x)
974
+ x = (x - self.mean) * self.img_range
975
+
976
+ if self.upsampler == 'pixelshuffle':
977
+ # for classical SR
978
+ x = self.conv_first(x)
979
+ x = self.conv_after_body(self.forward_features(x)) + x
980
+ x = self.conv_before_upsample(x)
981
+ x = self.conv_last(self.upsample(x))
982
+ elif self.upsampler == 'pixelshuffledirect':
983
+ # for lightweight SR
984
+ x = self.conv_first(x)
985
+ x = self.conv_after_body(self.forward_features(x)) + x
986
+ x = self.upsample(x)
987
+ elif self.upsampler == 'nearest+conv':
988
+ # for real-world SR
989
+ x = self.conv_first(x)
990
+ x = self.conv_after_body(self.forward_features(x)) + x
991
+ x = self.conv_before_upsample(x)
992
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
993
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
994
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
995
+ else:
996
+ # for image denoising and JPEG compression artifact reduction
997
+ x_first = self.conv_first(x)
998
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
999
+ x = x + self.conv_last(res)
1000
+
1001
+ x = x / self.img_range + self.mean
1002
+
1003
+ return x
1004
+
1005
+ def flops(self):
1006
+ flops = 0
1007
+ h, w = self.patches_resolution
1008
+ flops += h * w * 3 * self.embed_dim * 9
1009
+ flops += self.patch_embed.flops()
1010
+ for layer in self.layers:
1011
+ flops += layer.flops()
1012
+ flops += h * w * 3 * self.embed_dim * self.embed_dim
1013
+ flops += self.upsample.flops()
1014
+ return flops
1015
+
1016
+
1017
+
1018
+ class SwinIRNOUP(nn.Module):
1019
+ def __init__(self,
1020
+ img_size=48,
1021
+ patch_size=1,
1022
+ in_chans=3,
1023
+ embed_dim=180,
1024
+ depths=(6, 6, 6, 6, 6, 6),
1025
+ num_heads=(6, 6, 6, 6, 6, 6),
1026
+ window_size=8,
1027
+ mlp_ratio=2,
1028
+ qkv_bias=True,
1029
+ qk_scale=None,
1030
+ drop_rate=0.,
1031
+ attn_drop_rate=0.,
1032
+ drop_path_rate=0.1,
1033
+ norm_layer=nn.LayerNorm,
1034
+ ape=False,
1035
+ patch_norm=True,
1036
+ use_checkpoint=False,
1037
+ upscale=4,
1038
+ img_range=1.,
1039
+ upsampler='pixelshuffle',
1040
+ resi_connection='1conv',
1041
+ **kwargs):
1042
+ super(SwinIRNOUP, self).__init__()
1043
+ num_in_ch = in_chans
1044
+ num_out_ch = in_chans
1045
+ num_feat = 64
1046
+ self.img_range = img_range
1047
+ self.upsampler = upsampler
1048
+
1049
+
1050
+ # ------------------------- 1, shallow feature extraction ------------------------- #
1051
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
1052
+
1053
+ # ------------------------- 2, deep feature extraction ------------------------- #
1054
+ self.num_layers = len(depths)
1055
+ self.embed_dim = embed_dim
1056
+ self.ape = ape
1057
+ self.patch_norm = patch_norm
1058
+ self.num_features = embed_dim
1059
+ self.mlp_ratio = mlp_ratio
1060
+
1061
+ # split image into non-overlapping patches
1062
+ self.patch_embed = PatchEmbed(
1063
+ img_size=img_size,
1064
+ patch_size=patch_size,
1065
+ in_chans=embed_dim,
1066
+ embed_dim=embed_dim,
1067
+ norm_layer=norm_layer if self.patch_norm else None)
1068
+ num_patches = self.patch_embed.num_patches
1069
+ patches_resolution = self.patch_embed.patches_resolution
1070
+ self.patches_resolution = patches_resolution
1071
+
1072
+ # merge non-overlapping patches into image
1073
+ self.patch_unembed = PatchUnEmbed(
1074
+ img_size=img_size,
1075
+ patch_size=patch_size,
1076
+ in_chans=embed_dim,
1077
+ embed_dim=embed_dim,
1078
+ norm_layer=norm_layer if self.patch_norm else None)
1079
+
1080
+ # absolute position embedding
1081
+ if self.ape:
1082
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
1083
+ trunc_normal_(self.absolute_pos_embed, std=.02)
1084
+
1085
+ self.pos_drop = nn.Dropout(p=drop_rate)
1086
+
1087
+ # stochastic depth
1088
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
1089
+
1090
+ # build Residual Swin Transformer blocks (RSTB)
1091
+ self.layers = nn.ModuleList()
1092
+ for i_layer in range(self.num_layers):
1093
+ layer = RSTB(
1094
+ dim=embed_dim,
1095
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1096
+ depth=depths[i_layer],
1097
+ num_heads=num_heads[i_layer],
1098
+ window_size=window_size,
1099
+ mlp_ratio=self.mlp_ratio,
1100
+ qkv_bias=qkv_bias,
1101
+ qk_scale=qk_scale,
1102
+ drop=drop_rate,
1103
+ attn_drop=attn_drop_rate,
1104
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
1105
+ norm_layer=norm_layer,
1106
+ downsample=None,
1107
+ use_checkpoint=use_checkpoint,
1108
+ img_size=img_size,
1109
+ patch_size=patch_size,
1110
+ resi_connection=resi_connection)
1111
+ self.layers.append(layer)
1112
+ self.norm = norm_layer(self.num_features)
1113
+
1114
+ # build the last conv layer in deep feature extraction
1115
+ if resi_connection == '1conv':
1116
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1117
+ elif resi_connection == '3conv':
1118
+ # to save parameters and memory
1119
+ self.conv_after_body = nn.Sequential(
1120
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
1121
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
1122
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
1123
+
1124
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
1125
+ if self.upsampler == 'pixelshuffle':
1126
+ # for classical SR
1127
+ self.conv_before_upsample = nn.Sequential(
1128
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
1129
+
1130
+ elif self.upsampler == 'pixelshuffledirect':
1131
+ # for lightweight SR (to save parameters)
1132
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
1133
+ (patches_resolution[0], patches_resolution[1]))
1134
+ elif self.upsampler == 'nearest+conv':
1135
+ # for real-world SR (less artifacts)
1136
+ assert self.upscale == 4, 'only support x4 now.'
1137
+ self.conv_before_upsample = nn.Sequential(
1138
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
1139
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1140
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1141
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1142
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1143
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1144
+ else:
1145
+ # for image denoising and JPEG compression artifact reduction
1146
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
1147
+
1148
+ self.apply(self._init_weights)
1149
+
1150
+ def _init_weights(self, m):
1151
+ if isinstance(m, nn.Linear):
1152
+ trunc_normal_(m.weight, std=.02)
1153
+ if isinstance(m, nn.Linear) and m.bias is not None:
1154
+ nn.init.constant_(m.bias, 0)
1155
+ elif isinstance(m, nn.LayerNorm):
1156
+ nn.init.constant_(m.bias, 0)
1157
+ nn.init.constant_(m.weight, 1.0)
1158
+
1159
+ @torch.jit.ignore
1160
+ def no_weight_decay(self):
1161
+ return {'absolute_pos_embed'}
1162
+
1163
+ @torch.jit.ignore
1164
+ def no_weight_decay_keywords(self):
1165
+ return {'relative_position_bias_table'}
1166
+
1167
+ def forward_features(self, x):
1168
+ x_size = (x.shape[2], x.shape[3])
1169
+ x = self.patch_embed(x)
1170
+ if self.ape:
1171
+ x = x + self.absolute_pos_embed
1172
+ x = self.pos_drop(x)
1173
+
1174
+ for layer in self.layers:
1175
+ x = layer(x, x_size)
1176
+
1177
+ x = self.norm(x) # b seq_len c
1178
+ x = self.patch_unembed(x, x_size)
1179
+
1180
+ return x
1181
+
1182
+ def forward(self, x):
1183
+
1184
+ if self.upsampler == 'pixelshuffle':
1185
+ # for classical SR
1186
+ x = self.conv_first(x)
1187
+ x = self.conv_after_body(self.forward_features(x)) + x
1188
+ x = self.conv_before_upsample(x)
1189
+
1190
+ elif self.upsampler == 'pixelshuffledirect':
1191
+ # for lightweight SR
1192
+ x = self.conv_first(x)
1193
+ x = self.conv_after_body(self.forward_features(x)) + x
1194
+
1195
+ elif self.upsampler == 'nearest+conv':
1196
+ # for real-world SR
1197
+ x = self.conv_first(x)
1198
+ x = self.conv_after_body(self.forward_features(x)) + x
1199
+ x = self.conv_before_upsample(x)
1200
+
1201
+ else:
1202
+ # for image denoising and JPEG compression artifact reduction
1203
+ x_first = self.conv_first(x)
1204
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1205
+ x = x + self.conv_last(res)
1206
+
1207
+ return x
1208
+
1209
+ def flops(self):
1210
+ flops = 0
1211
+ h, w = self.patches_resolution
1212
+ flops += h * w * 3 * self.embed_dim * 9
1213
+ flops += self.patch_embed.flops()
1214
+ for layer in self.layers:
1215
+ flops += layer.flops()
1216
+ flops += h * w * 3 * self.embed_dim * self.embed_dim
1217
+ flops += self.upsample.flops()
1218
+ return flops
1219
+
1220
+
1221
+
1222
+
1223
+ if __name__ == '__main__':
1224
+ upscale = 4
1225
+ window_size = 8
1226
+ height = (1024 // upscale // window_size + 1) * window_size
1227
+ width = (720 // upscale // window_size + 1) * window_size
1228
+ model = SwinIR(
1229
+ upscale=2,
1230
+ img_size=(height, width),
1231
+ window_size=window_size,
1232
+ img_range=1.,
1233
+ depths=[6, 6, 6, 6],
1234
+ embed_dim=60,
1235
+ num_heads=[6, 6, 6, 6],
1236
+ mlp_ratio=2,
1237
+ upsampler='pixelshuffledirect')
1238
+ print(model)
1239
+ print(height, width, model.flops() / 1e9)
1240
+
1241
+ x = torch.randn((1, 3, height, width))
1242
+ x = model(x)
1243
+ print(x.shape)