|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from pipeline import Pipeline |
|
from replacements import get_foreground_estimation, sky_replacement |
|
|
|
SHARE_REPO = False |
|
|
|
pipeline = Pipeline(model_name="swin_small_patch4_window7_224") |
|
|
|
|
|
def predict(image): |
|
|
|
predicted_alpha = pipeline.inference(image) |
|
h, w = predicted_alpha.shape |
|
|
|
|
|
foreground = get_foreground_estimation(image, predicted_alpha) |
|
replaced_sky = sky_replacement(foreground, predicted_alpha) |
|
|
|
|
|
predicted_alpha_pil = Image.fromarray((predicted_alpha * 255).astype(np.uint8), mode='L') |
|
predicted_alpha_pil = predicted_alpha_pil.resize((h, w), Image.Resampling.LANCZOS) |
|
replaced_sky_pil = Image.fromarray((replaced_sky * 255).astype(np.uint8)) |
|
replaced_sky_pil = replaced_sky_pil.resize((h, w), Image.Resampling.LANCZOS) |
|
|
|
return predicted_alpha_pil, replaced_sky_pil |
|
|
|
|
|
real_example_list = [ |
|
["examples/real/1901.jpg", "Real", "Good"], |
|
["examples/real/2022.jpg", "Real", "Good"], |
|
["examples/real/2041.jpg", "Real", "Good"], |
|
["examples/real/2196.jpg", "Real", "Good"], |
|
["examples/real/2188.jpg", "Real", "Good"], |
|
["examples/real/0001.jpg", "Real", "Acceptable, missing minor detail around the lamppost"], |
|
["examples/real/0054.jpg", "Real", "Acceptable, missing sky details between the houses"], |
|
["examples/real/2043.jpg", "Real", "Acceptable, missing minor detail in the window in the background"], |
|
["examples/real/0211.jpg", "Real", "Okay, misclassified a cloud in the left corner as the sky"], |
|
["examples/real/0894.jpg", "Real", "Okay, missing details in the trees"], |
|
["examples/real/2184.jpg", "Real", "Okay, lacks tree details in the background"], |
|
["examples/real/2026.jpg", "Real", "Okay, lacks tree details in the left background"], |
|
["examples/real/1975.jpg", "Real", "Okay, lacks tree branch details"], |
|
["examples/real/0069.jpg", "Real", "Bad, didn't replace the sky between the houses"], |
|
["examples/real/2079.jpg", "Real", "Bad, couldn't get the complete details of the tree"], |
|
["examples/real/2038.jpg", "Real", "Bad, lacks overall details in both trees and tree branches"], |
|
] |
|
|
|
synthetic_example_list = [ |
|
["examples/synthetic/0055.jpg", "Synthetic", "Good"], |
|
["examples/synthetic/0059.jpg", "Synthetic", "Good"], |
|
["examples/synthetic/0086.jpg", "Synthetic", "Good"], |
|
["examples/synthetic/10406.jpg", "Synthetic", "Good"], |
|
["examples/synthetic/10515.jpg", "Synthetic", "Good"], |
|
["examples/synthetic/10416.jpg", "Synthetic", "Acceptable, missing minor detail in the tree leaves"], |
|
["examples/synthetic/0150.jpg", "Synthetic", "Acceptable, missing minor detail in the tree"], |
|
["examples/synthetic/0097.jpg", "Synthetic", "Okay, missing minor detail in the trees"], |
|
["examples/synthetic/0124.jpg", "Synthetic", "Okay, missing minor detail in the trees"], |
|
["examples/synthetic/0127.jpg", "Synthetic", "Bad, missing many details in the trees"], |
|
["examples/synthetic/10467.jpg", "Synthetic", "Bad, misclassified the windows as sky"], |
|
] |
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as demo: |
|
gr.Markdown( |
|
""" |
|
# Demo: Sky Replacement with Alpha Matting |
|
This demo performs alpha matting and sky replacements for houses using a U-Net architecture with a Swin backbone. \t |
|
This model is trained solely on synthetic data generated using Blender. \n |
|
Upload an image to perform sky replacement. |
|
""" |
|
) |
|
|
|
data_type = gr.Radio(choices=["Real", "Synthetic"], value="Real", label="Select Data Type for Examples") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
input_image = gr.Image(type="pil", label="Input Image") |
|
with gr.Row(): |
|
clear_button = gr.Button("Clear") |
|
run_button = gr.Button("Submit", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
output_mask = gr.Image(type="pil", label="Predicted Mask") |
|
output_sky = gr.Image(type="pil", label="Sky Replacement") |
|
|
|
metadata_display = gr.Markdown(None) |
|
|
|
with gr.Column(visible=True) as real_examples_container: |
|
real_examples_component = gr.Examples( |
|
examples=real_example_list, |
|
inputs=[input_image, |
|
gr.Textbox(label="Data Type", value="", interactive=False, visible=False), |
|
gr.Textbox(label="Result", value="", interactive=False, visible=False)], |
|
outputs=[input_image, metadata_display], |
|
fn=lambda example, dtype, desc: (example, f"**Type:** {dtype}\n\n**Result:** {desc}"), |
|
cache_examples=False, |
|
label="Real Data Examples" |
|
) |
|
|
|
with gr.Column(visible=False) as synthetic_examples_container: |
|
synthetic_examples_component = gr.Examples( |
|
examples=synthetic_example_list, |
|
inputs=[input_image, |
|
gr.Textbox(label="Data Type", value="", interactive=False, visible=False), |
|
gr.Textbox(label="Result", value="", interactive=False, visible=False)], |
|
outputs=[input_image, metadata_display], |
|
fn=lambda example, dtype, desc: (example, f"**Type:** {dtype}\n\n**Result:** {desc}"), |
|
cache_examples=False, |
|
label="Synthetic Data Examples" |
|
) |
|
|
|
|
|
|
|
def switch_examples(selected): |
|
if selected == "Real": |
|
return gr.update(visible=True), gr.update(visible=False) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=True) |
|
|
|
|
|
data_type.change( |
|
fn=switch_examples, |
|
inputs=data_type, |
|
outputs=[real_examples_container, synthetic_examples_container] |
|
) |
|
|
|
|
|
def clear_all(): |
|
return gr.update(value=None), gr.update(value=None), gr.update(value=None) |
|
|
|
|
|
clear_button.click(fn=clear_all, inputs=[], outputs=[input_image, output_mask, output_sky]) |
|
run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky]) |
|
|
|
|
|
demo.launch(share=SHARE_REPO, ssr_mode=False) |
|
|