Svane20's picture
Updated model to use PyTorch instead of ONNX
f28556a
import gradio as gr
import numpy as np
from PIL import Image
from pipeline import Pipeline
from replacements import get_foreground_estimation, sky_replacement
SHARE_REPO = False
pipeline = Pipeline(model_name="swin_small_patch4_window7_224")
def predict(image):
# Run inference to get the predicted alpha mask
predicted_alpha = pipeline.inference(image)
h, w = predicted_alpha.shape
# Estimate foreground and run sky_replacement
foreground = get_foreground_estimation(image, predicted_alpha)
replaced_sky = sky_replacement(foreground, predicted_alpha)
# Resize the predicted alpha and replaced sky to original dimensions
predicted_alpha_pil = Image.fromarray((predicted_alpha * 255).astype(np.uint8), mode='L')
predicted_alpha_pil = predicted_alpha_pil.resize((h, w), Image.Resampling.LANCZOS)
replaced_sky_pil = Image.fromarray((replaced_sky * 255).astype(np.uint8))
replaced_sky_pil = replaced_sky_pil.resize((h, w), Image.Resampling.LANCZOS)
return predicted_alpha_pil, replaced_sky_pil
real_example_list = [
["examples/real/1901.jpg", "Real", "Good"],
["examples/real/2022.jpg", "Real", "Good"],
["examples/real/2041.jpg", "Real", "Good"],
["examples/real/2196.jpg", "Real", "Good"],
["examples/real/2188.jpg", "Real", "Good"],
["examples/real/0001.jpg", "Real", "Acceptable, missing minor detail around the lamppost"],
["examples/real/0054.jpg", "Real", "Acceptable, missing sky details between the houses"],
["examples/real/2043.jpg", "Real", "Acceptable, missing minor detail in the window in the background"],
["examples/real/0211.jpg", "Real", "Okay, misclassified a cloud in the left corner as the sky"],
["examples/real/0894.jpg", "Real", "Okay, missing details in the trees"],
["examples/real/2184.jpg", "Real", "Okay, lacks tree details in the background"],
["examples/real/2026.jpg", "Real", "Okay, lacks tree details in the left background"],
["examples/real/1975.jpg", "Real", "Okay, lacks tree branch details"],
["examples/real/0069.jpg", "Real", "Bad, didn't replace the sky between the houses"],
["examples/real/2079.jpg", "Real", "Bad, couldn't get the complete details of the tree"],
["examples/real/2038.jpg", "Real", "Bad, lacks overall details in both trees and tree branches"],
]
synthetic_example_list = [
["examples/synthetic/0055.jpg", "Synthetic", "Good"],
["examples/synthetic/0059.jpg", "Synthetic", "Good"],
["examples/synthetic/0086.jpg", "Synthetic", "Good"],
["examples/synthetic/10406.jpg", "Synthetic", "Good"],
["examples/synthetic/10515.jpg", "Synthetic", "Good"],
["examples/synthetic/10416.jpg", "Synthetic", "Acceptable, missing minor detail in the tree leaves"],
["examples/synthetic/0150.jpg", "Synthetic", "Acceptable, missing minor detail in the tree"],
["examples/synthetic/0097.jpg", "Synthetic", "Okay, missing minor detail in the trees"],
["examples/synthetic/0124.jpg", "Synthetic", "Okay, missing minor detail in the trees"],
["examples/synthetic/0127.jpg", "Synthetic", "Bad, missing many details in the trees"],
["examples/synthetic/10467.jpg", "Synthetic", "Bad, misclassified the windows as sky"],
]
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown(
"""
# Demo: Sky Replacement with Alpha Matting
This demo performs alpha matting and sky replacements for houses using a U-Net architecture with a Swin backbone. \t
This model is trained solely on synthetic data generated using Blender. \n
Upload an image to perform sky replacement.
"""
)
data_type = gr.Radio(choices=["Real", "Synthetic"], value="Real", label="Select Data Type for Examples")
with gr.Row():
# Left Column: Input Image and Run/Clear Buttons
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
with gr.Row():
clear_button = gr.Button("Clear")
run_button = gr.Button("Submit", variant="primary")
# Right Column: Output Images
with gr.Column(scale=1):
output_mask = gr.Image(type="pil", label="Predicted Mask")
output_sky = gr.Image(type="pil", label="Sky Replacement")
metadata_display = gr.Markdown(None)
with gr.Column(visible=True) as real_examples_container:
real_examples_component = gr.Examples(
examples=real_example_list,
inputs=[input_image,
gr.Textbox(label="Data Type", value="", interactive=False, visible=False),
gr.Textbox(label="Result", value="", interactive=False, visible=False)],
outputs=[input_image, metadata_display],
fn=lambda example, dtype, desc: (example, f"**Type:** {dtype}\n\n**Result:** {desc}"),
cache_examples=False,
label="Real Data Examples"
)
with gr.Column(visible=False) as synthetic_examples_container:
synthetic_examples_component = gr.Examples(
examples=synthetic_example_list,
inputs=[input_image,
gr.Textbox(label="Data Type", value="", interactive=False, visible=False),
gr.Textbox(label="Result", value="", interactive=False, visible=False)],
outputs=[input_image, metadata_display],
fn=lambda example, dtype, desc: (example, f"**Type:** {dtype}\n\n**Result:** {desc}"),
cache_examples=False,
label="Synthetic Data Examples"
)
# Callback to toggle the container visibility based on selection.
def switch_examples(selected):
if selected == "Real":
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
data_type.change(
fn=switch_examples,
inputs=data_type,
outputs=[real_examples_container, synthetic_examples_container]
)
def clear_all():
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
clear_button.click(fn=clear_all, inputs=[], outputs=[input_image, output_mask, output_sky])
run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky])
# Launch the interface
demo.launch(share=SHARE_REPO, ssr_mode=False)