mutou0308 commited on
Commit
67816f5
Β·
verified Β·
1 Parent(s): af3f1a2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_demo.py +359 -0
gradio_demo.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import math
8
+ import torch.nn.functional as F
9
+ import os
10
+ import tempfile
11
+ import time
12
+ import threading
13
+
14
+ from utils.hatropeamp import HATNOUP_ROPE_AMP
15
+ from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
16
+ from utils.edsrbaseline import EDSRNOUP
17
+ from utils.hatropeamp import HATNOUP_ROPE_AMP
18
+ from utils.rdn import RDNNOUP
19
+ from utils.swinir import SwinIRNOUP
20
+ from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
21
+ from utils.gaussian_splatting import generate_2D_gaussian_splatting_step
22
+ from utils.split_and_joint_image import split_and_joint_image
23
+ from huggingface_hub import hf_hub_download
24
+ import subprocess
25
+ import sys
26
+ import spaces
27
+
28
+
29
+
30
+ # Device setup
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ # Global stop flag for interrupting inference
34
+ stop_inference = False
35
+ inference_lock = threading.Lock()
36
+
37
+ def load_model(
38
+ pretrained_model_name_or_path: str = "mutou0308/GSASR",
39
+ model_name: str = "HATL_SA1B",
40
+ device: str | torch.device = "cuda"
41
+ ):
42
+ enc_path = hf_hub_download(
43
+ repo_id=pretrained_model_name_or_path, filename=os.path.join('GSASR_enhenced_ultra', model_name, 'encoder.pth')
44
+ )
45
+ dec_path = hf_hub_download(
46
+ repo_id=pretrained_model_name_or_path, filename=os.path.join('GSASR_enhenced_ultra', model_name, 'decoder.pth')
47
+ )
48
+
49
+ enc_weight = torch.load(enc_path, weights_only=True)['params_ema']
50
+ dec_weight = torch.load(dec_path, weights_only=True)['params_ema']
51
+
52
+ if model_name in ['EDSR_DIV2K', 'EDSR_DF2K']:
53
+ encoder = EDSRNOUP()
54
+ decoder = Fea2GS_ROPE_AMP()
55
+ elif model_name in ['RDN_DIV2K', 'RDN_DF2K']:
56
+ encoder = RDNNOUP()
57
+ decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks = 2)
58
+ elif model_name in ['SwinIR_DIV2K', 'SwinIR_DF2K']:
59
+ encoder = SwinIRNOUP()
60
+ decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks=2, num_crossattn_layers=4, num_gs_seed=256, window_size=16)
61
+ elif model_name in ['HATL_SA1B']:
62
+ encoder = HATNOUP_ROPE_AMP()
63
+ decoder = Fea2GS_ROPE_AMP(channel=192, num_crossattn_blocks=4, num_crossattn_layers=4, num_selfattn_blocks=8, num_selfattn_layers=6,
64
+ num_gs_seed=256, window_size=16)
65
+ else:
66
+ raise ValueError(f"args.model-{model_name} must be in ['EDSR_DIV2K', 'EDSR_DF2K', 'RDN_DIV2K', 'RDN_DF2K', 'SwinIR_DIV2K', 'SwinIR_DF2K', 'HATL_SA1B']")
67
+
68
+ encoder.load_state_dict(enc_weight, strict=True)
69
+ decoder.load_state_dict(dec_weight, strict=True)
70
+ encoder.eval()
71
+ decoder.eval()
72
+ encoder = encoder.to(device)
73
+ decoder = decoder.to(device)
74
+ return encoder, decoder
75
+
76
+
77
+ def preprocess(x, denominator=16):
78
+ """Preprocess image to ensure dimensions are multiples of denominator"""
79
+ _, c, h, w = x.shape
80
+ if h % denominator > 0:
81
+ pad_h = denominator - h % denominator
82
+ else:
83
+ pad_h = 0
84
+ if w % denominator > 0:
85
+ pad_w = denominator - w % denominator
86
+ else:
87
+ pad_w = 0
88
+ x_new = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
89
+ return x_new
90
+
91
+ def postprocess(x, gt_size_h, gt_size_w):
92
+ """Post-process by cropping to target size"""
93
+ x_new = x[:, :, :gt_size_h, :gt_size_w]
94
+ return x_new
95
+
96
+ def should_use_tile(image_height, image_width, threshold=1024):
97
+ """Determine if tile processing should be used based on image resolution"""
98
+ return max(image_height, image_width) > threshold
99
+
100
+ def set_stop_flag():
101
+ """Set the global stop flag to interrupt inference"""
102
+ global stop_inference
103
+ with inference_lock:
104
+ stop_inference = True
105
+ return "πŸ›‘ Stopping inference...", gr.update(interactive=False)
106
+
107
+ def reset_stop_flag():
108
+ """Reset the global stop flag"""
109
+ global stop_inference
110
+ with inference_lock:
111
+ stop_inference = False
112
+
113
+ def check_stop_flag():
114
+ """Check if inference should be stopped"""
115
+ global stop_inference
116
+ with inference_lock:
117
+ return stop_inference
118
+
119
+ @spaces.GPU
120
+ def super_resolution_inference(image, scale=4.0):
121
+ """Super-resolution inference function with automatic tile processing"""
122
+
123
+ # Check if gscuda setup has been run
124
+ setup_marker = ".setup_complete"
125
+ if not os.path.exists(setup_marker):
126
+ print("First run detected, installing dependencies...")
127
+ try:
128
+ # subprocess.check_call(["pip", "install", "-e", "."])
129
+ subprocess.check_call(["pip", "install", "dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl"])
130
+ # Create marker file to indicate setup is complete
131
+ with open(setup_marker, "w") as f:
132
+ f.write("Setup completed")
133
+ print("Setup completed successfully!")
134
+ except subprocess.CalledProcessError as e:
135
+ return None, f"❌ Setup failed with error: {e}", None
136
+
137
+
138
+
139
+ if image is None:
140
+ return None, "Please upload an image", None
141
+
142
+ # Load model
143
+ encoder, decoder = load_model(model_name="HATL_SA1B")
144
+
145
+ # Reset stop flag at the beginning
146
+ reset_stop_flag()
147
+
148
+ # Fixed parameters
149
+ tile_overlap = 16 # Fixed overlap size
150
+ crop_size = 8 # Fixed crop size
151
+ tile_size = 1024 # Fixed tile size for large images
152
+
153
+ try:
154
+ # Check for interruption
155
+ if check_stop_flag():
156
+ return None, "❌ Inference interrupted", None
157
+
158
+ # Convert PIL image to numpy array
159
+ img_np = np.array(image)
160
+ if len(img_np.shape) == 3:
161
+ img_np = img_np[:, :, [2, 1, 0]] # RGB to BGR
162
+
163
+ # Convert to tensor
164
+ img = torch.from_numpy(np.transpose(img_np.astype(np.float32) / 255., (2, 0, 1))).float()
165
+ img = img.unsqueeze(0).to(device)
166
+
167
+ # Check for interruption
168
+ if check_stop_flag():
169
+ return None, "❌ Inference interrupted", None
170
+
171
+ # Calculate target size
172
+ gt_size = [math.floor(scale * img.shape[2]), math.floor(scale * img.shape[3])]
173
+
174
+ # Determine if tile processing should be used
175
+ use_tile = should_use_tile(img.shape[2], img.shape[3])
176
+
177
+ # Force AMP mixed precision
178
+ with torch.inference_mode():
179
+ with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
180
+ # Check for interruption before main processing
181
+ if check_stop_flag():
182
+ return None, "❌ Inference interrupted", None
183
+
184
+ if use_tile:
185
+ # Use tile processing
186
+ assert tile_size % 16 == 0, f"tile_size-{tile_size} must be divisible by 16"
187
+ assert 2 * tile_overlap < tile_size, f"2 * tile_overlap must be less than tile_size"
188
+ assert 2 * crop_size <= tile_overlap, f"2 * crop_size must be less than or equal to tile_overlap"
189
+
190
+ with torch.no_grad():
191
+ output = split_and_joint_image(
192
+ lq=img,
193
+ scale_factor=scale,
194
+ split_size=tile_size,
195
+ overlap_size=tile_overlap,
196
+ model_g=encoder,
197
+ model_fea2gs=decoder,
198
+ crop_size=crop_size,
199
+ scale_modify=torch.tensor([scale, scale]),
200
+ default_step_size=1.2,
201
+ cuda_rendering=True,
202
+ mode='scale_modify',
203
+ if_dmax=True,
204
+ dmax_mode='fix',
205
+ dmax=0.1
206
+ )
207
+ else:
208
+ # Direct processing without tiles
209
+ lq_pad = preprocess(img, 16) # denominator=16 for HATL
210
+ gt_size_pad = torch.tensor([math.floor(scale * lq_pad.shape[2]),
211
+ math.floor(scale * lq_pad.shape[3])])
212
+ gt_size_pad = gt_size_pad.unsqueeze(0)
213
+
214
+ with torch.no_grad():
215
+ # Check for interruption before encoder
216
+ if check_stop_flag():
217
+ return None, "❌ Inference interrupted", None
218
+
219
+ # Encoder output
220
+ encoder_output = encoder(lq_pad) # b,c,h,w
221
+
222
+ # Check for interruption before decoder
223
+ if check_stop_flag():
224
+ return None, "❌ Inference interrupted", None
225
+
226
+ scale_vector = torch.tensor(scale, dtype=torch.float32).unsqueeze(0).to(device)
227
+
228
+ # Decoder output
229
+ batch_gs_parameters = decoder(encoder_output, scale_vector)
230
+ gs_parameters = batch_gs_parameters[0, :]
231
+
232
+ # Check for interruption before gaussian rendering
233
+ if check_stop_flag():
234
+ return None, "❌ Inference interrupted", None
235
+
236
+ # Gaussian rendering
237
+ b_output = generate_2D_gaussian_splatting_step(
238
+ gs_parameters=gs_parameters,
239
+ sr_size=gt_size_pad[0],
240
+ scale=scale,
241
+ sample_coords=None,
242
+ scale_modify=torch.tensor([scale, scale]),
243
+ default_step_size=1.2,
244
+ cuda_rendering=True,
245
+ mode='scale_modify',
246
+ if_dmax=True,
247
+ dmax_mode='fix',
248
+ dmax=0.1
249
+ )
250
+ output = b_output.unsqueeze(0)
251
+
252
+ # Check for interruption before post-processing
253
+ if check_stop_flag():
254
+ return None, "❌ Inference interrupted", None
255
+
256
+ # Post-processing
257
+ output = postprocess(output, gt_size[0], gt_size[1])
258
+
259
+ # Convert back to PIL image format
260
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
261
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # BGR to RGB
262
+ output = (output * 255.0).round().astype(np.uint8)
263
+
264
+ # Convert to PIL image
265
+ output_pil = Image.fromarray(output)
266
+
267
+ # Generate result information
268
+ original_size = f"{img.shape[3]}x{img.shape[2]}"
269
+ output_size = f"{output.shape[1]}x{output.shape[0]}"
270
+ tile_info = f"Tile processing enabled (size: {tile_size})" if use_tile else "Direct processing (no tiles)"
271
+ result_info = f"βœ… Processing completed successfully!\nOriginal size: {original_size}\nSuper-resolution size: {output_size}\nScale factor: {scale:.2f}x\nProcessing mode: {tile_info}\nAMP acceleration: Force enabled\nOverlap size: {tile_overlap}\nCrop size: {crop_size}"
272
+
273
+ return output_pil, result_info, output_pil
274
+
275
+ except Exception as e:
276
+ if check_stop_flag():
277
+ return None, "❌ Inference interrupted", None
278
+ return None, f"❌ Error during processing: {str(e)}", None
279
+
280
+ def predict(image, scale):
281
+ """Gradio prediction function"""
282
+ output_image, info, download_image = super_resolution_inference(image, scale)
283
+
284
+ # If processing successful, save image for download
285
+ if output_image is not None:
286
+ # Create temporary filename
287
+ timestamp = int(time.time())
288
+ temp_filename = f"GSASR_SR_result_{scale}x_{timestamp}.png"
289
+ temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
290
+
291
+ # Save image
292
+ output_image.save(temp_path, "PNG")
293
+
294
+ return output_image, temp_path, "βœ… Ready", gr.update(interactive=True)
295
+ else:
296
+ return output_image, None, info if info else "❌ Processing failed", gr.update(interactive=True)
297
+
298
+ # Create Gradio interface
299
+ with gr.Blocks(title="πŸš€ GSASR (2D Gaussian Splatting Super-Resolution)") as demo:
300
+ gr.Markdown("# **πŸš€ GSASR (Generalized and efficient 2d gaussian splatting for arbitrary-scale super-resolution)**")
301
+ gr.Markdown("Official demo for GSASR. Please refer to our [paper](https://arxiv.org/pdf/2501.06838), [project page](https://mt-cly.github.io/GSASR.github.io/), and [github](https://github.com/ChrisDud0257/GSASR) for more details.")
302
+
303
+ with gr.Row():
304
+ with gr.Column():
305
+ input_image = gr.Image(type="pil", label="Input Image")
306
+
307
+ # Scale parameters
308
+ with gr.Group():
309
+ gr.Markdown("### SR Scale")
310
+ scale_slider = gr.Slider(minimum=1.0, maximum=30.0, value=4.0, step=0.1, label="SR Scale")
311
+
312
+ # Control buttons
313
+ with gr.Row():
314
+ submit_btn = gr.Button("πŸš€ Start Super-Resolution", variant="primary")
315
+ stop_btn = gr.Button("πŸ›‘ Stop Inference", variant="stop")
316
+
317
+ with gr.Column():
318
+ output_image = gr.Image(type="pil", label="Super-Resolution Result")
319
+
320
+ # Status display
321
+ status_text = gr.Textbox(label="Status", value="βœ… Ready", interactive=False)
322
+
323
+ # Download component
324
+ with gr.Group():
325
+ gr.Markdown("### πŸ“₯ Download Super-Resolution Result")
326
+ download_btn = gr.File(visible=True)
327
+
328
+ # Event handlers
329
+ submit_event = submit_btn.click(
330
+ fn=predict,
331
+ inputs=[input_image, scale_slider],
332
+ outputs=[output_image, download_btn, status_text, stop_btn]
333
+ )
334
+
335
+ stop_btn.click(
336
+ fn=set_stop_flag,
337
+ inputs=[],
338
+ outputs=[status_text, stop_btn],
339
+ cancels=[submit_event]
340
+ )
341
+
342
+ # Example images
343
+ gr.Markdown("### πŸ“š Example Images")
344
+ gr.Markdown("Try these examples with different scales:")
345
+
346
+ gr.Examples(
347
+ examples=[
348
+ ["assets/0846x4.png", 1.5],
349
+ ["assets/0892x4.png", 2.8],
350
+ ["assets/0873x4_cropped_120x120.png", 30.0]
351
+ ],
352
+ inputs=[input_image, scale_slider],
353
+ examples_per_page=3,
354
+ cache_examples=False,
355
+ label="Examples"
356
+ )
357
+
358
+ if __name__ == "__main__":
359
+ demo.launch(share=True, server_name="0.0.0.0")