File size: 14,471 Bytes
909940e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51153c1
909940e
 
51153c1
909940e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359


import torch
import numpy as np
import gradio as gr
from PIL import Image
import math
import torch.nn.functional as F
import os
import tempfile
import time
import threading

from utils.hatropeamp import HATNOUP_ROPE_AMP
from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
from utils.edsrbaseline import EDSRNOUP
from utils.hatropeamp import HATNOUP_ROPE_AMP
from utils.rdn import RDNNOUP
from utils.swinir import SwinIRNOUP
from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
from utils.gaussian_splatting import generate_2D_gaussian_splatting_step
from utils.split_and_joint_image import split_and_joint_image
from huggingface_hub import hf_hub_download
import subprocess
import sys
import spaces



# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Global stop flag for interrupting inference
stop_inference = False
inference_lock = threading.Lock()

def load_model(
    pretrained_model_name_or_path: str = "mutou0308/GSASR",
    model_name: str = "HATL_SA1B",
    device: str | torch.device = "cuda"
):
    enc_path = hf_hub_download(
            repo_id=pretrained_model_name_or_path, filename=os.path.join('GSASR_enhenced_ultra', model_name, 'encoder.pth')
        )
    dec_path = hf_hub_download(
            repo_id=pretrained_model_name_or_path, filename=os.path.join('GSASR_enhenced_ultra', model_name,  'decoder.pth')
        )

    enc_weight = torch.load(enc_path, weights_only=True)['params_ema']
    dec_weight = torch.load(dec_path, weights_only=True)['params_ema']
    
    if model_name in ['EDSR_DIV2K', 'EDSR_DF2K']:
        encoder = EDSRNOUP()
        decoder = Fea2GS_ROPE_AMP()
    elif model_name in ['RDN_DIV2K', 'RDN_DF2K']:
        encoder = RDNNOUP()
        decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks = 2)
    elif model_name in ['SwinIR_DIV2K', 'SwinIR_DF2K']:
        encoder = SwinIRNOUP()
        decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks=2, num_crossattn_layers=4, num_gs_seed=256, window_size=16)
    elif model_name in ['HATL_SA1B']:
        encoder = HATNOUP_ROPE_AMP()
        decoder = Fea2GS_ROPE_AMP(channel=192, num_crossattn_blocks=4, num_crossattn_layers=4, num_selfattn_blocks=8, num_selfattn_layers=6,
                                  num_gs_seed=256, window_size=16)
    else:
        raise ValueError(f"args.model-{model_name} must be in ['EDSR_DIV2K', 'EDSR_DF2K', 'RDN_DIV2K', 'RDN_DF2K', 'SwinIR_DIV2K', 'SwinIR_DF2K', 'HATL_SA1B']")

    encoder.load_state_dict(enc_weight, strict=True)
    decoder.load_state_dict(dec_weight, strict=True)
    encoder.eval()
    decoder.eval()
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    return encoder, decoder


def preprocess(x, denominator=16):
    """Preprocess image to ensure dimensions are multiples of denominator"""
    _, c, h, w = x.shape
    if h % denominator > 0:
        pad_h = denominator - h % denominator
    else:
        pad_h = 0
    if w % denominator > 0:
        pad_w = denominator - w % denominator
    else:
        pad_w = 0
    x_new = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
    return x_new

def postprocess(x, gt_size_h, gt_size_w):
    """Post-process by cropping to target size"""
    x_new = x[:, :, :gt_size_h, :gt_size_w]
    return x_new

def should_use_tile(image_height, image_width, threshold=1024):
    """Determine if tile processing should be used based on image resolution"""
    return max(image_height, image_width) > threshold

def set_stop_flag():
    """Set the global stop flag to interrupt inference"""
    global stop_inference
    with inference_lock:
        stop_inference = True
    return "πŸ›‘ Stopping inference...", gr.update(interactive=False)

def reset_stop_flag():
    """Reset the global stop flag"""
    global stop_inference
    with inference_lock:
        stop_inference = False

def check_stop_flag():
    """Check if inference should be stopped"""
    global stop_inference
    with inference_lock:
        return stop_inference
    
@spaces.GPU
def super_resolution_inference(image, scale=4.0):
    """Super-resolution inference function with automatic tile processing"""

    # Check if gscuda setup has been run
    setup_marker = ".setup_complete"
    if not os.path.exists(setup_marker):
        print("First run detected, installing dependencies...")
        try:
            # subprocess.check_call(["pip", "install", "-e", "."])
            subprocess.check_call(["pip", "install", "dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl"])
            # Create marker file to indicate setup is complete
            with open(setup_marker, "w") as f:
                f.write("Setup completed")
            print("Setup completed successfully!")
        except subprocess.CalledProcessError as e:
            return None, f"❌ Setup failed with error: {e}", None


        
    if image is None:
        return None, "Please upload an image", None
    
    # Load model
    encoder, decoder = load_model(model_name="HATL_SA1B")

    # Reset stop flag at the beginning
    reset_stop_flag()
    
    # Fixed parameters
    tile_overlap = 16  # Fixed overlap size
    crop_size = 8     # Fixed crop size
    tile_size = 1024   # Fixed tile size for large images
    
    try:
        # Check for interruption
        if check_stop_flag():
            return None, "❌ Inference interrupted", None
            
        # Convert PIL image to numpy array
        img_np = np.array(image)
        if len(img_np.shape) == 3:
            img_np = img_np[:, :, [2, 1, 0]]  # RGB to BGR
        
        # Convert to tensor
        img = torch.from_numpy(np.transpose(img_np.astype(np.float32) / 255., (2, 0, 1))).float()
        img = img.unsqueeze(0).to(device)
        
        # Check for interruption
        if check_stop_flag():
            return None, "❌ Inference interrupted", None
        
        # Calculate target size
        gt_size = [math.floor(scale * img.shape[2]), math.floor(scale * img.shape[3])]
        
        # Determine if tile processing should be used
        use_tile = should_use_tile(img.shape[2], img.shape[3])
        
        # Force AMP mixed precision
        with torch.inference_mode():
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                # Check for interruption before main processing
                if check_stop_flag():
                    return None, "❌ Inference interrupted", None
                    
                if use_tile:
                    # Use tile processing
                    assert tile_size % 16 == 0, f"tile_size-{tile_size} must be divisible by 16"
                    assert 2 * tile_overlap < tile_size, f"2 * tile_overlap must be less than tile_size"
                    assert 2 * crop_size <= tile_overlap, f"2 * crop_size must be less than or equal to tile_overlap"
                    
                    with torch.no_grad():
                        output = split_and_joint_image(
                            lq=img, 
                            scale_factor=scale,
                            split_size=tile_size,
                            overlap_size=tile_overlap,
                            model_g=encoder,
                            model_fea2gs=decoder,
                            crop_size=crop_size,
                            scale_modify=torch.tensor([scale, scale]),
                            default_step_size=1.2,
                            cuda_rendering=True,
                            mode='scale_modify',
                            if_dmax=True,
                            dmax_mode='fix',
                            dmax=0.1
                        )
                else:
                    # Direct processing without tiles
                    lq_pad = preprocess(img, 16)  # denominator=16 for HATL
                    gt_size_pad = torch.tensor([math.floor(scale * lq_pad.shape[2]), 
                                            math.floor(scale * lq_pad.shape[3])])
                    gt_size_pad = gt_size_pad.unsqueeze(0)
                    
                    with torch.no_grad():
                        # Check for interruption before encoder
                        if check_stop_flag():
                            return None, "❌ Inference interrupted", None
                            
                        # Encoder output
                        encoder_output = encoder(lq_pad)  # b,c,h,w
                        
                        # Check for interruption before decoder
                        if check_stop_flag():
                            return None, "❌ Inference interrupted", None
                            
                        scale_vector = torch.tensor(scale, dtype=torch.float32).unsqueeze(0).to(device)
                        
                        # Decoder output
                        batch_gs_parameters = decoder(encoder_output, scale_vector)
                        gs_parameters = batch_gs_parameters[0, :]
                        
                        # Check for interruption before gaussian rendering
                        if check_stop_flag():
                            return None, "❌ Inference interrupted", None
                        
                        # Gaussian rendering
                        b_output = generate_2D_gaussian_splatting_step(
                            gs_parameters=gs_parameters,
                            sr_size=gt_size_pad[0],
                            scale=scale,
                            sample_coords=None,
                            scale_modify=torch.tensor([scale, scale]),
                            default_step_size=1.2,
                            cuda_rendering=True,
                            mode='scale_modify',
                            if_dmax=True,
                            dmax_mode='fix',
                            dmax=0.1
                        )
                        output = b_output.unsqueeze(0)
        
        # Check for interruption before post-processing
        if check_stop_flag():
            return None, "❌ Inference interrupted", None
        
        # Post-processing
        output = postprocess(output, gt_size[0], gt_size[1])
        
        # Convert back to PIL image format
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # BGR to RGB
        output = (output * 255.0).round().astype(np.uint8)
        
        # Convert to PIL image
        output_pil = Image.fromarray(output)
        
        # Generate result information
        original_size = f"{img.shape[3]}x{img.shape[2]}"
        output_size = f"{output.shape[1]}x{output.shape[0]}"
        tile_info = f"Tile processing enabled (size: {tile_size})" if use_tile else "Direct processing (no tiles)"
        result_info = f"βœ… Processing completed successfully!\nOriginal size: {original_size}\nSuper-resolution size: {output_size}\nScale factor: {scale:.2f}x\nProcessing mode: {tile_info}\nAMP acceleration: Force enabled\nOverlap size: {tile_overlap}\nCrop size: {crop_size}"
        
        return output_pil, result_info, output_pil
        
    except Exception as e:
        if check_stop_flag():
            return None, "❌ Inference interrupted", None
        return None, f"❌ Error during processing: {str(e)}", None

def predict(image, scale):
    """Gradio prediction function"""
    output_image, info, download_image = super_resolution_inference(image, scale)
    
    # If processing successful, save image for download
    if output_image is not None:
        # Create temporary filename
        timestamp = int(time.time())
        temp_filename = f"GSASR_SR_result_{scale}x_{timestamp}.png"
        temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
        
        # Save image
        output_image.save(temp_path, "PNG")
        
        return output_image, temp_path, "βœ… Ready", gr.update(interactive=True)
    else:
        return output_image, None, info if info else "❌ Processing failed", gr.update(interactive=True)

# Create Gradio interface
with gr.Blocks(title="πŸš€ GSASR (2D Gaussian Splatting Super-Resolution)") as demo:
    gr.Markdown("# **πŸš€ GSASR (Generalized and efficient 2d gaussian splatting for arbitrary-scale super-resolution)**")
    gr.Markdown("Official demo for GSASR. Please refer to our [paper](https://arxiv.org/pdf/2501.06838), [project page](https://mt-cly.github.io/GSASR.github.io/), and [github](https://github.com/ChrisDud0257/GSASR) for more details.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            
            # Scale parameters
            with gr.Group():
                gr.Markdown("### SR Scale")
                scale_slider = gr.Slider(minimum=1.0, maximum=30.0, value=4.0, step=0.1, label="SR Scale")
            
            # Control buttons
            with gr.Row():
                submit_btn = gr.Button("πŸš€ Start Super-Resolution", variant="primary")
                stop_btn = gr.Button("πŸ›‘ Stop Inference", variant="stop")
        
        with gr.Column():
            output_image = gr.Image(type="pil", label="Super-Resolution Result")
            
            # Status display
            status_text = gr.Textbox(label="Status", value="βœ… Ready", interactive=False)
            
            # Download component
            with gr.Group():
                gr.Markdown("### πŸ“₯ Download Super-Resolution Result")
                download_btn = gr.File(visible=True)
    
    # Event handlers
    submit_event = submit_btn.click(
        fn=predict,
        inputs=[input_image, scale_slider],
        outputs=[output_image, download_btn, status_text, stop_btn]
    )
    
    stop_btn.click(
        fn=set_stop_flag,
        inputs=[],
        outputs=[status_text, stop_btn],
        cancels=[submit_event]
    )
    
    # Example images
    gr.Markdown("### πŸ“š Example Images")
    gr.Markdown("Try these examples with different scales:")
    
    gr.Examples(
        examples=[
            ["assets/0846x4.png", 1.5],
            ["assets/0892x4.png", 2.8],
            ["assets/0873x4_cropped_120x120.png", 30.0]
        ],
        inputs=[input_image, scale_slider],
        examples_per_page=3,
        cache_examples=False,
        label="Examples"
    )

if __name__ == "__main__":
    demo.launch(share=True, server_name="0.0.0.0")