luoweibetter commited on
Commit
1a92924
·
verified ·
1 Parent(s): 87e5ca7

Update Zero_Shot_App.py

Browse files
Files changed (1) hide show
  1. Zero_Shot_App.py +186 -193
Zero_Shot_App.py CHANGED
@@ -1,195 +1,188 @@
1
- # import argparse
2
- # from functools import partial
3
- # import gradio as gr
4
- # from torch.nn import functional as F
5
- # from torch import nn
6
- # from dataset import get_data_transforms
7
- # from PIL import Image
8
- # import os
9
-
10
- # from utils import get_gaussian_kernel
11
-
12
- # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
13
- # import os
14
- # import torch
15
- # import cv2
16
- # import numpy as np
17
-
18
- # # # Model-Related Modules
19
- # from models import vit_encoder
20
- # from models.uad import INP_Former
21
- # from models.vision_transformer import Mlp, Aggregation_Block, Prototype_Block
22
-
23
-
24
- # # Configurations
25
- # os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
26
- # parser = argparse.ArgumentParser(description='')
27
- # # model info
28
- # parser.add_argument('--encoder', type=str, default='dinov2reg_vit_base_14')
29
- # parser.add_argument('--input_size', type=int, default=448)
30
- # parser.add_argument('--crop_size', type=int, default=392)
31
- # parser.add_argument('--INP_num', type=int, default=6)
32
-
33
- # args = parser.parse_args()
34
-
35
-
36
- # ############ Init Model
37
- # ckt_path1 = 'saved_results/INP-Former-Multi-Class_dataset=Real-IAD_Encoder=dinov2reg_vit_base_14_Resize=448_Crop=392_INP_num=6/model.pth'
38
- # ckt_path2 = "saved_results/INP-Former-Multi-Class_dataset=VisA_Encoder=dinov2reg_vit_base_14_Resize=448_Crop=392_INP_num=6/model.pth"
39
-
40
- # #
41
- # data_transform, _ = get_data_transforms(args.input_size, args.crop_size)
42
-
43
- # # device
44
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
-
46
- # # Adopting a grouping-based reconstruction strategy similar to Dinomaly
47
- # target_layers = [2, 3, 4, 5, 6, 7, 8, 9]
48
- # fuse_layer_encoder = [[0, 1, 2, 3], [4, 5, 6, 7]]
49
- # fuse_layer_decoder = [[0, 1, 2, 3], [4, 5, 6, 7]]
50
-
51
- # # Encoder info
52
- # encoder = vit_encoder.load(args.encoder)
53
- # if 'small' in args.encoder:
54
- # embed_dim, num_heads = 384, 6
55
- # elif 'base' in args.encoder:
56
- # embed_dim, num_heads = 768, 12
57
- # elif 'large' in args.encoder:
58
- # embed_dim, num_heads = 1024, 16
59
- # target_layers = [4, 6, 8, 10, 12, 14, 16, 18]
60
- # else:
61
- # raise "Architecture not in small, base, large."
62
-
63
- # # Model Preparation
64
- # Bottleneck = []
65
- # INP_Guided_Decoder = []
66
- # INP_Extractor = []
67
-
68
- # # bottleneck
69
- # Bottleneck.append(Mlp(embed_dim, embed_dim * 4, embed_dim, drop=0.))
70
- # Bottleneck = nn.ModuleList(Bottleneck)
71
-
72
- # # INP
73
- # INP = nn.ParameterList(
74
- # [nn.Parameter(torch.randn(args.INP_num, embed_dim))
75
- # for _ in range(1)])
76
-
77
- # # INP Extractor
78
- # for i in range(1):
79
- # blk = Aggregation_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.,
80
- # qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8))
81
- # INP_Extractor.append(blk)
82
- # INP_Extractor = nn.ModuleList(INP_Extractor)
83
-
84
- # # INP_Guided_Decoder
85
- # for i in range(8):
86
- # blk = Prototype_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.,
87
- # qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8))
88
- # INP_Guided_Decoder.append(blk)
89
- # INP_Guided_Decoder = nn.ModuleList(INP_Guided_Decoder)
90
-
91
- # model = INP_Former(encoder=encoder, bottleneck=Bottleneck, aggregation=INP_Extractor, decoder=INP_Guided_Decoder,
92
- # target_layers=target_layers, remove_class_token=True, fuse_layer_encoder=fuse_layer_encoder,
93
- # fuse_layer_decoder=fuse_layer_decoder, prototype_token=INP)
94
- # model = model.to(device)
95
-
96
- # gaussian_kernel = get_gaussian_kernel(kernel_size=5, sigma=4).to(device)
97
-
98
-
99
- # def resize_and_center_crop(image, resize_size=448, crop_size=392):
100
- # # Resize to 448x448
101
- # image_resized = cv2.resize(image, (resize_size, resize_size), interpolation=cv2.INTER_LINEAR)
102
-
103
- # # Compute crop coordinates
104
- # start = (resize_size - crop_size) // 2
105
- # end = start + crop_size
106
-
107
- # # Center crop to 392x392
108
- # image_cropped = image_resized[start:end, start:end, :]
109
-
110
- # return image_cropped
111
-
112
- # def process_image(image, options):
113
- # # Load the model based on selected options
114
- # if 'Real-IAD' in options:
115
- # model.load_state_dict(torch.load(ckt_path1), strict=True)
116
- # elif 'VisA' in options:
117
- # model.load_state_dict(torch.load(ckt_path2), strict=True)
118
- # else:
119
- # # Default to 'All' if no valid option is provided
120
- # model.load_state_dict(torch.load(ckt_path1), strict=True)
121
- # print('Invalid option. Defaulting to All.')
122
-
123
- # # Ensure image is in RGB mode
124
- # image = image.convert('RGB')
125
-
126
-
127
-
128
- # # Convert PIL image to NumPy array
129
- # np_image = np.array(image)
130
- # image_shape = np_image.shape[0]
131
-
132
- # # Convert RGB to BGR for OpenCV
133
- # np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
134
- # np_image = resize_and_center_crop(np_image, resize_size=args.input_size, crop_size=args.crop_size)
135
-
136
- # # Preprocess the image and run the model
137
- # input_image = data_transform(image)
138
- # input_image = input_image.to(device)
139
-
140
- # with torch.no_grad():
141
- # _ = model(input_image.unsqueeze(0))
142
- # anomaly_map = model.distance
143
- # side = int(model.distance.shape[1] ** 0.5)
144
- # anomaly_map = anomaly_map.reshape([anomaly_map.shape[0], side, side]).contiguous()
145
- # anomaly_map = torch.unsqueeze(anomaly_map, dim=1)
146
- # anomaly_map = F.interpolate(anomaly_map, size=input_image.shape[-1], mode='bilinear', align_corners=True)
147
- # anomaly_map = gaussian_kernel(anomaly_map)
148
-
149
- # # Process anomaly map
150
- # anomaly_map = anomaly_map.squeeze().cpu().numpy()
151
- # anomaly_map = (anomaly_map * 255).astype(np.uint8)
152
-
153
- # # Apply color map and blend with original image
154
- # heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
155
- # vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
156
-
157
- # # Convert OpenCV image back to PIL image for Gradio
158
- # vis_map_pil = Image.fromarray(cv2.resize(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB), (image_shape, image_shape)))
159
-
160
- # return vis_map_pil
161
-
162
- # # Define examples
163
- # examples = [
164
- # ["assets/img2.png", "Real-IAD"],
165
- # ["assets/img.png", "VisA"]
166
- # ]
167
-
168
- # # Gradio interface layout
169
- # demo = gr.Interface(
170
- # fn=process_image,
171
- # inputs=[
172
- # gr.Image(type="pil", label="Upload Image"),
173
- # gr.Radio(["Real-IAD",
174
- # "VisA"],
175
- # label="Pre-trained Datasets")
176
- # ],
177
- # outputs=[
178
- # gr.Image(type="pil", label="Output Image")
179
- # ],
180
- # examples=examples,
181
- # title="INP-Former -- Zero-shot Anomaly Detection",
182
- # description="Upload an image and select pre-trained datasets to do zero-shot anomaly detection"
183
- # )
184
-
185
- # # Launch the demo
186
- # demo.launch()
187
- # # demo.launch(server_name="0.0.0.0", server_port=10002)
188
-
189
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- def greet(name):
192
- return "Hello " + name + "!!"
193
-
194
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
195
- demo.launch()
 
1
+ import argparse
2
+ from functools import partial
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
+ from torch.nn import functional as F
5
+ from torch import nn
6
+ from dataset import get_data_transforms
7
+ from PIL import Image
8
+ import os
9
+
10
+ from utils import get_gaussian_kernel
11
+
12
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
13
+ import os
14
+ import torch
15
+ import cv2
16
+ import numpy as np
17
+
18
+ # # Model-Related Modules
19
+ from models import vit_encoder
20
+ from models.uad import INP_Former
21
+ from models.vision_transformer import Mlp, Aggregation_Block, Prototype_Block
22
+
23
+
24
+ # Configurations
25
+ os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
26
+ parser = argparse.ArgumentParser(description='')
27
+ # model info
28
+ parser.add_argument('--encoder', type=str, default='dinov2reg_vit_base_14')
29
+ parser.add_argument('--input_size', type=int, default=448)
30
+ parser.add_argument('--crop_size', type=int, default=392)
31
+ parser.add_argument('--INP_num', type=int, default=6)
32
+
33
+ args = parser.parse_args()
34
+
35
+
36
+ ############ Init Model
37
+ ckt_path1 = 'saved_results/INP-Former-Multi-Class_dataset=Real-IAD_Encoder=dinov2reg_vit_base_14_Resize=448_Crop=392_INP_num=6/model.pth'
38
+ ckt_path2 = "saved_results/INP-Former-Multi-Class_dataset=VisA_Encoder=dinov2reg_vit_base_14_Resize=448_Crop=392_INP_num=6/model.pth"
39
+
40
+ #
41
+ data_transform, _ = get_data_transforms(args.input_size, args.crop_size)
42
+
43
+ # device
44
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+
46
+ # Adopting a grouping-based reconstruction strategy similar to Dinomaly
47
+ target_layers = [2, 3, 4, 5, 6, 7, 8, 9]
48
+ fuse_layer_encoder = [[0, 1, 2, 3], [4, 5, 6, 7]]
49
+ fuse_layer_decoder = [[0, 1, 2, 3], [4, 5, 6, 7]]
50
+
51
+ # Encoder info
52
+ encoder = vit_encoder.load(args.encoder)
53
+ if 'small' in args.encoder:
54
+ embed_dim, num_heads = 384, 6
55
+ elif 'base' in args.encoder:
56
+ embed_dim, num_heads = 768, 12
57
+ elif 'large' in args.encoder:
58
+ embed_dim, num_heads = 1024, 16
59
+ target_layers = [4, 6, 8, 10, 12, 14, 16, 18]
60
+ else:
61
+ raise "Architecture not in small, base, large."
62
+
63
+ # Model Preparation
64
+ Bottleneck = []
65
+ INP_Guided_Decoder = []
66
+ INP_Extractor = []
67
+
68
+ # bottleneck
69
+ Bottleneck.append(Mlp(embed_dim, embed_dim * 4, embed_dim, drop=0.))
70
+ Bottleneck = nn.ModuleList(Bottleneck)
71
+
72
+ # INP
73
+ INP = nn.ParameterList(
74
+ [nn.Parameter(torch.randn(args.INP_num, embed_dim))
75
+ for _ in range(1)])
76
+
77
+ # INP Extractor
78
+ for i in range(1):
79
+ blk = Aggregation_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.,
80
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8))
81
+ INP_Extractor.append(blk)
82
+ INP_Extractor = nn.ModuleList(INP_Extractor)
83
+
84
+ # INP_Guided_Decoder
85
+ for i in range(8):
86
+ blk = Prototype_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.,
87
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8))
88
+ INP_Guided_Decoder.append(blk)
89
+ INP_Guided_Decoder = nn.ModuleList(INP_Guided_Decoder)
90
+
91
+ model = INP_Former(encoder=encoder, bottleneck=Bottleneck, aggregation=INP_Extractor, decoder=INP_Guided_Decoder,
92
+ target_layers=target_layers, remove_class_token=True, fuse_layer_encoder=fuse_layer_encoder,
93
+ fuse_layer_decoder=fuse_layer_decoder, prototype_token=INP)
94
+ model = model.to(device)
95
+
96
+ gaussian_kernel = get_gaussian_kernel(kernel_size=5, sigma=4).to(device)
97
+
98
+
99
+ def resize_and_center_crop(image, resize_size=448, crop_size=392):
100
+ # Resize to 448x448
101
+ image_resized = cv2.resize(image, (resize_size, resize_size), interpolation=cv2.INTER_LINEAR)
102
+
103
+ # Compute crop coordinates
104
+ start = (resize_size - crop_size) // 2
105
+ end = start + crop_size
106
+
107
+ # Center crop to 392x392
108
+ image_cropped = image_resized[start:end, start:end, :]
109
+
110
+ return image_cropped
111
+
112
+ def process_image(image, options):
113
+ # Load the model based on selected options
114
+ if 'Real-IAD' in options:
115
+ model.load_state_dict(torch.load(ckt_path1), strict=True)
116
+ elif 'VisA' in options:
117
+ model.load_state_dict(torch.load(ckt_path2), strict=True)
118
+ else:
119
+ # Default to 'All' if no valid option is provided
120
+ model.load_state_dict(torch.load(ckt_path1), strict=True)
121
+ print('Invalid option. Defaulting to All.')
122
+
123
+ # Ensure image is in RGB mode
124
+ image = image.convert('RGB')
125
+
126
+
127
+
128
+ # Convert PIL image to NumPy array
129
+ np_image = np.array(image)
130
+ image_shape = np_image.shape[0]
131
+
132
+ # Convert RGB to BGR for OpenCV
133
+ np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
134
+ np_image = resize_and_center_crop(np_image, resize_size=args.input_size, crop_size=args.crop_size)
135
+
136
+ # Preprocess the image and run the model
137
+ input_image = data_transform(image)
138
+ input_image = input_image.to(device)
139
+
140
+ with torch.no_grad():
141
+ _ = model(input_image.unsqueeze(0))
142
+ anomaly_map = model.distance
143
+ side = int(model.distance.shape[1] ** 0.5)
144
+ anomaly_map = anomaly_map.reshape([anomaly_map.shape[0], side, side]).contiguous()
145
+ anomaly_map = torch.unsqueeze(anomaly_map, dim=1)
146
+ anomaly_map = F.interpolate(anomaly_map, size=input_image.shape[-1], mode='bilinear', align_corners=True)
147
+ anomaly_map = gaussian_kernel(anomaly_map)
148
+
149
+ # Process anomaly map
150
+ anomaly_map = anomaly_map.squeeze().cpu().numpy()
151
+ anomaly_map = (anomaly_map * 255).astype(np.uint8)
152
+
153
+ # Apply color map and blend with original image
154
+ heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
155
+ vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
156
+
157
+ # Convert OpenCV image back to PIL image for Gradio
158
+ vis_map_pil = Image.fromarray(cv2.resize(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB), (image_shape, image_shape)))
159
+
160
+ return vis_map_pil
161
+
162
+ # Define examples
163
+ examples = [
164
+ ["assets/img2.png", "Real-IAD"],
165
+ ["assets/img.png", "VisA"]
166
+ ]
167
+
168
+ # Gradio interface layout
169
+ demo = gr.Interface(
170
+ fn=process_image,
171
+ inputs=[
172
+ gr.Image(type="pil", label="Upload Image"),
173
+ gr.Radio(["Real-IAD",
174
+ "VisA"],
175
+ label="Pre-trained Datasets")
176
+ ],
177
+ outputs=[
178
+ gr.Image(type="pil", label="Output Image")
179
+ ],
180
+ examples=examples,
181
+ title="INP-Former -- Zero-shot Anomaly Detection",
182
+ description="Upload an image and select pre-trained datasets to do zero-shot anomaly detection"
183
+ )
184
+
185
+ # Launch the demo
186
+ demo.launch()
187
+ # demo.launch(server_name="0.0.0.0", server_port=10002)
188