pikween commited on
Commit
925ba47
·
1 Parent(s): f09d157

rename to fit

Browse files
Files changed (1) hide show
  1. app.py +308 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, "./SMT")
3
+
4
+ from smt_trainer import SMT_Trainer
5
+ from smt_model.modeling_smt import SMTModelForCausalLM
6
+
7
+ import torch
8
+ import gradio as gr
9
+ import numpy as np
10
+ import pandas as pd
11
+ import cv2
12
+
13
+ from math import sqrt
14
+
15
+ CA_layers = list()
16
+
17
+ colors = [ (128, 0, 0),
18
+ (128, 64, 0),
19
+ (128, 128, 0),
20
+ ( 0, 128, 0),
21
+ ( 0, 128, 128),
22
+ ( 0, 64, 128),
23
+ ( 0, 0, 128),
24
+ (128, 0, 128),
25
+ (128, 0, 0)
26
+ ]
27
+
28
+ def contrast(elem):
29
+ return elem!=0
30
+
31
+ def overlay(background:np.ndarray, overlay:np.ndarray, alpha=1):
32
+ """
33
+ :param background: BGR image (np.uint8)
34
+ :param overlay: BGRA image (np.uint8)
35
+ :param alpha: Transparency of overlay over background
36
+
37
+ returns BGR image of combined images (np.float32)
38
+ """
39
+
40
+ # add alpha channel to background
41
+ background = np.concatenate([background, np.full([*background.shape[:2], 1], 1.0)], axis=-1 )
42
+
43
+ # normalize overlay alpha channel from 0-255 to 0.-1.
44
+ alpha_background = 1.0
45
+ alpha_overlay = overlay[:,:,3] / 255.0 * alpha
46
+
47
+ for channel in range(3):
48
+
49
+ background[:,:,channel] = alpha_overlay * overlay[:,:,channel] + \
50
+ alpha_background * background[:,:,channel] * ( 1 - alpha_overlay )
51
+
52
+ background[:,:,3] = ( 1 - ( 1 - alpha_overlay ) * ( 1 - alpha_background ) ) * 255
53
+
54
+ # ignore alpha channel because gradio doesnt care
55
+ # also divide by 255 because somehow it needs a float image even though it gives int images
56
+ return (background[:,:,:3]/255.0).astype(np.float32)
57
+
58
+ def generate_CA_images(token_idx, image, multiplier=1):
59
+
60
+ global CA_layers
61
+
62
+ CA_final_images = []
63
+
64
+ # resize to fit input image (value in 0-1)
65
+ masks = [ cv2.resize(CA_layers[layer_idx][token_idx],
66
+ interpolation=cv2.INTER_NEAREST,
67
+ dsize=(image.shape[1], image.shape[0])) for layer_idx in range(0, len(CA_layers)) ]
68
+
69
+ for i,mask in enumerate(masks):
70
+
71
+ # apply multiplier
72
+ mask *= multiplier
73
+
74
+ # normalize values above 1
75
+ max_pixel = np.max(mask)
76
+ if max_pixel > 1:
77
+ mask /= max_pixel
78
+
79
+ # (convert to values in 0-255)
80
+ mask = np.round(mask*255.0).astype(np.uint8)
81
+
82
+ # add singleton dimension as channel
83
+ mask = np.expand_dims(mask, axis=-1)
84
+
85
+ # base color + transparency mask = BGRA
86
+ ca = np.concatenate( (np.full(shape=image.shape, fill_value=colors[i]), mask ), axis=-1)
87
+
88
+ CA_final_images.append(overlay(image, ca))
89
+
90
+ return CA_final_images
91
+
92
+ def make_predictions(checkpoint, input_image, input_type:int):
93
+
94
+ global CA_layers
95
+
96
+ # take from huggingface
97
+ if input_type == 0:
98
+ # TODO this doesnt work because the HuggingFace weights aren't updated
99
+ model = SMTModelForCausalLM.from_pretrained("antoniorv6/smt-grandstaff")
100
+ model.to(device=model.positional_2D.pe.device)
101
+ input_image = np.mean(input_image, axis=2, keepdims=True) # 3 channels to one
102
+ input_image = np.transpose(input_image, (2,0,1))[None, :] # add batch size as well, [B, C, H, W]
103
+ input_image = torch.from_numpy(input_image)#.to(device=model.positional_2D.pe.device)
104
+
105
+
106
+ # take from checkpoint variable
107
+ elif input_type == 1:
108
+ model = SMT_Trainer.load_from_checkpoint(checkpoint).model
109
+ model.to(device=model.pos2D.pe.device)
110
+ input_image = np.mean(input_image, axis=2, keepdims=True) # 3 channels to one
111
+ input_image = np.transpose(input_image, (2,0,1))[None, :] # add batch size as well, [B, C, H, W]
112
+ input_image = torch.from_numpy(input_image).to(device=model.pos2D.pe.device)
113
+
114
+ input_image = input_image.to(torch.float32)
115
+
116
+ # width / height
117
+ aspect_ratio = input_image.shape[3]/input_image.shape[2]
118
+
119
+ # 8 attention layers * [channels | seq_len | extracted_features]
120
+ # extracted features is FLAT input_image shape divided by 16
121
+ predicted_seq, predictions = model.predict(input_image, return_weights=True)
122
+
123
+ # seq_len | reduced_h * reduced_w
124
+ CA_layers = [ ca_layer.mean(dim=1).squeeze() for ca_layer in predictions.cross_attentions ]
125
+
126
+ seq_len = CA_layers[0].shape[0]
127
+ att_w = round(sqrt(CA_layers[0].shape[1] * aspect_ratio))
128
+ att_h = round(sqrt(CA_layers[0].shape[1] / aspect_ratio))
129
+
130
+ # make the attention 2-D
131
+ CA_layers = [ att.reshape( seq_len, att_h, att_w ) for att in CA_layers ]
132
+
133
+ # convert to numpy
134
+ CA_layers = [ att.cpu().detach().numpy() for att in CA_layers ]
135
+ # ^^^ we store this, then generate the actual images to display ONLY whenever the token slider is moved
136
+
137
+ overall = np.stack(CA_layers).sum(axis=0)
138
+
139
+ ## normalize
140
+ overall_max_value = np.max(overall)
141
+ if overall_max_value > 1.0:
142
+ overall /= np.max(overall)
143
+
144
+ CA_layers.append(overall)
145
+
146
+ return pd.DataFrame([predicted_seq])
147
+
148
+ def define_input_source( choice:gr.SelectData ):
149
+ """
150
+ Defines the interface according to the inputs the user has chosen to work with
151
+ """
152
+
153
+ if choice.index == 0: # pretrained weights
154
+ return gr.update(visible=False), 0 # file input invisible, input type state update
155
+
156
+ elif choice.index == 1: # your own weights
157
+ return gr.update(visible=True), 1 # file input visible, input type state update
158
+
159
+ def define_interface():
160
+
161
+ # main components
162
+ file_input = gr.File(label="Model Checkpoint File", visible=False, interactive=True)
163
+ image_input = gr.Image(label="Input Image")
164
+ tabs = gr.Tabs()
165
+
166
+ # knob components
167
+ token_slider = gr.Slider(minimum=0, maximum=0, step=1,
168
+ label="Pick a token",
169
+ info="Select a predicted token to visualize the attention it pays in the input sample",
170
+ visible=False)
171
+
172
+ intensifier_slider = gr.Slider(minimum=1, maximum=100, step=1,
173
+ label="Intensify attention",
174
+ info="Use this slider to intensify the attention values to better see differences",
175
+ value = 10,
176
+ visible=False)
177
+
178
+ token_table = gr.DataFrame(interactive=False, value=pd.DataFrame(["The predicted sequence will appear here"]))
179
+
180
+ def intensifier_visibility():
181
+ """
182
+ Makes intensifier slider visible whenever token slider is changed
183
+ """
184
+ return gr.update(visible=True)
185
+
186
+ with gr.Blocks() as page:
187
+
188
+ ###
189
+ token_slider.release( fn=intensifier_visibility, outputs=intensifier_slider )
190
+ ###
191
+
192
+ gr.Markdown("# SMT Demonstrator")
193
+
194
+ with gr.Row():
195
+
196
+ with gr.Column():
197
+
198
+ '''
199
+ model_interface = gr.Interface(make_predictions,
200
+ inputs=[file_input, image_input, input_type],
201
+ outputs=[token_table],
202
+ flagging_mode='never')
203
+ '''
204
+
205
+ # input area
206
+ with gr.Blocks():
207
+ select_src_weights = gr.Dropdown(["Test pretrained weights (default)", "Test your own weights"],
208
+ label="Pick which weights to test out",
209
+ interactive=True)
210
+
211
+ # State variable -- Weights source picked by user
212
+ input_type = gr.Number(value=0, visible=False)
213
+
214
+ select_src_weights.select( define_input_source, outputs=[file_input, input_type] )
215
+
216
+ file_input.render()
217
+ image_input.render()
218
+
219
+ with gr.Row():
220
+
221
+ def submit_logic(file, image, type):
222
+
223
+ return make_predictions(file, image, type), gr.update(visible=True), gr.update(visible=True)
224
+
225
+ clear_btn = gr.ClearButton( components=[file_input, image_input] )
226
+
227
+ submit_btn = gr.Button( value="Submit", variant="primary")
228
+ submit_btn.click( fn=submit_logic,
229
+ inputs=[file_input, image_input, input_type],
230
+ outputs=[token_table, token_slider, intensifier_slider] )
231
+
232
+ with gr.Column(scale=2):
233
+
234
+ token_slider.render()
235
+
236
+ # State variable -- Tab the user left off on
237
+ tab_selected = gr.Number(value="8", visible=False) # on Overall Attention tab by default
238
+
239
+ # genera las imagenes cada vez que se mueve el slider
240
+ @gr.render( inputs =[token_table, token_slider, image_input, intensifier_slider, tab_selected],
241
+ triggers=[token_slider.release, intensifier_slider.release, token_table.change])
242
+ def render_images_display(prediction, slider, image, intensifier, tab_no):
243
+
244
+ if prediction.shape[0] > 0:
245
+
246
+ images = generate_CA_images(slider, image, intensifier)
247
+
248
+ gr.Markdown(value="## Contents of the Cross-Attention layers")
249
+
250
+ with gr.Tabs(selected=f"{tab_no}") as tabs:
251
+
252
+ with gr.Tab(f"Overall", id="8") as tab_overall:
253
+ tab_overall.select( (lambda : gr.Number(8)), outputs=[tab_selected] )
254
+ gr.Image(value=images[8])
255
+
256
+ with gr.Tab(f"Layer 1", id=f"0") as tab_1:
257
+ tab_1.select( (lambda : gr.Number(0)), outputs=[tab_selected] )
258
+ gr.Image(value=images[0])
259
+
260
+ with gr.Tab(f"Layer 2", id=f"1") as tab_2:
261
+ tab_2.select( (lambda : gr.Number(1)), outputs=[tab_selected] )
262
+ gr.Image(value=images[1])
263
+
264
+ with gr.Tab(f"Layer 3", id=f"2") as tab_3:
265
+ tab_3.select( (lambda : gr.Number(2)), outputs=[tab_selected] )
266
+ gr.Image(value=images[2])
267
+
268
+ with gr.Tab(f"Layer 4", id=f"3") as tab_4:
269
+ tab_4.select( (lambda : gr.Number(3)), outputs=[tab_selected] )
270
+ gr.Image(value=images[3])
271
+
272
+ with gr.Tab(f"Layer 5", id=f"4") as tab_5:
273
+ tab_5.select( (lambda : gr.Number(4)), outputs=[tab_selected] )
274
+ gr.Image(value=images[4])
275
+
276
+ with gr.Tab(f"Layer 6", id=f"5") as tab_6:
277
+ tab_6.select( (lambda : gr.Number(5)), outputs=[tab_selected] )
278
+ gr.Image(value=images[5])
279
+
280
+ with gr.Tab(f"Layer 7", id=f"6") as tab_7:
281
+ tab_7.select( (lambda : gr.Number(6)), outputs=[tab_selected] )
282
+ gr.Image(value=images[6])
283
+
284
+ with gr.Tab(f"Layer 8", id=f"7") as tab_8:
285
+ tab_8.select( (lambda : gr.Number(7)), outputs=[tab_selected] )
286
+ gr.Image(value=images[7])
287
+
288
+ intensifier_slider.render()
289
+
290
+ with gr.Column():
291
+
292
+
293
+ gr.Markdown("## Predicted Sequence")
294
+
295
+ def render_prediction_display(tokens):
296
+ return gr.Slider(maximum=tokens.shape[0], visible=True), gr.update(visible=True)
297
+
298
+ token_table.render()
299
+ token_table.change(render_prediction_display, inputs=[token_table], outputs=[token_slider, token_table])
300
+
301
+
302
+
303
+ return page
304
+
305
+ if __name__=="__main__":
306
+ page = define_interface()
307
+ page.launch(share=False)
308
+