import gradio as gr import numpy as np from PIL import Image from pipeline import Pipeline from replacements import get_foreground_estimation, sky_replacement SHARE_REPO = False pipeline = Pipeline(model_name="swin_small_patch4_window7_224") def predict(image): # Run inference to get the predicted alpha mask predicted_alpha = pipeline.inference(image) h, w = predicted_alpha.shape # Estimate foreground and run sky_replacement foreground = get_foreground_estimation(image, predicted_alpha) replaced_sky = sky_replacement(foreground, predicted_alpha) # Resize the predicted alpha and replaced sky to original dimensions predicted_alpha_pil = Image.fromarray((predicted_alpha * 255).astype(np.uint8), mode='L') predicted_alpha_pil = predicted_alpha_pil.resize((h, w), Image.Resampling.LANCZOS) replaced_sky_pil = Image.fromarray((replaced_sky * 255).astype(np.uint8)) replaced_sky_pil = replaced_sky_pil.resize((h, w), Image.Resampling.LANCZOS) return predicted_alpha_pil, replaced_sky_pil real_example_list = [ ["examples/real/1901.jpg", "Real", "Good"], ["examples/real/2022.jpg", "Real", "Good"], ["examples/real/2041.jpg", "Real", "Good"], ["examples/real/2196.jpg", "Real", "Good"], ["examples/real/2188.jpg", "Real", "Good"], ["examples/real/0001.jpg", "Real", "Acceptable, missing minor detail around the lamppost"], ["examples/real/0054.jpg", "Real", "Acceptable, missing sky details between the houses"], ["examples/real/2043.jpg", "Real", "Acceptable, missing minor detail in the window in the background"], ["examples/real/0211.jpg", "Real", "Okay, misclassified a cloud in the left corner as the sky"], ["examples/real/0894.jpg", "Real", "Okay, missing details in the trees"], ["examples/real/2184.jpg", "Real", "Okay, lacks tree details in the background"], ["examples/real/2026.jpg", "Real", "Okay, lacks tree details in the left background"], ["examples/real/1975.jpg", "Real", "Okay, lacks tree branch details"], ["examples/real/0069.jpg", "Real", "Bad, didn't replace the sky between the houses"], ["examples/real/2079.jpg", "Real", "Bad, couldn't get the complete details of the tree"], ["examples/real/2038.jpg", "Real", "Bad, lacks overall details in both trees and tree branches"], ] synthetic_example_list = [ ["examples/synthetic/0055.jpg", "Synthetic", "Good"], ["examples/synthetic/0059.jpg", "Synthetic", "Good"], ["examples/synthetic/0086.jpg", "Synthetic", "Good"], ["examples/synthetic/10406.jpg", "Synthetic", "Good"], ["examples/synthetic/10515.jpg", "Synthetic", "Good"], ["examples/synthetic/10416.jpg", "Synthetic", "Acceptable, missing minor detail in the tree leaves"], ["examples/synthetic/0150.jpg", "Synthetic", "Acceptable, missing minor detail in the tree"], ["examples/synthetic/0097.jpg", "Synthetic", "Okay, missing minor detail in the trees"], ["examples/synthetic/0124.jpg", "Synthetic", "Okay, missing minor detail in the trees"], ["examples/synthetic/0127.jpg", "Synthetic", "Bad, missing many details in the trees"], ["examples/synthetic/10467.jpg", "Synthetic", "Bad, misclassified the windows as sky"], ] with gr.Blocks(theme=gr.themes.Default()) as demo: gr.Markdown( """ # Demo: Sky Replacement with Alpha Matting This demo performs alpha matting and sky replacements for houses using a U-Net architecture with a Swin backbone. \t This model is trained solely on synthetic data generated using Blender. \n Upload an image to perform sky replacement. """ ) data_type = gr.Radio(choices=["Real", "Synthetic"], value="Real", label="Select Data Type for Examples") with gr.Row(): # Left Column: Input Image and Run/Clear Buttons with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input Image") with gr.Row(): clear_button = gr.Button("Clear") run_button = gr.Button("Submit", variant="primary") # Right Column: Output Images with gr.Column(scale=1): output_mask = gr.Image(type="pil", label="Predicted Mask") output_sky = gr.Image(type="pil", label="Sky Replacement") metadata_display = gr.Markdown(None) with gr.Column(visible=True) as real_examples_container: real_examples_component = gr.Examples( examples=real_example_list, inputs=[input_image, gr.Textbox(label="Data Type", value="", interactive=False, visible=False), gr.Textbox(label="Result", value="", interactive=False, visible=False)], outputs=[input_image, metadata_display], fn=lambda example, dtype, desc: (example, f"**Type:** {dtype}\n\n**Result:** {desc}"), cache_examples=False, label="Real Data Examples" ) with gr.Column(visible=False) as synthetic_examples_container: synthetic_examples_component = gr.Examples( examples=synthetic_example_list, inputs=[input_image, gr.Textbox(label="Data Type", value="", interactive=False, visible=False), gr.Textbox(label="Result", value="", interactive=False, visible=False)], outputs=[input_image, metadata_display], fn=lambda example, dtype, desc: (example, f"**Type:** {dtype}\n\n**Result:** {desc}"), cache_examples=False, label="Synthetic Data Examples" ) # Callback to toggle the container visibility based on selection. def switch_examples(selected): if selected == "Real": return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) data_type.change( fn=switch_examples, inputs=data_type, outputs=[real_examples_container, synthetic_examples_container] ) def clear_all(): return gr.update(value=None), gr.update(value=None), gr.update(value=None) clear_button.click(fn=clear_all, inputs=[], outputs=[input_image, output_mask, output_sky]) run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky]) # Launch the interface demo.launch(share=SHARE_REPO, ssr_mode=False)