danhtran2mind commited on
Commit
10a4a0a
·
verified ·
1 Parent(s): 9566994

Update apps/gradio_app.py

Browse files
Files changed (1) hide show
  1. apps/gradio_app.py +140 -145
apps/gradio_app.py CHANGED
@@ -1,146 +1,141 @@
1
- import gradio as gr
2
- from PIL import Image
3
- from gradio_app.inference import run_inference
4
- from gradio_app.components import (
5
- CONTENT_DESCRIPTION, CONTENT_OUTTRO,
6
- CONTENT_IN_1, CONTENT_IN_2,
7
- CONTENT_OUT_1, CONTENT_OUT_2,
8
- list_reference_files, list_mapping_files,
9
- list_classifier_files, list_edgeface_files
10
- )
11
- from glob import glob
12
- import os
13
-
14
- def create_image_io_row():
15
- """Create the row for image input and output display."""
16
- with gr.Row(elem_classes=["image-io-row"]):
17
- image_input = gr.Image(type="pil", label="Upload Image")
18
- output = gr.HTML(label="Inference Results", elem_classes=["results-container"])
19
- return image_input, output
20
-
21
- def create_model_settings_row():
22
- """Create the row for model files and settings."""
23
- with gr.Row():
24
- with gr.Column():
25
- with gr.Group(elem_classes=["section-group"]):
26
- gr.Markdown("### Model Files", elem_classes=["section-title"])
27
- ref_dict = gr.Dropdown(
28
- choices=["Select a file"] + list_reference_files(),
29
- label="Reference Dict JSON",
30
- value="data/reference_data/reference_image_data.json"
31
- )
32
- index_map = gr.Dropdown(
33
- choices=["Select a file"] + list_mapping_files(),
34
- label="Index to Class Mapping JSON",
35
- value="ckpts/index_to_class_mapping.json"
36
- )
37
- classifier_model = gr.Dropdown(
38
- choices=["Select a file"] + list_classifier_files(),
39
- label="Classifier Model (.pth)",
40
- value="ckpts/SlimFace_efficientnet_b3_full_model.pth"
41
- )
42
- edgeface_model = gr.Dropdown(
43
- choices=["Select a file"] + list_edgeface_files(),
44
- label="EdgeFace Model (.pt)",
45
- value="ckpts/idiap/edgeface_s_gamma_05.pt"
46
- )
47
- with gr.Column():
48
- with gr.Group(elem_classes=["section-group"]):
49
- gr.Markdown("### Advanced Settings", elem_classes=["section-title"])
50
- algorithm = gr.Dropdown(
51
- choices=["yolo", "mtcnn", "retinaface"],
52
- label="Detection Algorithm",
53
- value="yolo"
54
- )
55
- accelerator = gr.Dropdown(
56
- choices=["auto", "cpu", "cuda", "mps"],
57
- label="Accelerator",
58
- value="auto"
59
- )
60
- resolution = gr.Slider(
61
- minimum=128,
62
- maximum=512,
63
- step=32,
64
- label="Image Resolution",
65
- value=300
66
- )
67
- similarity_threshold = gr.Slider(
68
- minimum=0.1,
69
- maximum=1.0,
70
- step=0.05,
71
- label="Similarity Threshold",
72
- value=0.3
73
- )
74
- return ref_dict, index_map, classifier_model, edgeface_model, algorithm, accelerator, resolution, similarity_threshold
75
-
76
- # Load local CSS file
77
- CSS = open("apps/gradio_app/static/styles.css").read()
78
-
79
- def create_interface():
80
- """Create the Gradio interface for SlimFace."""
81
- with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
82
- gr.Markdown("# SlimFace Demonstration")
83
- gr.Markdown(CONTENT_DESCRIPTION)
84
- gr.Markdown(CONTENT_IN_1)
85
- gr.HTML(CONTENT_IN_2)
86
-
87
- image_input, output = create_image_io_row()
88
- ref_dict, index_map, classifier_model, edgeface_model, algorithm, accelerator, resolution, similarity_threshold = create_model_settings_row()
89
-
90
- # Add example image gallery as a table
91
- with gr.Group():
92
- gr.Markdown("### Example Images")
93
- example_images = glob("apps/assets/examples/*.[jp][pn][gf]")
94
- if example_images:
95
- # Create a list of dictionaries for the table
96
- table_data = []
97
- for img_path in example_images:
98
- table_data.append({
99
- "Image": img_path, # Will be rendered as an image
100
- "Action": f"Use {os.path.basename(img_path)}" # Button text
101
- })
102
-
103
- # Create a table with images and buttons
104
- gr.Dataframe(
105
- value=table_data,
106
- headers=["Image", "Action"],
107
- datatype=["image", "str"],
108
- interactive=False,
109
- elem_classes=["example-table"],
110
- # Add click event for buttons
111
- row_click=lambda row: Image.open(row["Image"]),
112
- outputs=image_input
113
- )
114
- else:
115
- gr.Markdown("No example images found in apps/assets/examples/")
116
-
117
- with gr.Row():
118
- submit_btn = gr.Button("Run Inference", variant="primary", elem_classes=["centered-button"])
119
-
120
- submit_btn.click(
121
- fn=run_inference,
122
- inputs=[
123
- image_input,
124
- ref_dict,
125
- index_map,
126
- classifier_model,
127
- edgeface_model,
128
- algorithm,
129
- accelerator,
130
- resolution,
131
- similarity_threshold
132
- ],
133
- outputs=output
134
- )
135
- gr.Markdown(CONTENT_OUTTRO)
136
- gr.HTML(CONTENT_OUT_1)
137
- gr.Markdown(CONTENT_OUT_2)
138
- return demo
139
-
140
- def main():
141
- """Launch the Gradio interface."""
142
- demo = create_interface()
143
- demo.launch(share=True)
144
-
145
- if __name__ == "__main__":
146
  main()
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from gradio_app.inference import run_inference
4
+ from gradio_app.components import (
5
+ CONTENT_DESCRIPTION, CONTENT_OUTTRO,
6
+ CONTENT_IN_1, CONTENT_IN_2,
7
+ CONTENT_OUT_1, CONTENT_OUT_2,
8
+ list_reference_files, list_mapping_files,
9
+ list_classifier_files, list_edgeface_files
10
+ )
11
+ from glob import glob
12
+ import os
13
+
14
+ def create_image_io_row():
15
+ """Create the row for image input and output display."""
16
+ with gr.Row(elem_classes=["image-io-row"]):
17
+ image_input = gr.Image(type="pil", label="Upload Image")
18
+ output = gr.HTML(label="Inference Results", elem_classes=["results-container"])
19
+ return image_input, output
20
+
21
+ def create_model_settings_row():
22
+ """Create the row for model files and settings."""
23
+ with gr.Row():
24
+ with gr.Column():
25
+ with gr.Group(elem_classes=["section-group"]):
26
+ gr.Markdown("### Model Files", elem_classes=["section-title"])
27
+ ref_dict = gr.Dropdown(
28
+ choices=["Select a file"] + list_reference_files(),
29
+ label="Reference Dict JSON",
30
+ value="data/reference_data/reference_image_data.json"
31
+ )
32
+ index_map = gr.Dropdown(
33
+ choices=["Select a file"] + list_mapping_files(),
34
+ label="Index to Class Mapping JSON",
35
+ value="ckpts/index_to_class_mapping.json"
36
+ )
37
+ classifier_model = gr.Dropdown(
38
+ choices=["Select a file"] + list_classifier_files(),
39
+ label="Classifier Model (.pth)",
40
+ value="ckpts/SlimFace_efficientnet_b3_full_model.pth"
41
+ )
42
+ edgeface_model = gr.Dropdown(
43
+ choices=["Select a file"] + list_edgeface_files(),
44
+ label="EdgeFace Model (.pt)",
45
+ value="ckpts/idiap/edgeface_s_gamma_05.pt"
46
+ )
47
+ with gr.Column():
48
+ with gr.Group(elem_classes=["section-group"]):
49
+ gr.Markdown("### Advanced Settings", elem_classes=["section-title"])
50
+ algorithm = gr.Dropdown(
51
+ choices=["yolo", "mtcnn", "retinaface"],
52
+ label="Detection Algorithm",
53
+ value="yolo"
54
+ )
55
+ accelerator = gr.Dropdown(
56
+ choices=["auto", "cpu", "cuda", "mps"],
57
+ label="Accelerator",
58
+ value="auto"
59
+ )
60
+ resolution = gr.Slider(
61
+ minimum=128,
62
+ maximum=512,
63
+ step=32,
64
+ label="Image Resolution",
65
+ value=300
66
+ )
67
+ similarity_threshold = gr.Slider(
68
+ minimum=0.1,
69
+ maximum=1.0,
70
+ step=0.05,
71
+ label="Similarity Threshold",
72
+ value=0.3
73
+ )
74
+ return ref_dict, index_map, classifier_model, edgeface_model, algorithm, accelerator, resolution, similarity_threshold
75
+
76
+ # Load local CSS file
77
+ CSS = open("apps/gradio_app/static/styles.css").read()
78
+
79
+ def create_interface():
80
+ """Create the Gradio interface for SlimFace."""
81
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
82
+ gr.Markdown("# SlimFace Demonstration")
83
+ gr.Markdown(CONTENT_DESCRIPTION)
84
+ gr.Markdown(CONTENT_IN_1)
85
+ gr.HTML(CONTENT_IN_2)
86
+
87
+ image_input, output = create_image_io_row()
88
+ ref_dict, index_map, classifier_model, edgeface_model, algorithm, accelerator, resolution, similarity_threshold = create_model_settings_row()
89
+
90
+ # Add example image gallery as a row of columns
91
+ with gr.Group():
92
+ gr.Markdown("### Example Images")
93
+ example_images = glob("apps/assets/examples/*.[jp][pn][gf]")
94
+ if example_images:
95
+ with gr.Row(elem_classes=["example-row"]):
96
+ for img_path in example_images:
97
+ with gr.Column(min_width=120):
98
+ gr.Image(
99
+ value=img_path,
100
+ label=os.path.basename(img_path),
101
+ type="filepath",
102
+ height=100,
103
+ elem_classes=["example-image"]
104
+ )
105
+ gr.Button(f"Use {os.path.basename(img_path)}").click(
106
+ fn=lambda x=img_path: Image.open(x),
107
+ outputs=image_input
108
+ )
109
+ else:
110
+ gr.Markdown("No example images found in apps/assets/examples/")
111
+
112
+ with gr.Row():
113
+ submit_btn = gr.Button("Run Inference", variant="primary", elem_classes=["centered-button"])
114
+
115
+ submit_btn.click(
116
+ fn=run_inference,
117
+ inputs=[
118
+ image_input,
119
+ ref_dict,
120
+ index_map,
121
+ classifier_model,
122
+ edgeface_model,
123
+ algorithm,
124
+ accelerator,
125
+ resolution,
126
+ similarity_threshold
127
+ ],
128
+ outputs=output
129
+ )
130
+ gr.Markdown(CONTENT_OUTTRO)
131
+ gr.HTML(CONTENT_OUT_1)
132
+ gr.Markdown(CONTENT_OUT_2)
133
+ return demo
134
+
135
+ def main():
136
+ """Launch the Gradio interface."""
137
+ demo = create_interface()
138
+ demo.launch()
139
+
140
+ if __name__ == "__main__":
 
 
 
 
 
141
  main()