INP-Former / aug_funcs.py
luoweibetter's picture
Upload 205 files
4057a1f verified
raw
history blame
2.41 kB
import numpy as np
import torch
import torch.nn.functional as F
import kornia as K
def embedding_concat(x, y, use_cuda):
device = torch.device('cuda' if use_cuda else 'cpu')
B, C1, H1, W1 = x.size()
_, C2, H2, W2 = y.size()
s = int(H1 / H2)
x = F.unfold(x, kernel_size=s, dilation=1, stride=s)
x = x.view(B, C1, -1, H2, W2)
z = torch.zeros(B, C1 + C2, x.size(2), H2, W2).to(device)
for i in range(x.size(2)):
z[:, :, i, :, :] = torch.cat((x[:, :, i, :, :], y), 1)
z = z.view(B, -1, H2 * W2)
z = F.fold(z, kernel_size=s, output_size=(H1, W1), stride=s)
return z
def mahalanobis_torch(u, v, cov):
delta = u - v
m = torch.dot(delta, torch.matmul(cov, delta))
return torch.sqrt(m)
def get_rot_mat(theta):
theta = torch.tensor(theta)
return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0]])
def get_translation_mat(a, b):
return torch.tensor([[1, 0, a],
[0, 1, b]])
def rot_img(x, theta):
dtype = torch.FloatTensor
rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
grid = F.affine_grid(rot_mat, x.size()).type(dtype)
x = F.grid_sample(x, grid, padding_mode="reflection")
return x
def translation_img(x, a, b):
dtype = torch.FloatTensor
rot_mat = get_translation_mat(a, b)[None, ...].type(dtype).repeat(x.shape[0],1,1)
grid = F.affine_grid(rot_mat, x.size()).type(dtype)
x = F.grid_sample(x, grid, padding_mode="reflection")
return x
def hflip_img(x):
x = K.geometry.transform.hflip(x)
return x
def rot90_img(x,k):
# k is 0,1,2,3
degreesarr = [0., 90., 180., 270., 360]
degrees = torch.tensor(degreesarr[k])
x = K.geometry.transform.rotate(x, angle = degrees, padding_mode='reflection')
return x
def grey_img(x):
x = K.color.rgb_to_grayscale(x)
x = x.repeat(1, 3, 1,1)
return x
def denormalization(x):
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
x = (((x.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8)
# x = (x.transpose(1, 2, 0) * 255.).astype(np.uint8)
return x
def denorm(x):
"""Convert the range from [-1, 1] to [0, 1]."""
out = (x + 1) / 2
return out.clamp_(0, 1)