pawlo2013 commited on
Commit
56c6b0d
·
1 Parent(s): 8ef1b85

latest bug fix

Browse files
Files changed (1) hide show
  1. app.py +4 -20
app.py CHANGED
@@ -8,7 +8,6 @@ import matplotlib.pyplot as plt
8
  import numpy as np
9
  import cv2
10
 
11
-
12
  # Model and processor configuration
13
  model_name_or_path = "google/vit-base-patch16-224-in21k"
14
  processor = ViTImageProcessor.from_pretrained(model_name_or_path)
@@ -30,10 +29,8 @@ model = ViTForImageClassification.from_pretrained(
30
  model.eval()
31
 
32
 
33
- # Define the classification function
34
  # Define the classification function
35
  def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
36
- # filename = img.filename
37
  img = img.convert("RGB")
38
  processed_input = processor(images=img, return_tensors="pt").to(device)
39
 
@@ -45,7 +42,6 @@ def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="me
45
  predicted_class = class_names[prediction]
46
 
47
  result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
48
- # get the filename from the image object
49
 
50
  # Generate attention heatmap
51
  heatmap_img = show_final_layer_attention_maps(
@@ -75,24 +71,18 @@ def load_examples_from_folder(folder_path):
75
  def show_final_layer_attention_maps(
76
  model, tensor, device, discard_ratio=0.6, head_fusion="max", only_last_layer=False
77
  ):
78
- # Create a DataLoader with batch size equal to the number of images
79
  image = tensor["pixel_values"].to(device).squeeze(0)
80
 
81
- # Iterate over the samples
82
  with torch.no_grad():
83
- # Forward pass through the model
84
  outputs = model(**tensor, output_attentions=True)
85
 
86
- print(type(outputs.attentions[0]))
87
  if outputs.attentions[0] is None:
88
  print("Attention outputs are None.")
89
  return None
90
 
91
- # Scale image to [0, 1]
92
  image = image - image.min()
93
  image = image / image.max()
94
 
95
- # Initialize the result tensor and recursively fuse the attention maps
96
  result = torch.eye(outputs.attentions[0].size(-1)).to(device)
97
  if only_last_layer:
98
  attention_list = outputs.attentions[-1].unsqueeze(0).to(device)
@@ -119,27 +109,21 @@ def show_final_layer_attention_maps(
119
  result = torch.matmul(a, result)
120
 
121
  mask = result[0, 0, 1:]
122
- # In case of 224x224 image, this brings us from 196 to 14
123
  width = int(mask.size(-1) ** 0.5)
124
  mask = mask.reshape(width, width).cpu().numpy()
125
  mask = mask / np.max(mask)
126
 
127
  mask = cv2.resize(mask, (224, 224))
128
 
129
- # Normalize mask to [0, 1] for visualization
130
  mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
131
- heatmap = plt.cm.jet(mask)[:, :, :3] # Apply colormap
132
 
133
- # Superimpose heatmap on the original image
134
  showed_img = image.permute(1, 2, 0).detach().cpu().numpy()
135
  showed_img = (showed_img - np.min(showed_img)) / (
136
  np.max(showed_img) - np.min(showed_img)
137
- ) # Normalize image
138
- superimposed_img = (
139
- heatmap * 0.4 + showed_img * 0.6
140
- ) # Combine heatmap with original image
141
 
142
- # Plot attention map
143
  superimposed_img_pil = Image.fromarray(
144
  (superimposed_img * 255).astype(np.uint8)
145
  )
@@ -165,4 +149,4 @@ iface = gr.Interface(
165
  )
166
  # Launch the app
167
  if __name__ == "__main__":
168
- iface.launch()
 
8
  import numpy as np
9
  import cv2
10
 
 
11
  # Model and processor configuration
12
  model_name_or_path = "google/vit-base-patch16-224-in21k"
13
  processor = ViTImageProcessor.from_pretrained(model_name_or_path)
 
29
  model.eval()
30
 
31
 
 
32
  # Define the classification function
33
  def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
 
34
  img = img.convert("RGB")
35
  processed_input = processor(images=img, return_tensors="pt").to(device)
36
 
 
42
  predicted_class = class_names[prediction]
43
 
44
  result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
 
45
 
46
  # Generate attention heatmap
47
  heatmap_img = show_final_layer_attention_maps(
 
71
  def show_final_layer_attention_maps(
72
  model, tensor, device, discard_ratio=0.6, head_fusion="max", only_last_layer=False
73
  ):
 
74
  image = tensor["pixel_values"].to(device).squeeze(0)
75
 
 
76
  with torch.no_grad():
 
77
  outputs = model(**tensor, output_attentions=True)
78
 
 
79
  if outputs.attentions[0] is None:
80
  print("Attention outputs are None.")
81
  return None
82
 
 
83
  image = image - image.min()
84
  image = image / image.max()
85
 
 
86
  result = torch.eye(outputs.attentions[0].size(-1)).to(device)
87
  if only_last_layer:
88
  attention_list = outputs.attentions[-1].unsqueeze(0).to(device)
 
109
  result = torch.matmul(a, result)
110
 
111
  mask = result[0, 0, 1:]
 
112
  width = int(mask.size(-1) ** 0.5)
113
  mask = mask.reshape(width, width).cpu().numpy()
114
  mask = mask / np.max(mask)
115
 
116
  mask = cv2.resize(mask, (224, 224))
117
 
 
118
  mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
119
+ heatmap = plt.cm.jet(mask)[:, :, :3]
120
 
 
121
  showed_img = image.permute(1, 2, 0).detach().cpu().numpy()
122
  showed_img = (showed_img - np.min(showed_img)) / (
123
  np.max(showed_img) - np.min(showed_img)
124
+ )
125
+ superimposed_img = heatmap * 0.4 + showed_img * 0.6
 
 
126
 
 
127
  superimposed_img_pil = Image.fromarray(
128
  (superimposed_img * 255).astype(np.uint8)
129
  )
 
149
  )
150
  # Launch the app
151
  if __name__ == "__main__":
152
+ iface.launch(debug=True)