|
|
|
|
|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
from PIL import Image |
|
import math |
|
import torch.nn.functional as F |
|
import os |
|
import tempfile |
|
import time |
|
import threading |
|
|
|
from utils.hatropeamp import HATNOUP_ROPE_AMP |
|
from utils.fea2gsropeamp import Fea2GS_ROPE_AMP |
|
from utils.edsrbaseline import EDSRNOUP |
|
from utils.hatropeamp import HATNOUP_ROPE_AMP |
|
from utils.rdn import RDNNOUP |
|
from utils.swinir import SwinIRNOUP |
|
from utils.fea2gsropeamp import Fea2GS_ROPE_AMP |
|
from utils.gaussian_splatting import generate_2D_gaussian_splatting_step |
|
from utils.split_and_joint_image import split_and_joint_image |
|
from huggingface_hub import hf_hub_download |
|
import subprocess |
|
import sys |
|
import spaces |
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
stop_inference = False |
|
inference_lock = threading.Lock() |
|
|
|
def load_model( |
|
pretrained_model_name_or_path: str = "mutou0308/GSASR", |
|
model_name: str = "HATL_SA1B", |
|
device: str | torch.device = "cuda" |
|
): |
|
enc_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, filename=os.path.join('GSASR_enhenced_ultra', model_name, 'encoder.pth') |
|
) |
|
dec_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, filename=os.path.join('GSASR_enhenced_ultra', model_name, 'decoder.pth') |
|
) |
|
|
|
enc_weight = torch.load(enc_path, weights_only=True)['params_ema'] |
|
dec_weight = torch.load(dec_path, weights_only=True)['params_ema'] |
|
|
|
if model_name in ['EDSR_DIV2K', 'EDSR_DF2K']: |
|
encoder = EDSRNOUP() |
|
decoder = Fea2GS_ROPE_AMP() |
|
elif model_name in ['RDN_DIV2K', 'RDN_DF2K']: |
|
encoder = RDNNOUP() |
|
decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks = 2) |
|
elif model_name in ['SwinIR_DIV2K', 'SwinIR_DF2K']: |
|
encoder = SwinIRNOUP() |
|
decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks=2, num_crossattn_layers=4, num_gs_seed=256, window_size=16) |
|
elif model_name in ['HATL_SA1B']: |
|
encoder = HATNOUP_ROPE_AMP() |
|
decoder = Fea2GS_ROPE_AMP(channel=192, num_crossattn_blocks=4, num_crossattn_layers=4, num_selfattn_blocks=8, num_selfattn_layers=6, |
|
num_gs_seed=256, window_size=16) |
|
else: |
|
raise ValueError(f"args.model-{model_name} must be in ['EDSR_DIV2K', 'EDSR_DF2K', 'RDN_DIV2K', 'RDN_DF2K', 'SwinIR_DIV2K', 'SwinIR_DF2K', 'HATL_SA1B']") |
|
|
|
encoder.load_state_dict(enc_weight, strict=True) |
|
decoder.load_state_dict(dec_weight, strict=True) |
|
encoder.eval() |
|
decoder.eval() |
|
encoder = encoder.to(device) |
|
decoder = decoder.to(device) |
|
return encoder, decoder |
|
|
|
|
|
def preprocess(x, denominator=16): |
|
"""Preprocess image to ensure dimensions are multiples of denominator""" |
|
_, c, h, w = x.shape |
|
if h % denominator > 0: |
|
pad_h = denominator - h % denominator |
|
else: |
|
pad_h = 0 |
|
if w % denominator > 0: |
|
pad_w = denominator - w % denominator |
|
else: |
|
pad_w = 0 |
|
x_new = F.pad(x, (0, pad_w, 0, pad_h), 'reflect') |
|
return x_new |
|
|
|
def postprocess(x, gt_size_h, gt_size_w): |
|
"""Post-process by cropping to target size""" |
|
x_new = x[:, :, :gt_size_h, :gt_size_w] |
|
return x_new |
|
|
|
def should_use_tile(image_height, image_width, threshold=1024): |
|
"""Determine if tile processing should be used based on image resolution""" |
|
return max(image_height, image_width) > threshold |
|
|
|
def set_stop_flag(): |
|
"""Set the global stop flag to interrupt inference""" |
|
global stop_inference |
|
with inference_lock: |
|
stop_inference = True |
|
return "π Stopping inference...", gr.update(interactive=False) |
|
|
|
def reset_stop_flag(): |
|
"""Reset the global stop flag""" |
|
global stop_inference |
|
with inference_lock: |
|
stop_inference = False |
|
|
|
def check_stop_flag(): |
|
"""Check if inference should be stopped""" |
|
global stop_inference |
|
with inference_lock: |
|
return stop_inference |
|
|
|
@spaces.GPU |
|
def super_resolution_inference(image, scale=4.0): |
|
"""Super-resolution inference function with automatic tile processing""" |
|
|
|
|
|
setup_marker = ".setup_complete" |
|
if not os.path.exists(setup_marker): |
|
print("First run detected, installing dependencies...") |
|
try: |
|
|
|
subprocess.check_call(["pip", "install", "dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl"]) |
|
|
|
with open(setup_marker, "w") as f: |
|
f.write("Setup completed") |
|
print("Setup completed successfully!") |
|
except subprocess.CalledProcessError as e: |
|
return None, f"β Setup failed with error: {e}", None |
|
|
|
|
|
|
|
if image is None: |
|
return None, "Please upload an image", None |
|
|
|
|
|
encoder, decoder = load_model(model_name="HATL_SA1B") |
|
|
|
|
|
reset_stop_flag() |
|
|
|
|
|
tile_overlap = 16 |
|
crop_size = 8 |
|
tile_size = 1024 |
|
|
|
try: |
|
|
|
if check_stop_flag(): |
|
return None, "β Inference interrupted", None |
|
|
|
|
|
img_np = np.array(image) |
|
if len(img_np.shape) == 3: |
|
img_np = img_np[:, :, [2, 1, 0]] |
|
|
|
|
|
img = torch.from_numpy(np.transpose(img_np.astype(np.float32) / 255., (2, 0, 1))).float() |
|
img = img.unsqueeze(0).to(device) |
|
|
|
|
|
if check_stop_flag(): |
|
return None, "β Inference interrupted", None |
|
|
|
|
|
gt_size = [math.floor(scale * img.shape[2]), math.floor(scale * img.shape[3])] |
|
|
|
|
|
use_tile = should_use_tile(img.shape[2], img.shape[3]) |
|
|
|
|
|
with torch.inference_mode(): |
|
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): |
|
|
|
if check_stop_flag(): |
|
return None, "β Inference interrupted", None |
|
|
|
if use_tile: |
|
|
|
assert tile_size % 16 == 0, f"tile_size-{tile_size} must be divisible by 16" |
|
assert 2 * tile_overlap < tile_size, f"2 * tile_overlap must be less than tile_size" |
|
assert 2 * crop_size <= tile_overlap, f"2 * crop_size must be less than or equal to tile_overlap" |
|
|
|
with torch.no_grad(): |
|
output = split_and_joint_image( |
|
lq=img, |
|
scale_factor=scale, |
|
split_size=tile_size, |
|
overlap_size=tile_overlap, |
|
model_g=encoder, |
|
model_fea2gs=decoder, |
|
crop_size=crop_size, |
|
scale_modify=torch.tensor([scale, scale]), |
|
default_step_size=1.2, |
|
cuda_rendering=True, |
|
mode='scale_modify', |
|
if_dmax=True, |
|
dmax_mode='fix', |
|
dmax=0.1 |
|
) |
|
else: |
|
|
|
lq_pad = preprocess(img, 16) |
|
gt_size_pad = torch.tensor([math.floor(scale * lq_pad.shape[2]), |
|
math.floor(scale * lq_pad.shape[3])]) |
|
gt_size_pad = gt_size_pad.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
|
|
if check_stop_flag(): |
|
return None, "β Inference interrupted", None |
|
|
|
|
|
encoder_output = encoder(lq_pad) |
|
|
|
|
|
if check_stop_flag(): |
|
return None, "β Inference interrupted", None |
|
|
|
scale_vector = torch.tensor(scale, dtype=torch.float32).unsqueeze(0).to(device) |
|
|
|
|
|
batch_gs_parameters = decoder(encoder_output, scale_vector) |
|
gs_parameters = batch_gs_parameters[0, :] |
|
|
|
|
|
if check_stop_flag(): |
|
return None, "β Inference interrupted", None |
|
|
|
|
|
b_output = generate_2D_gaussian_splatting_step( |
|
gs_parameters=gs_parameters, |
|
sr_size=gt_size_pad[0], |
|
scale=scale, |
|
sample_coords=None, |
|
scale_modify=torch.tensor([scale, scale]), |
|
default_step_size=1.2, |
|
cuda_rendering=True, |
|
mode='scale_modify', |
|
if_dmax=True, |
|
dmax_mode='fix', |
|
dmax=0.1 |
|
) |
|
output = b_output.unsqueeze(0) |
|
|
|
|
|
if check_stop_flag(): |
|
return None, "β Inference interrupted", None |
|
|
|
|
|
output = postprocess(output, gt_size[0], gt_size[1]) |
|
|
|
|
|
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() |
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) |
|
output = (output * 255.0).round().astype(np.uint8) |
|
|
|
|
|
output_pil = Image.fromarray(output) |
|
|
|
|
|
original_size = f"{img.shape[3]}x{img.shape[2]}" |
|
output_size = f"{output.shape[1]}x{output.shape[0]}" |
|
tile_info = f"Tile processing enabled (size: {tile_size})" if use_tile else "Direct processing (no tiles)" |
|
result_info = f"β
Processing completed successfully!\nOriginal size: {original_size}\nSuper-resolution size: {output_size}\nScale factor: {scale:.2f}x\nProcessing mode: {tile_info}\nAMP acceleration: Force enabled\nOverlap size: {tile_overlap}\nCrop size: {crop_size}" |
|
|
|
return output_pil, result_info, output_pil |
|
|
|
except Exception as e: |
|
if check_stop_flag(): |
|
return None, "β Inference interrupted", None |
|
return None, f"β Error during processing: {str(e)}", None |
|
|
|
def predict(image, scale): |
|
"""Gradio prediction function""" |
|
output_image, info, download_image = super_resolution_inference(image, scale) |
|
|
|
|
|
if output_image is not None: |
|
|
|
timestamp = int(time.time()) |
|
temp_filename = f"GSASR_SR_result_{scale}x_{timestamp}.png" |
|
temp_path = os.path.join(tempfile.gettempdir(), temp_filename) |
|
|
|
|
|
output_image.save(temp_path, "PNG") |
|
|
|
return output_image, temp_path, "β
Ready", gr.update(interactive=True) |
|
else: |
|
return output_image, None, info if info else "β Processing failed", gr.update(interactive=True) |
|
|
|
|
|
with gr.Blocks(title="π GSASR (2D Gaussian Splatting Super-Resolution)") as demo: |
|
gr.Markdown("# **π GSASR (Generalized and efficient 2d gaussian splatting for arbitrary-scale super-resolution)**") |
|
gr.Markdown("Official demo for GSASR. Please refer to our [paper](https://arxiv.org/pdf/2501.06838), [project page](https://mt-cly.github.io/GSASR.github.io/), and [github](https://github.com/ChrisDud0257/GSASR) for more details.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Input Image") |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("### SR Scale") |
|
scale_slider = gr.Slider(minimum=1.0, maximum=30.0, value=4.0, step=0.1, label="SR Scale") |
|
|
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("π Start Super-Resolution", variant="primary") |
|
stop_btn = gr.Button("π Stop Inference", variant="stop") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(type="pil", label="Super-Resolution Result") |
|
|
|
|
|
status_text = gr.Textbox(label="Status", value="β
Ready", interactive=False) |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("### π₯ Download Super-Resolution Result") |
|
download_btn = gr.File(visible=True) |
|
|
|
|
|
submit_event = submit_btn.click( |
|
fn=predict, |
|
inputs=[input_image, scale_slider], |
|
outputs=[output_image, download_btn, status_text, stop_btn] |
|
) |
|
|
|
stop_btn.click( |
|
fn=set_stop_flag, |
|
inputs=[], |
|
outputs=[status_text, stop_btn], |
|
cancels=[submit_event] |
|
) |
|
|
|
|
|
gr.Markdown("### π Example Images") |
|
gr.Markdown("Try these examples with different scales:") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["assets/0846x4.png", 1.5], |
|
["assets/0892x4.png", 2.8], |
|
["assets/0873x4_cropped_120x120.png", 30.0] |
|
], |
|
inputs=[input_image, scale_slider], |
|
examples_per_page=3, |
|
cache_examples=False, |
|
label="Examples" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True, server_name="0.0.0.0") |