import torch from numpy.random import normal import random import logging import numpy as np from torch.nn import functional as F from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score import cv2 import matplotlib.pyplot as plt from sklearn.metrics import auc from skimage import measure import pandas as pd from numpy import ndarray from statistics import mean import os from functools import partial import math from tqdm import tqdm import torch.backends.cudnn as cudnn def get_logger(name, save_path=None, level='INFO'): logger = logging.getLogger(name) logger.setLevel(getattr(logging, level)) log_format = logging.Formatter('%(message)s') streamHandler = logging.StreamHandler() streamHandler.setFormatter(log_format) logger.addHandler(streamHandler) if not save_path is None: os.makedirs(save_path, exist_ok=True) fileHandler = logging.FileHandler(os.path.join(save_path, 'log.txt')) fileHandler.setFormatter(log_format) logger.addHandler(fileHandler) return logger def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def modify_grad(x, inds, factor=0.): # print(inds.shape) inds = inds.expand_as(x) # print(x.shape) # print(inds.shape) x[inds] *= factor return x def modify_grad_v2(x, factor): factor = factor.expand_as(x) x *= factor return x def global_cosine_hm_adaptive(a, b, y=3): cos_loss = torch.nn.CosineSimilarity() loss = 0 for item in range(len(a)): a_ = a[item].detach() b_ = b[item] with torch.no_grad(): point_dist = 1 - cos_loss(a_, b_).unsqueeze(1).detach() mean_dist = point_dist.mean() # std_dist = point_dist.reshape(-1).std() # thresh = torch.topk(point_dist.reshape(-1), k=int(point_dist.numel() * (1 - p)))[0][-1] factor = (point_dist/mean_dist)**(y) # factor = factor/torch.max(factor) # factor = torch.clip(factor, min=min_grad) # print(thresh) loss += torch.mean(1 - cos_loss(a_.reshape(a_.shape[0], -1), b_.reshape(b_.shape[0], -1))) partial_func = partial(modify_grad_v2, factor=factor) b_.register_hook(partial_func) loss = loss / len(a) return loss def cal_anomaly_maps(fs_list, ft_list, out_size=224): if not isinstance(out_size, tuple): out_size = (out_size, out_size) a_map_list = [] for i in range(len(ft_list)): fs = fs_list[i] ft = ft_list[i] a_map = 1 - F.cosine_similarity(fs, ft) # mse_map = torch.mean((fs-ft)**2, dim=1) # a_map = mse_map a_map = torch.unsqueeze(a_map, dim=1) a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True) a_map_list.append(a_map) anomaly_map = torch.cat(a_map_list, dim=1).mean(dim=1, keepdim=True) return anomaly_map, a_map_list def min_max_norm(image): a_min, a_max = image.min(), image.max() return (image - a_min) / (a_max - a_min) def return_best_thr(y_true, y_score): precs, recs, thrs = precision_recall_curve(y_true, y_score) f1s = 2 * precs * recs / (precs + recs + 1e-7) f1s = f1s[:-1] thrs = thrs[~np.isnan(f1s)] f1s = f1s[~np.isnan(f1s)] best_thr = thrs[np.argmax(f1s)] return best_thr def f1_score_max(y_true, y_score): precs, recs, thrs = precision_recall_curve(y_true, y_score) f1s = 2 * precs * recs / (precs + recs + 1e-7) f1s = f1s[:-1] return f1s.max() def specificity_score(y_true, y_score): y_true = np.array(y_true) y_score = np.array(y_score) TN = (y_true[y_score == 0] == 0).sum() N = (y_true == 0).sum() return TN / N def denormalize(img): std = np.array([0.229, 0.224, 0.225]) mean = np.array([0.485, 0.456, 0.406]) x = (((img.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8) return x def save_imag_ZS(imgs, anomaly_map, gt, prototype_map, save_root, img_path): batch_num = imgs.shape[0] for i in range(batch_num): img_path_list = img_path[i].split('\\') class_name, category, idx_name = img_path_list[-4], img_path_list[-2], img_path_list[-1] os.makedirs(os.path.join(save_root, class_name, category), exist_ok=True) input_frame = denormalize(imgs[i].clone().squeeze(0).cpu().detach().numpy()) cv2_input = np.array(input_frame, dtype=np.uint8) plt.imsave(os.path.join(save_root, class_name, category, fr'{idx_name}_0.png'), cv2_input) ano_map = anomaly_map[i].squeeze(0).cpu().detach().numpy() plt.imsave(os.path.join(save_root, class_name, category, fr'{idx_name}_1.png'), ano_map, cmap='jet') gt_map = gt[i].squeeze(0).cpu().detach().numpy() plt.imsave(os.path.join(save_root, class_name, category, fr'{idx_name}_2.png'), gt_map, cmap='gray') distance = prototype_map[i].view((28, 28)).cpu().detach().numpy() distance = cv2.resize(distance, (392, 392), interpolation=cv2.INTER_AREA) plt.imsave(os.path.join(save_root, class_name, category, fr'{idx_name}_3.png'), distance, cmap='jet') plt.close() def evaluation_batch(model, dataloader, device, _class_=None, max_ratio=0, resize_mask=None): model.eval() gt_list_px = [] pr_list_px = [] gt_list_sp = [] pr_list_sp = [] gaussian_kernel = get_gaussian_kernel(kernel_size=5, sigma=4).to(device) with torch.no_grad(): for img, gt, label, img_path in tqdm(dataloader, ncols=80): img = img.to(device) output = model(img) en, de = output[0], output[1] anomaly_map, _ = cal_anomaly_maps(en, de, img.shape[-1]) if resize_mask is not None: anomaly_map = F.interpolate(anomaly_map, size=resize_mask, mode='bilinear', align_corners=False) gt = F.interpolate(gt, size=resize_mask, mode='nearest') anomaly_map = gaussian_kernel(anomaly_map) gt[gt > 0.5] = 1 gt[gt <= 0.5] = 0 # gt = gt.bool() if gt.shape[1] > 1: gt = torch.max(gt, dim=1, keepdim=True)[0] gt_list_px.append(gt) pr_list_px.append(anomaly_map) gt_list_sp.append(label) if max_ratio == 0: sp_score = torch.max(anomaly_map.flatten(1), dim=1)[0] else: anomaly_map = anomaly_map.flatten(1) sp_score = torch.sort(anomaly_map, dim=1, descending=True)[0][:, :int(anomaly_map.shape[1] * max_ratio)] sp_score = sp_score.mean(dim=1) pr_list_sp.append(sp_score) gt_list_px = torch.cat(gt_list_px, dim=0)[:, 0].cpu().numpy() pr_list_px = torch.cat(pr_list_px, dim=0)[:, 0].cpu().numpy() gt_list_sp = torch.cat(gt_list_sp).flatten().cpu().numpy() pr_list_sp = torch.cat(pr_list_sp).flatten().cpu().numpy() # aupro_px = compute_pro(gt_list_px, pr_list_px) gt_list_px, pr_list_px = gt_list_px.ravel(), pr_list_px.ravel() auroc_px = roc_auc_score(gt_list_px, pr_list_px) auroc_sp = roc_auc_score(gt_list_sp, pr_list_sp) ap_px = average_precision_score(gt_list_px, pr_list_px) ap_sp = average_precision_score(gt_list_sp, pr_list_sp) f1_sp = f1_score_max(gt_list_sp, pr_list_sp) f1_px = f1_score_max(gt_list_px, pr_list_px) # return [auroc_sp, ap_sp, f1_sp, auroc_px, ap_px, f1_px, aupro_px] return [auroc_sp, ap_sp, f1_sp, auroc_px, ap_px, f1_px, 0.] def evaluation_batch_vis_ZS(model, dataloader, device, _class_=None, max_ratio=0, resize_mask=None, save_root=None): model.eval() gt_list_px = [] pr_list_px = [] gt_list_sp = [] pr_list_sp = [] gaussian_kernel = get_gaussian_kernel(kernel_size=5, sigma=4).to(device) with torch.no_grad(): for img, gt, label, img_path in tqdm(dataloader, ncols=80): img = img.to(device) _ = model(img) anomaly_map = model.distance side = int(model.distance.shape[1]**0.5) anomaly_map = anomaly_map.reshape([anomaly_map.shape[0], side, side]).contiguous() anomaly_map = torch.unsqueeze(anomaly_map, dim=1) anomaly_map = F.interpolate(anomaly_map, size=img.shape[-1], mode='bilinear', align_corners=True) if resize_mask is not None: anomaly_map = F.interpolate(anomaly_map, size=resize_mask, mode='bilinear', align_corners=False) gt = F.interpolate(gt, size=resize_mask, mode='nearest') anomaly_map = gaussian_kernel(anomaly_map) save_imag_ZS(img, anomaly_map, gt, model.distance, save_root, img_path) gt[gt > 0.5] = 1 gt[gt <= 0.5] = 0 # gt = gt.bool() if gt.shape[1] > 1: gt = torch.max(gt, dim=1, keepdim=True)[0] gt_list_px.append(gt) pr_list_px.append(anomaly_map) gt_list_sp.append(label) if max_ratio == 0: sp_score = torch.max(anomaly_map.flatten(1), dim=1)[0] else: anomaly_map = anomaly_map.flatten(1) sp_score = torch.sort(anomaly_map, dim=1, descending=True)[0][:, :int(anomaly_map.shape[1] * max_ratio)] sp_score = sp_score.mean(dim=1) pr_list_sp.append(sp_score) gt_list_px = torch.cat(gt_list_px, dim=0)[:, 0].cpu().numpy() pr_list_px = torch.cat(pr_list_px, dim=0)[:, 0].cpu().numpy() gt_list_sp = torch.cat(gt_list_sp).flatten().cpu().numpy() pr_list_sp = torch.cat(pr_list_sp).flatten().cpu().numpy() # aupro_px = compute_pro(gt_list_px, pr_list_px) gt_list_px, pr_list_px = gt_list_px.ravel(), pr_list_px.ravel() auroc_px = roc_auc_score(gt_list_px, pr_list_px) auroc_sp = roc_auc_score(gt_list_sp, pr_list_sp) ap_px = average_precision_score(gt_list_px, pr_list_px) ap_sp = average_precision_score(gt_list_sp, pr_list_sp) f1_sp = f1_score_max(gt_list_sp, pr_list_sp) f1_px = f1_score_max(gt_list_px, pr_list_px) # return [auroc_sp, ap_sp, f1_sp, auroc_px, ap_px, f1_px, aupro_px] return [auroc_sp, ap_sp, f1_sp, auroc_px, ap_px, f1_px, 0.] def compute_pro(masks: ndarray, amaps: ndarray, num_th: int = 200) -> None: """Compute the area under the curve of per-region overlaping (PRO) and 0 to 0.3 FPR Args: category (str): Category of product masks (ndarray): All binary masks in test. masks.shape -> (num_test_data, h, w) amaps (ndarray): All anomaly maps in test. amaps.shape -> (num_test_data, h, w) num_th (int, optional): Number of thresholds """ assert isinstance(amaps, ndarray), "type(amaps) must be ndarray" assert isinstance(masks, ndarray), "type(masks) must be ndarray" assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)" assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)" assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same" assert set(masks.flatten()) == {0, 1}, "set(masks.flatten()) must be {0, 1}" assert isinstance(num_th, int), "type(num_th) must be int" df = pd.DataFrame([], columns=["pro", "fpr", "threshold"]) binary_amaps = np.zeros_like(amaps, dtype=np.bool) min_th = amaps.min() max_th = amaps.max() delta = (max_th - min_th) / num_th for th in np.arange(min_th, max_th, delta): binary_amaps[amaps <= th] = 0 binary_amaps[amaps > th] = 1 pros = [] for binary_amap, mask in zip(binary_amaps, masks): for region in measure.regionprops(measure.label(mask)): axes0_ids = region.coords[:, 0] axes1_ids = region.coords[:, 1] tp_pixels = binary_amap[axes0_ids, axes1_ids].sum() pros.append(tp_pixels / region.area) inverse_masks = 1 - masks fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() fpr = fp_pixels / inverse_masks.sum() df = df.append({"pro": mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True) # Normalize FPR from 0 ~ 1 to 0 ~ 0.3 df = df[df["fpr"] < 0.3] df["fpr"] = df["fpr"] / df["fpr"].max() pro_auc = auc(df["fpr"], df["pro"]) return pro_auc def get_gaussian_kernel(kernel_size=3, sigma=2, channels=1): # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) x_coord = torch.arange(kernel_size) x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) y_grid = x_grid.t() xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() mean = (kernel_size - 1) / 2. variance = sigma ** 2. # Calculate the 2-dimensional gaussian kernel which is # the product of two gaussian distributions for two different # variables (in this case called x and y) gaussian_kernel = (1. / (2. * math.pi * variance)) * \ torch.exp( -torch.sum((xy_grid - mean) ** 2., dim=-1) / \ (2 * variance) ) # Make sure sum of values in gaussian kernel equals 1. gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) # Reshape to 2d depthwise convolutional weight gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) gaussian_filter = torch.nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, groups=channels, bias=False, padding=kernel_size // 2) gaussian_filter.weight.data = gaussian_kernel gaussian_filter.weight.requires_grad = False return gaussian_filter from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import ReduceLROnPlateau class WarmCosineScheduler(_LRScheduler): def __init__(self, optimizer, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, ): self.final_value = final_value self.total_iters = total_iters warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(total_iters - warmup_iters) schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) self.schedule = np.concatenate((warmup_schedule, schedule)) super(WarmCosineScheduler, self).__init__(optimizer) def get_lr(self): if self.last_epoch >= self.total_iters: return [self.final_value for base_lr in self.base_lrs] else: return [self.schedule[self.last_epoch] for base_lr in self.base_lrs]