import sys sys.path.insert(0, "./SMT") from smt_trainer import SMT_Trainer from smt_model.modeling_smt import SMTModelForCausalLM import torch import gradio as gr import numpy as np import pandas as pd import cv2 from math import sqrt CA_layers = list() colors = [ (128, 0, 0), (128, 64, 0), (128, 128, 0), ( 0, 128, 0), ( 0, 128, 128), ( 0, 64, 128), ( 0, 0, 128), (128, 0, 128), (128, 0, 0) ] def contrast(elem): return elem!=0 def overlay(background:np.ndarray, overlay:np.ndarray, alpha=1): """ :param background: BGR image (np.uint8) :param overlay: BGRA image (np.uint8) :param alpha: Transparency of overlay over background returns BGR image of combined images (np.float32) """ # add alpha channel to background background = np.concatenate([background, np.full([*background.shape[:2], 1], 1.0)], axis=-1 ) # normalize overlay alpha channel from 0-255 to 0.-1. alpha_background = 1.0 alpha_overlay = overlay[:,:,3] / 255.0 * alpha for channel in range(3): background[:,:,channel] = alpha_overlay * overlay[:,:,channel] + \ alpha_background * background[:,:,channel] * ( 1 - alpha_overlay ) background[:,:,3] = ( 1 - ( 1 - alpha_overlay ) * ( 1 - alpha_background ) ) * 255 # ignore alpha channel because gradio doesnt care # also divide by 255 because somehow it needs a float image even though it gives int images return (background[:,:,:3]/255.0).astype(np.float32) def generate_CA_images(token_idx, image, multiplier=1): global CA_layers CA_final_images = [] # resize to fit input image (value in 0-1) masks = [ cv2.resize(CA_layers[layer_idx][token_idx], interpolation=cv2.INTER_NEAREST, dsize=(image.shape[1], image.shape[0])) for layer_idx in range(0, len(CA_layers)) ] for i,mask in enumerate(masks): # apply multiplier mask *= multiplier # normalize values above 1 max_pixel = np.max(mask) if max_pixel > 1: mask /= max_pixel # (convert to values in 0-255) mask = np.round(mask*255.0).astype(np.uint8) # add singleton dimension as channel mask = np.expand_dims(mask, axis=-1) # base color + transparency mask = BGRA ca = np.concatenate( (np.full(shape=image.shape, fill_value=colors[i]), mask ), axis=-1) CA_final_images.append(overlay(image, ca)) return CA_final_images def make_predictions(checkpoint, input_image, input_type:int): global CA_layers # take from huggingface if input_type == 0: # TODO this doesnt work because the HuggingFace weights aren't updated model = SMTModelForCausalLM.from_pretrained("antoniorv6/smt-grandstaff") model.to(device=model.positional_2D.pe.device) input_image = np.mean(input_image, axis=2, keepdims=True) # 3 channels to one input_image = np.transpose(input_image, (2,0,1))[None, :] # add batch size as well, [B, C, H, W] input_image = torch.from_numpy(input_image)#.to(device=model.positional_2D.pe.device) # take from checkpoint variable elif input_type == 1: model = SMT_Trainer.load_from_checkpoint(checkpoint).model model.to(device=model.pos2D.pe.device) input_image = np.mean(input_image, axis=2, keepdims=True) # 3 channels to one input_image = np.transpose(input_image, (2,0,1))[None, :] # add batch size as well, [B, C, H, W] input_image = torch.from_numpy(input_image).to(device=model.pos2D.pe.device) input_image = input_image.to(torch.float32) # width / height aspect_ratio = input_image.shape[3]/input_image.shape[2] # 8 attention layers * [channels | seq_len | extracted_features] # extracted features is FLAT input_image shape divided by 16 predicted_seq, predictions = model.predict(input_image, return_weights=True) # seq_len | reduced_h * reduced_w CA_layers = [ ca_layer.mean(dim=1).squeeze() for ca_layer in predictions.cross_attentions ] seq_len = CA_layers[0].shape[0] att_w = round(sqrt(CA_layers[0].shape[1] * aspect_ratio)) att_h = round(sqrt(CA_layers[0].shape[1] / aspect_ratio)) # make the attention 2-D CA_layers = [ att.reshape( seq_len, att_h, att_w ) for att in CA_layers ] # convert to numpy CA_layers = [ att.cpu().detach().numpy() for att in CA_layers ] # ^^^ we store this, then generate the actual images to display ONLY whenever the token slider is moved overall = np.stack(CA_layers).sum(axis=0) ## normalize overall_max_value = np.max(overall) if overall_max_value > 1.0: overall /= np.max(overall) CA_layers.append(overall) return pd.DataFrame([predicted_seq]) def define_input_source( choice:gr.SelectData ): """ Defines the interface according to the inputs the user has chosen to work with """ if choice.index == 0: # pretrained weights return gr.update(visible=False), 0 # file input invisible, input type state update elif choice.index == 1: # your own weights return gr.update(visible=True), 1 # file input visible, input type state update def define_interface(): # main components file_input = gr.File(label="Model Checkpoint File", visible=False, interactive=True) image_input = gr.Image(label="Input Image") tabs = gr.Tabs() # knob components token_slider = gr.Slider(minimum=0, maximum=0, step=1, label="Pick a token", info="Select a predicted token to visualize the attention it pays in the input sample", visible=False) intensifier_slider = gr.Slider(minimum=1, maximum=100, step=1, label="Intensify attention", info="Use this slider to intensify the attention values to better see differences", value = 10, visible=False) token_table = gr.DataFrame(interactive=False, value=pd.DataFrame(["The predicted sequence will appear here"])) def intensifier_visibility(): """ Makes intensifier slider visible whenever token slider is changed """ return gr.update(visible=True) with gr.Blocks() as page: ### token_slider.release( fn=intensifier_visibility, outputs=intensifier_slider ) ### gr.Markdown("# SMT Demonstrator") with gr.Row(): with gr.Column(): ''' model_interface = gr.Interface(make_predictions, inputs=[file_input, image_input, input_type], outputs=[token_table], flagging_mode='never') ''' # input area with gr.Blocks(): select_src_weights = gr.Dropdown(["Test pretrained weights (default)", "Test your own weights"], label="Pick which weights to test out", interactive=True) # State variable -- Weights source picked by user input_type = gr.Number(value=0, visible=False) select_src_weights.select( define_input_source, outputs=[file_input, input_type] ) file_input.render() image_input.render() with gr.Row(): def submit_logic(file, image, type): return make_predictions(file, image, type), gr.update(visible=True), gr.update(visible=True) clear_btn = gr.ClearButton( components=[file_input, image_input] ) submit_btn = gr.Button( value="Submit", variant="primary") submit_btn.click( fn=submit_logic, inputs=[file_input, image_input, input_type], outputs=[token_table, token_slider, intensifier_slider] ) with gr.Column(scale=2): token_slider.render() # State variable -- Tab the user left off on tab_selected = gr.Number(value="8", visible=False) # on Overall Attention tab by default # genera las imagenes cada vez que se mueve el slider @gr.render( inputs =[token_table, token_slider, image_input, intensifier_slider, tab_selected], triggers=[token_slider.release, intensifier_slider.release, token_table.change]) def render_images_display(prediction, slider, image, intensifier, tab_no): if prediction.shape[0] > 0: images = generate_CA_images(slider, image, intensifier) gr.Markdown(value="## Contents of the Cross-Attention layers") with gr.Tabs(selected=f"{tab_no}") as tabs: with gr.Tab(f"Overall", id="8") as tab_overall: tab_overall.select( (lambda : gr.Number(8)), outputs=[tab_selected] ) gr.Image(value=images[8]) with gr.Tab(f"Layer 1", id=f"0") as tab_1: tab_1.select( (lambda : gr.Number(0)), outputs=[tab_selected] ) gr.Image(value=images[0]) with gr.Tab(f"Layer 2", id=f"1") as tab_2: tab_2.select( (lambda : gr.Number(1)), outputs=[tab_selected] ) gr.Image(value=images[1]) with gr.Tab(f"Layer 3", id=f"2") as tab_3: tab_3.select( (lambda : gr.Number(2)), outputs=[tab_selected] ) gr.Image(value=images[2]) with gr.Tab(f"Layer 4", id=f"3") as tab_4: tab_4.select( (lambda : gr.Number(3)), outputs=[tab_selected] ) gr.Image(value=images[3]) with gr.Tab(f"Layer 5", id=f"4") as tab_5: tab_5.select( (lambda : gr.Number(4)), outputs=[tab_selected] ) gr.Image(value=images[4]) with gr.Tab(f"Layer 6", id=f"5") as tab_6: tab_6.select( (lambda : gr.Number(5)), outputs=[tab_selected] ) gr.Image(value=images[5]) with gr.Tab(f"Layer 7", id=f"6") as tab_7: tab_7.select( (lambda : gr.Number(6)), outputs=[tab_selected] ) gr.Image(value=images[6]) with gr.Tab(f"Layer 8", id=f"7") as tab_8: tab_8.select( (lambda : gr.Number(7)), outputs=[tab_selected] ) gr.Image(value=images[7]) intensifier_slider.render() with gr.Column(): gr.Markdown("## Predicted Sequence") def render_prediction_display(tokens): return gr.Slider(maximum=tokens.shape[0], visible=True), gr.update(visible=True) token_table.render() token_table.change(render_prediction_display, inputs=[token_table], outputs=[token_slider, token_table]) return page if __name__=="__main__": page = define_interface() page.launch(share=False)