File size: 4,346 Bytes
b7f710c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from PIL import Image
from gradio_app.inference import run_inference
from gradio_app.components import (
    CONTENT_DESCRIPTION, CONTENT_IN, CONTENT_OUT,
     list_reference_files, list_mapping_files,
      list_classifier_files, list_edgeface_files
)

def create_image_input_column():
    """Create the column for image input and output display."""
    with gr.Column():
        image_input = gr.Image(type="pil", label="Upload Image")
        output = gr.HTML(label="Inference Results", elem_classes=["results-container"])
    return image_input, output

def create_model_files_column():
    """Create the column for model file selection."""
    with gr.Column():
        with gr.Group(elem_classes=["section-group"]):
            gr.Markdown("### Model Files", elem_classes=["section-title"])
            ref_dict = gr.Dropdown(
                choices=["Select a file"] + list_reference_files(),
                label="Reference Dict JSON",
                value="data/reference_data/reference_image_data.json"
            )
            index_map = gr.Dropdown(
                choices=["Select a file"] + list_mapping_files(),
                label="Index to Class Mapping JSON",
                value="ckpts/index_to_class_mapping.json"
            )
            classifier_model = gr.Dropdown(
                choices=["Select a file"] + list_classifier_files(),
                label="Classifier Model (.pth)",
                value="ckpts/SlimFace_efficientnet_b3_full_model.pth"
            )
            edgeface_model = gr.Dropdown(
                choices=["Select a file"] + list_edgeface_files(),
                label="EdgeFace Model (.pt)",
                value="ckpts/idiap/edgeface_s_gamma_05.pt"
            )
    return ref_dict, index_map, classifier_model, edgeface_model

def create_settings_column():
    """Create the column for advanced settings."""
    with gr.Column():
        with gr.Group(elem_classes=["section-group"]):
            gr.Markdown("### Advanced Settings", elem_classes=["section-title"])
            algorithm = gr.Dropdown(
                choices=["yolo", "mtcnn", "retinaface"],
                label="Detection Algorithm",
                value="yolo"
            )
            accelerator = gr.Dropdown(
                choices=["auto", "cpu", "cuda", "mps"],
                label="Accelerator",
                value="auto"
            )
            resolution = gr.Slider(
                minimum=128,
                maximum=512,
                step=32,
                label="Image Resolution",
                value=300
            )
            similarity_threshold = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                step=0.05,
                label="Similarity Threshold",
                value=0.3
            )
    return algorithm, accelerator, resolution, similarity_threshold

def create_interface():
    """Create the Gradio interface for SlimFace."""
    with gr.Blocks(css="gradio_app/static/styles.css", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# SlimFace Demonstration")
        gr.Markdown(CONTENT_DESCRIPTION)
        gr.HTML(CONTENT_IN)
        
        with gr.Row():
            image_input, output = create_image_input_column()
            ref_dict, index_map, classifier_model, edgeface_model = create_model_files_column()
        
        with gr.Row():
            algorithm, accelerator, resolution, similarity_threshold = create_settings_column()
        
        with gr.Row():
            submit_btn = gr.Button("Run Inference", variant="primary", elem_classes=["centered-button"])
        
        submit_btn.click(
            fn=run_inference,
            inputs=[
                image_input,
                ref_dict,
                index_map,
                classifier_model,
                edgeface_model,
                algorithm,
                accelerator,
                resolution,
                similarity_threshold
            ],
            outputs=output
        )
        gr.Markdown(CONTENT_OUT)
    return demo

def main():
    """Launch the Gradio interface."""
    demo = create_interface()
    demo.launch()

if __name__ == "__main__":
    main()