Spaces:
Sleeping
Sleeping
import sys | |
sys.path.insert(0, "./SMT") | |
from smt_trainer import SMT_Trainer | |
from smt_model.modeling_smt import SMTModelForCausalLM | |
import torch | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import cv2 | |
from math import sqrt | |
CA_layers = list() | |
colors = [ (128, 0, 0), | |
(128, 64, 0), | |
(128, 128, 0), | |
( 0, 128, 0), | |
( 0, 128, 128), | |
( 0, 64, 128), | |
( 0, 0, 128), | |
(128, 0, 128), | |
(128, 0, 0) | |
] | |
def contrast(elem): | |
return elem!=0 | |
def overlay(background:np.ndarray, overlay:np.ndarray, alpha=1): | |
""" | |
:param background: BGR image (np.uint8) | |
:param overlay: BGRA image (np.uint8) | |
:param alpha: Transparency of overlay over background | |
returns BGR image of combined images (np.float32) | |
""" | |
# add alpha channel to background | |
background = np.concatenate([background, np.full([*background.shape[:2], 1], 1.0)], axis=-1 ) | |
# normalize overlay alpha channel from 0-255 to 0.-1. | |
alpha_background = 1.0 | |
alpha_overlay = overlay[:,:,3] / 255.0 * alpha | |
for channel in range(3): | |
background[:,:,channel] = alpha_overlay * overlay[:,:,channel] + \ | |
alpha_background * background[:,:,channel] * ( 1 - alpha_overlay ) | |
background[:,:,3] = ( 1 - ( 1 - alpha_overlay ) * ( 1 - alpha_background ) ) * 255 | |
# ignore alpha channel because gradio doesnt care | |
# also divide by 255 because somehow it needs a float image even though it gives int images | |
return (background[:,:,:3]/255.0).astype(np.float32) | |
def generate_CA_images(token_idx, image, multiplier=1): | |
global CA_layers | |
CA_final_images = [] | |
# resize to fit input image (value in 0-1) | |
masks = [ cv2.resize(CA_layers[layer_idx][token_idx], | |
interpolation=cv2.INTER_NEAREST, | |
dsize=(image.shape[1], image.shape[0])) for layer_idx in range(0, len(CA_layers)) ] | |
for i,mask in enumerate(masks): | |
# apply multiplier | |
mask *= multiplier | |
# normalize values above 1 | |
max_pixel = np.max(mask) | |
if max_pixel > 1: | |
mask /= max_pixel | |
# (convert to values in 0-255) | |
mask = np.round(mask*255.0).astype(np.uint8) | |
# add singleton dimension as channel | |
mask = np.expand_dims(mask, axis=-1) | |
# base color + transparency mask = BGRA | |
ca = np.concatenate( (np.full(shape=image.shape, fill_value=colors[i]), mask ), axis=-1) | |
CA_final_images.append(overlay(image, ca)) | |
return CA_final_images | |
def make_predictions(checkpoint, input_image, input_type:int): | |
global CA_layers | |
# take from huggingface | |
if input_type == 0: | |
# TODO this doesnt work because the HuggingFace weights aren't updated | |
model = SMTModelForCausalLM.from_pretrained("antoniorv6/smt-grandstaff") | |
model.to(device=model.positional_2D.pe.device) | |
input_image = np.mean(input_image, axis=2, keepdims=True) # 3 channels to one | |
input_image = np.transpose(input_image, (2,0,1))[None, :] # add batch size as well, [B, C, H, W] | |
input_image = torch.from_numpy(input_image)#.to(device=model.positional_2D.pe.device) | |
# take from checkpoint variable | |
elif input_type == 1: | |
model = SMT_Trainer.load_from_checkpoint(checkpoint).model | |
model.to(device=model.pos2D.pe.device) | |
input_image = np.mean(input_image, axis=2, keepdims=True) # 3 channels to one | |
input_image = np.transpose(input_image, (2,0,1))[None, :] # add batch size as well, [B, C, H, W] | |
input_image = torch.from_numpy(input_image).to(device=model.pos2D.pe.device) | |
input_image = input_image.to(torch.float32) | |
# width / height | |
aspect_ratio = input_image.shape[3]/input_image.shape[2] | |
# 8 attention layers * [channels | seq_len | extracted_features] | |
# extracted features is FLAT input_image shape divided by 16 | |
predicted_seq, predictions = model.predict(input_image, return_weights=True) | |
# seq_len | reduced_h * reduced_w | |
CA_layers = [ ca_layer.mean(dim=1).squeeze() for ca_layer in predictions.cross_attentions ] | |
seq_len = CA_layers[0].shape[0] | |
att_w = round(sqrt(CA_layers[0].shape[1] * aspect_ratio)) | |
att_h = round(sqrt(CA_layers[0].shape[1] / aspect_ratio)) | |
# make the attention 2-D | |
CA_layers = [ att.reshape( seq_len, att_h, att_w ) for att in CA_layers ] | |
# convert to numpy | |
CA_layers = [ att.cpu().detach().numpy() for att in CA_layers ] | |
# ^^^ we store this, then generate the actual images to display ONLY whenever the token slider is moved | |
overall = np.stack(CA_layers).sum(axis=0) | |
## normalize | |
overall_max_value = np.max(overall) | |
if overall_max_value > 1.0: | |
overall /= np.max(overall) | |
CA_layers.append(overall) | |
return pd.DataFrame([predicted_seq]) | |
def define_input_source( choice:gr.SelectData ): | |
""" | |
Defines the interface according to the inputs the user has chosen to work with | |
""" | |
if choice.index == 0: # pretrained weights | |
return gr.update(visible=False), 0 # file input invisible, input type state update | |
elif choice.index == 1: # your own weights | |
return gr.update(visible=True), 1 # file input visible, input type state update | |
def define_interface(): | |
# main components | |
file_input = gr.File(label="Model Checkpoint File", visible=False, interactive=True) | |
image_input = gr.Image(label="Input Image") | |
tabs = gr.Tabs() | |
# knob components | |
token_slider = gr.Slider(minimum=0, maximum=0, step=1, | |
label="Pick a token", | |
info="Select a predicted token to visualize the attention it pays in the input sample", | |
visible=False) | |
intensifier_slider = gr.Slider(minimum=1, maximum=100, step=1, | |
label="Intensify attention", | |
info="Use this slider to intensify the attention values to better see differences", | |
value = 10, | |
visible=False) | |
token_table = gr.DataFrame(interactive=False, value=pd.DataFrame(["The predicted sequence will appear here"])) | |
def intensifier_visibility(): | |
""" | |
Makes intensifier slider visible whenever token slider is changed | |
""" | |
return gr.update(visible=True) | |
with gr.Blocks() as page: | |
### | |
token_slider.release( fn=intensifier_visibility, outputs=intensifier_slider ) | |
### | |
gr.Markdown("# SMT Demonstrator") | |
with gr.Row(): | |
with gr.Column(): | |
''' | |
model_interface = gr.Interface(make_predictions, | |
inputs=[file_input, image_input, input_type], | |
outputs=[token_table], | |
flagging_mode='never') | |
''' | |
# input area | |
with gr.Blocks(): | |
select_src_weights = gr.Dropdown(["Test pretrained weights (default)", "Test your own weights"], | |
label="Pick which weights to test out", | |
interactive=True) | |
# State variable -- Weights source picked by user | |
input_type = gr.Number(value=0, visible=False) | |
select_src_weights.select( define_input_source, outputs=[file_input, input_type] ) | |
file_input.render() | |
image_input.render() | |
with gr.Row(): | |
def submit_logic(file, image, type): | |
return make_predictions(file, image, type), gr.update(visible=True), gr.update(visible=True) | |
clear_btn = gr.ClearButton( components=[file_input, image_input] ) | |
submit_btn = gr.Button( value="Submit", variant="primary") | |
submit_btn.click( fn=submit_logic, | |
inputs=[file_input, image_input, input_type], | |
outputs=[token_table, token_slider, intensifier_slider] ) | |
with gr.Column(scale=2): | |
token_slider.render() | |
# State variable -- Tab the user left off on | |
tab_selected = gr.Number(value="8", visible=False) # on Overall Attention tab by default | |
# genera las imagenes cada vez que se mueve el slider | |
def render_images_display(prediction, slider, image, intensifier, tab_no): | |
if prediction.shape[0] > 0: | |
images = generate_CA_images(slider, image, intensifier) | |
gr.Markdown(value="## Contents of the Cross-Attention layers") | |
with gr.Tabs(selected=f"{tab_no}") as tabs: | |
with gr.Tab(f"Overall", id="8") as tab_overall: | |
tab_overall.select( (lambda : gr.Number(8)), outputs=[tab_selected] ) | |
gr.Image(value=images[8]) | |
with gr.Tab(f"Layer 1", id=f"0") as tab_1: | |
tab_1.select( (lambda : gr.Number(0)), outputs=[tab_selected] ) | |
gr.Image(value=images[0]) | |
with gr.Tab(f"Layer 2", id=f"1") as tab_2: | |
tab_2.select( (lambda : gr.Number(1)), outputs=[tab_selected] ) | |
gr.Image(value=images[1]) | |
with gr.Tab(f"Layer 3", id=f"2") as tab_3: | |
tab_3.select( (lambda : gr.Number(2)), outputs=[tab_selected] ) | |
gr.Image(value=images[2]) | |
with gr.Tab(f"Layer 4", id=f"3") as tab_4: | |
tab_4.select( (lambda : gr.Number(3)), outputs=[tab_selected] ) | |
gr.Image(value=images[3]) | |
with gr.Tab(f"Layer 5", id=f"4") as tab_5: | |
tab_5.select( (lambda : gr.Number(4)), outputs=[tab_selected] ) | |
gr.Image(value=images[4]) | |
with gr.Tab(f"Layer 6", id=f"5") as tab_6: | |
tab_6.select( (lambda : gr.Number(5)), outputs=[tab_selected] ) | |
gr.Image(value=images[5]) | |
with gr.Tab(f"Layer 7", id=f"6") as tab_7: | |
tab_7.select( (lambda : gr.Number(6)), outputs=[tab_selected] ) | |
gr.Image(value=images[6]) | |
with gr.Tab(f"Layer 8", id=f"7") as tab_8: | |
tab_8.select( (lambda : gr.Number(7)), outputs=[tab_selected] ) | |
gr.Image(value=images[7]) | |
intensifier_slider.render() | |
with gr.Column(): | |
gr.Markdown("## Predicted Sequence") | |
def render_prediction_display(tokens): | |
return gr.Slider(maximum=tokens.shape[0], visible=True), gr.update(visible=True) | |
token_table.render() | |
token_table.change(render_prediction_display, inputs=[token_table], outputs=[token_slider, token_table]) | |
return page | |
if __name__=="__main__": | |
page = define_interface() | |
page.launch(share=False) | |