GSASR / utils /fea2gsropeamp.py
mt-cly
init
909940e
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import warnings
import math
import copy
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_, kaiming_normal_
from einops import rearrange
from torch.utils.checkpoint import checkpoint
from functools import partial
from typing import Any, Optional, Tuple
import numpy as np
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
low = norm_cdf((a - mean) / std)
up = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [low, up], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * low - 1, 2 * up - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution.
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def init_t_xy(end_x: int, end_y: int, zero_center=False):
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode='floor').float()
return t_x, t_y
def init_random_2d_freqs(head_dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
freqs_x = []
freqs_y = []
theta = theta
mag = 1 / (theta ** (torch.arange(0, head_dim, 4)[: (head_dim // 4)].float() / head_dim))
for i in range(num_heads):
angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
freqs_x.append(fx)
freqs_y.append(fy)
freqs_x = torch.stack(freqs_x, dim=0)
freqs_y = torch.stack(freqs_y, dim=0)
freqs = torch.stack([freqs_x, freqs_y], dim=0)
return freqs
def compute_cis(freqs, t_x, t_y):
N = t_x.shape[0]
# No float 16 for this range
with torch.cuda.amp.autocast(enabled=False):
freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
# assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
# print(f"freqs_cis shape is {freqs_cis.shape}, x shape is {x.shape}")
if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# print(f"xq shape is {xq.shape}, xq.shape[:-1] is {xq.shape[:-1]}")
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
# print(f"xq_ shape is {xq_.shape}")
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
def apply_rotary_emb_single(x, freqs_cis):
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
seq_len = x_.shape[2]
freqs_cis = freqs_cis[:, :seq_len, :]
freqs_cis = freqs_cis.unsqueeze(0).expand_as(x_)
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
return x_out.type_as(x).to(x.device)
def window_partition(x, window_size):
# x is the feature from net_g
b, c, h, w = x.shape
windows = rearrange(x, 'b c (h_count dh) (w_count dw) -> (b h_count w_count) (dh dw) c', dh=window_size,
dw=window_size)
# h_count = h // window_size
# w_count = w // window_size
# windows = x.reshape(b,c,h_count, window_size, w_count, window_size)
# windows = windows.permute(0,1,2,4,3,5) #b,c,h_count,w_count,window_size,window_size
# windows = windows.reshape(b,c,h_count*w_count, window_size * window_size)
# windows = windows.permute(0,2,3,1) #b,h_count*w_count, window_size*window_size,c
# windows = windows.reshape(-1, window_size*window_size, c)
return windows
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
class MLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features, act_layer=nn.ReLU):
super(MLP, self).__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class WindowCrossAttn(nn.Module):
def __init__(self, dim=180, num_heads=6, window_size=12, num_gs_seed=2304, rope_mixed = True, rope_theta = 10.0):
super(WindowCrossAttn, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.num_gs_seed = num_gs_seed
self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
self.rope_mixed = rope_mixed
t_x, t_y = init_t_xy(end_x=max(self.num_gs_seed_sqrt, self.window_size), end_y=max(self.num_gs_seed_sqrt, self.window_size))
self.register_buffer('rope_t_x', t_x)
self.register_buffer('rope_t_y', t_y)
freqs = init_random_2d_freqs(
head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
rotate=self.rope_mixed
)
if self.rope_mixed:
self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
else:
self.register_buffer('rope_freqs', freqs)
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
self.rope_freqs_cis = freqs_cis
self.qhead = nn.Linear(dim, dim, bias=True)
self.khead = nn.Linear(dim, dim, bias=True)
self.vhead = nn.Linear(dim, dim, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(self, gs, feat):
# gs shape: b*h_count*w_count, num_gs, c the input gs here should already include pos embedding and scale embedding
# feat shape: b*h_count*w_count, dh*dw, c dh=dw=window_size
b_, num_gs, c = gs.shape
b_, n, c = feat.shape
q = self.qhead(gs) # b_, num_gs_, c
q = q.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
q = q.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
k = self.khead(feat) # b_, n_, c
k = k.reshape(b_, n, self.num_heads, c // self.num_heads)
k = k.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
v = self.vhead(feat) # b_, n_, c
v = v.reshape(b_, n, self.num_heads, c // self.num_heads)
v = v.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
###### Apply rotary position embedding
if self.rope_mixed:
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
else:
freqs_cis = self.rope_freqs_cis.to(gs.device)
q = apply_rotary_emb_single(q, freqs_cis)
k = apply_rotary_emb_single(k, freqs_cis)
#########
attn = F.scaled_dot_product_attention(q, k, v)
x = attn.transpose(1, 2).reshape(b_, num_gs, c)
x = self.proj(x)
return x
class WindowCrossAttnLayer(nn.Module):
def __init__(self, dim=180, num_heads=6, window_size=12, shift_size=0, num_gs_seed=2308, rope_mixed = True, rope_theta = 10.0):
super(WindowCrossAttnLayer, self).__init__()
self.gs_cross_attn_scale = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.norm4 = nn.LayerNorm(dim)
self.shift_size = shift_size
self.window_size = window_size
self.window_cross_attn = WindowCrossAttn(dim=dim, num_heads=num_heads, window_size=window_size,
num_gs_seed=num_gs_seed, rope_mixed = rope_mixed, rope_theta = rope_theta)
self.mlp_crossattn_scale = MLP(in_features=dim, hidden_features=dim, out_features=dim)
self.mlp_crossattn_feature = MLP(in_features=dim, hidden_features=dim, out_features=dim)
def forward(self, x, query_pos, feat, scale_embedding):
# gs shape: b*h_count*w_count, num_gs, c
# query_pos shape: b*h_count*w_count, num_gs, c
# feat shape: b,c,h,w
# scale_embedding shape: b*h_count*w_count, 1, c
###GS cross attn with scale embedding
resi = x
x = self.norm1(x)
# print(f"x: {x.shape} {x.device}, query_pos: {query_pos.shape}, {query_pos.device}, scale_embedding: {scale_embedding.shape}, {scale_embedding.device}")
x, _ = self.gs_cross_attn_scale(with_pos_embed(x, query_pos), scale_embedding, scale_embedding)
x = resi + x
###FFN
resi = x
x = self.norm2(x)
x = self.mlp_crossattn_scale(x)
x = resi + x
###cross attention for Q,K,V
resi = x
x = self.norm3(x)
if self.shift_size > 0:
shift_feat = torch.roll(feat, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
else:
shift_feat = feat
shift_feat = window_partition(shift_feat, self.window_size) # b*h_count*w_count, dh*dw, c dh=dw=window_size
x = self.window_cross_attn(with_pos_embed(x, query_pos),
shift_feat) # b*h_count*w_count, num_gs, c dh=dw=window_size
x = resi + x
###FFN
resi = x
x = self.norm4(x)
x = self.mlp_crossattn_feature(x)
x = resi + x
return x
class WindowCrossAttnBlock(nn.Module):
def __init__(self, dim=180, window_size=12, num_heads=6, num_layers=4, num_gs_seed=230, rope_mixed = True, rope_theta = 10.0):
super(WindowCrossAttnBlock, self).__init__()
self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
self.mlp = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim)
)
self.norm = nn.LayerNorm(dim)
self.blocks = nn.ModuleList([
WindowCrossAttnLayer(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if i % 2 == 0 else window_size // 2,
num_gs_seed=num_gs_seed,
rope_mixed = rope_mixed, rope_theta = rope_theta) for i in range(num_layers)
])
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
def forward(self, x, query_pos, feat, scale_embedding, h_count, w_count):
resi = x
x = self.norm(x)
for block in self.blocks:
x = block(x, query_pos, feat, scale_embedding)
x = self.mlp(x)
x = rearrange(x, '(b m n) (h w) c -> b c (m h) (n w)', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
x = self.conv(x)
x = rearrange(x, 'b c (m h) (n w) -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
x = resi + x
return x
class GSSelfAttn(nn.Module):
def __init__(self, dim=180, num_heads=6, num_gs_seed_sqrt = 12, rope_mixed = True, rope_theta=10.0):
super(GSSelfAttn, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.num_gs_seed_sqrt = num_gs_seed_sqrt
self.proj = nn.Linear(dim, dim)
self.rope_mixed = rope_mixed
t_x, t_y = init_t_xy(end_x=self.num_gs_seed_sqrt, end_y=self.num_gs_seed_sqrt)
self.register_buffer('rope_t_x', t_x)
self.register_buffer('rope_t_y', t_y)
freqs = init_random_2d_freqs(
head_dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
rotate=self.rope_mixed
)
if self.rope_mixed:
self.rope_freqs = nn.Parameter(freqs, requires_grad=True)
else:
self.register_buffer('rope_freqs', freqs)
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
self.rope_freqs_cis = freqs_cis
self.qhead = nn.Linear(dim, dim, bias=True)
self.khead = nn.Linear(dim, dim, bias=True)
self.vhead = nn.Linear(dim, dim, bias=True)
def forward(self, gs):
# gs shape: b*h_count*w_count, num_gs, c
# pos shape: b*h_count*w_count, num_gs, c
b_, num_gs, c = gs.shape
q = self.qhead(gs)
q = q.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
q = q.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
k = self.khead(gs)
k = k.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
k = k.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
v = self.vhead(gs)
v = v.reshape(b_, num_gs, self.num_heads, c // self.num_heads)
v = v.permute(0, 2, 1, 3) # b_, num_heads, n, c // num_heads
###### Apply rotary position embedding
if self.rope_mixed:
freqs_cis = compute_cis(self.rope_freqs, self.rope_t_x, self.rope_t_y)
else:
freqs_cis = self.rope_freqs_cis.to(gs.device)
q, k = apply_rotary_emb(q, k, freqs_cis)
#########
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.transpose(1, 2).reshape(b_, num_gs, c)
attn = self.proj(attn)
return attn
class GSSelfAttnLayer(nn.Module):
def __init__(self, dim=180, num_heads=6, num_gs_seed_sqrt = 12, shift_size = 0, rope_mixed = True, rope_theta=10.0):
super(GSSelfAttnLayer, self).__init__()
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.norm4 = nn.LayerNorm(dim)
self.gs_self_attn = GSSelfAttn(dim = dim, num_heads = num_heads, num_gs_seed_sqrt = num_gs_seed_sqrt, rope_mixed = rope_mixed, rope_theta=rope_theta)
self.mlp_selfattn = MLP(in_features=dim, hidden_features=dim, out_features=dim)
self.num_gs_seed_sqrt = num_gs_seed_sqrt
self.shift_size = shift_size
self.gs_cross_attn_scale = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.mlp_crossattn = MLP(in_features=dim, hidden_features=dim, out_features=dim)
def forward(self, gs, pos, h_count, w_count, scale_embedding):
# gs shape:b*h_count*w_count, num_gs_seed, channel
# pos shape: b*h_count*w_count, num_gs_seed, channel
# scale_embedding shape: b*h_count*w_count, 1, channel
# gs cross attn with scale_embedding
resi = gs
gs = self.norm3(gs)
gs, _ = self.gs_cross_attn_scale(with_pos_embed(gs, pos), scale_embedding, scale_embedding)
gs = gs + resi
# FFN
resi = gs
gs = self.norm4(gs)
gs = self.mlp_crossattn(gs)
gs = gs + resi
resi = gs
gs = self.norm1(gs)
#### shift gs
if self.shift_size > 0:
shift_gs = rearrange(gs, '(b m n) (h w) c -> b (m h) (n w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
shift_gs = torch.roll(shift_gs, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
shift_gs = rearrange(shift_gs, 'b (m h) (n w) c -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
else:
shift_gs = gs
#### gs self attention
gs = self.gs_self_attn(shift_gs)
#### shift gs back
if self.shift_size > 0:
shift_gs = rearrange(gs, '(b m n) (h w) c -> b (m h) (n w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
shift_gs = torch.roll(shift_gs, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
shift_gs = rearrange(shift_gs, 'b (m h) (n w) c -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt, w = self.num_gs_seed_sqrt)
else:
shift_gs = gs
gs = shift_gs + resi
#FFN
resi = gs
gs = self.norm2(gs)
gs = self.mlp_selfattn(gs)
gs = gs + resi
return gs
class GSSelfAttnBlock(nn.Module):
def __init__(self, dim=180, num_heads=6, num_selfattn_layers=4, num_gs_seed_sqrt = 12, rope_mixed = True, rope_theta=10.0):
super(GSSelfAttnBlock, self).__init__()
self.num_gs_seed_sqrt = num_gs_seed_sqrt
self.mlp = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim)
)
self.norm = nn.LayerNorm(dim)
self.blocks = nn.ModuleList([
GSSelfAttnLayer(
dim = dim,
num_heads = num_heads,
num_gs_seed_sqrt=num_gs_seed_sqrt,
shift_size=0 if i % 2 == 0 else num_gs_seed_sqrt // 2,
rope_mixed = rope_mixed, rope_theta=rope_theta
) for i in range(num_selfattn_layers)
])
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
def forward(self, gs, pos, h_count, w_count, scale_embedding):
resi = gs
gs = self.norm(gs)
for block in self.blocks:
gs = block(gs, pos, h_count, w_count, scale_embedding)
gs = self.mlp(gs)
gs = rearrange(gs, '(b m n) (h w) c -> b c (m h) (n w)', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
gs = self.conv(gs)
gs = rearrange(gs, 'b c (m h) (n w) -> (b m n) (h w) c', m=h_count, n=w_count, h=self.num_gs_seed_sqrt)
gs = gs + resi
return gs
class Fea2GS_ROPE_AMP(nn.Module):
def __init__(self, inchannel=64, channel=192, num_heads=6, num_crossattn_blocks=1, num_crossattn_layers=2, num_selfattn_blocks = 6, num_selfattn_layers = 6,
num_gs_seed=144, gs_up_factor=1.0, window_size=12, img_range=1.0, shuffle_scale1 = 2, shuffle_scale2 = 2, use_checkpoint = False,
rope_mixed = True, rope_theta = 10.0):
"""
Args:
gs_repeat_factor: the ratio of gs embedding number and pixel number along width&height, will generate
(h * gs_repeat_factor) * (w * gs_repeat_factor) gs embedding, higher values means repeat more gs embedding.
gs_up_factor: how many 2d gaussian are generated by one gasussian embedding.
"""
super(Fea2GS_ROPE_AMP, self).__init__()
self.channel = channel
self.nhead = num_heads
self.gs_up_factor = gs_up_factor
self.num_gs_seed = num_gs_seed
self.window_size = window_size
self.img_range = img_range
self.use_checkpoint = use_checkpoint
self.num_gs_seed_sqrt = int(math.sqrt(num_gs_seed))
self.gs_up_factor_sqrt = int(math.sqrt(gs_up_factor))
self.shuffle_scale1 = shuffle_scale1
self.shuffle_scale2 = shuffle_scale2
# shared gaussian embedding and its pos embedding
self.gs_embedding = nn.Parameter(torch.randn(self.num_gs_seed, channel), requires_grad=True)
self.pos_embedding = nn.Parameter(torch.randn(self.num_gs_seed, channel), requires_grad=True)
self.img_feat_proj = nn.Sequential(
nn.Conv2d(inchannel, channel, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(channel, channel, 3, 1, 1)
)
self.window_crossattn_blocks = nn.ModuleList([
WindowCrossAttnBlock(dim=channel,
window_size=window_size,
num_heads=num_heads,
num_layers=num_crossattn_layers,
num_gs_seed=num_gs_seed, rope_mixed = rope_mixed, rope_theta = rope_theta) for i in range(num_crossattn_blocks)
])
self.gs_selfattn_blocks = nn.ModuleList([
GSSelfAttnBlock(dim=channel,
num_heads=num_heads,
num_selfattn_layers=num_selfattn_layers,
num_gs_seed_sqrt=self.num_gs_seed_sqrt,
rope_mixed = rope_mixed, rope_theta=rope_theta
) for i in range(num_selfattn_blocks)
])
# GS sigma_x, sigma_y
self.mlp_block_sigma = nn.Sequential(
nn.Linear(channel, channel),
nn.ReLU(),
nn.Linear(channel, channel * 4),
nn.ReLU(),
nn.Linear(channel * 4, int(2 * gs_up_factor))
)
# GS rho
self.mlp_block_rho = nn.Sequential(
nn.Linear(channel, channel),
nn.ReLU(),
nn.Linear(channel, channel * 4),
nn.ReLU(),
nn.Linear(channel * 4, int(1 * gs_up_factor))
)
# GS alpha
self.mlp_block_alpha = nn.Sequential(
nn.Linear(channel, channel),
nn.ReLU(),
nn.Linear(channel, channel * 4),
nn.ReLU(),
nn.Linear(channel * 4, int(1 * gs_up_factor))
)
# GS RGB values
self.mlp_block_rgb = nn.Sequential(
nn.Linear(channel, channel),
nn.ReLU(),
nn.Linear(channel, channel * 4),
nn.ReLU(),
nn.Linear(channel * 4, int(3 * gs_up_factor))
)
# GS mean_x, mean_y
self.mlp_block_mean = nn.Sequential(
nn.Linear(channel, channel),
nn.ReLU(),
nn.Linear(channel, channel * 4),
nn.ReLU(),
nn.Linear(channel * 4, int(2 * gs_up_factor))
)
self.scale_mlp = nn.Sequential(
nn.Linear(1, channel * 4),
nn.ReLU(),
nn.Linear(channel * 4, channel)
)
self.UPNet = nn.Sequential(
nn.Conv2d(channel, channel * self.shuffle_scale1 * self.shuffle_scale1, 3, 1, 1),
nn.PixelShuffle(self.shuffle_scale1),
nn.Conv2d(channel, channel * self.shuffle_scale2 * self.shuffle_scale2, 3, 1, 1),
nn.PixelShuffle(self.shuffle_scale2)
)
self.conv_final = nn.Conv2d(channel, channel, 3, 1, 1)
@staticmethod
def get_N_reference_points(h, w, device='cuda'):
# step_y = 1/(h+1)
# step_x = 1/(w+1)
step_y = 1 / h
step_x = 1 / w
ref_y, ref_x = torch.meshgrid(torch.linspace(step_y / 2, 1 - step_y / 2, h, dtype=torch.float32, device=device),
torch.linspace(step_x / 2, 1 - step_x / 2, w, dtype=torch.float32, device=device))
reference_points = torch.stack((ref_x.reshape(-1), ref_y.reshape(-1)), -1)
reference_points = reference_points[None, :, None]
return reference_points
def forward(self, srcs, scale):
'''
using deformable detr decoder for cross attention
Args:
query: (batch_size, num_query, dim)
query_pos: (batch_size, num_query, dim)
srcs: (batch_size, dim, h1, w1)
'''
b, c, h, w = srcs.shape ###srcs is pad to the size that could be divided by window_size
query = self.gs_embedding.unsqueeze(0).unsqueeze(1).repeat(b, (h // self.window_size) * (w // self.window_size),
1, 1) # b, h_count*w_count, num_gs_seed, channel
query = query.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
self.channel) # b*h_count*w_count, num_gs_seed, channel
scale = 1 / scale
scale = scale.unsqueeze(1) # b*1
scale_embedding = self.scale_mlp(scale) # b*channel
scale_embedding = scale_embedding.unsqueeze(1).unsqueeze(2).repeat(1, (h // self.window_size) * (
w // self.window_size), self.num_gs_seed, 1) # b, h_count*w_count, num_gs_seed, channel
scale_embedding = scale_embedding.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
self.channel) # b*h_count*w_count, num_gs_seed, channel
query_pos = self.pos_embedding.unsqueeze(0).unsqueeze(1).repeat(b, (h // self.window_size) * (
w // self.window_size), 1, 1) # b, h_count*w_count, num_gs_seed, channel
feat = self.img_feat_proj(srcs) # b*channel*h*w
query_pos = query_pos.reshape(b * (h // self.window_size) * (w // self.window_size), -1,
self.channel) # b*h_count*w_count, num_gs_seed, channel
for block in self.window_crossattn_blocks:
if self.use_checkpoint:
query = checkpoint(block, query, query_pos, feat, scale_embedding, h // self.window_size, w // self.window_size)
else:
query = block(query, query_pos, feat, scale_embedding, h // self.window_size, w // self.window_size) # b*h_count*w_count, num_gs_seed, channel
resi = query
for block in self.gs_selfattn_blocks:
if self.use_checkpoint:
query = checkpoint(block, query, query_pos, h // self.window_size, w // self.window_size, scale_embedding)
else:
query = block(query, query_pos, h // self.window_size, w // self.window_size, scale_embedding)
query = rearrange(query, '(b m n) (h w) c -> b c (m h) (n w)', m=h // self.window_size, n=w // self.window_size,
h=self.num_gs_seed_sqrt)
query = self.conv_final(query)
resi = rearrange(resi, '(b m n) (h w) c -> b c (m h) (n w)', m=h // self.window_size, n=w // self.window_size,
h=self.num_gs_seed_sqrt)
query = query + resi
query = self.UPNet(query)
query = query.permute(0,2,3,1)
# query = rearrange(query, '(b m n) (h w) c -> b m h n w c', m=h // self.window_size, n=w // self.window_size,
# h=self.num_gs_seed_sqrt)
query_sigma = self.mlp_block_sigma(query).reshape(b, -1, 2)
query_rho = self.mlp_block_rho(query).reshape(b, -1, 1)
query_alpha = self.mlp_block_alpha(query).reshape(b, -1, 1)
query_rgb = self.mlp_block_rgb(query).reshape(b, -1, 3)
query_mean = self.mlp_block_mean(query).reshape(b, -1, 2)
query_mean = query_mean / torch.tensor(
[self.num_gs_seed_sqrt * (w // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2,
self.num_gs_seed_sqrt * (h // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2])[
None, None].to(query_mean.device) # b, h_count*w_count*num_gs_seed, 2
reference_offset = self.get_N_reference_points(self.num_gs_seed_sqrt * (h // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2,
self.num_gs_seed_sqrt * (w // self.window_size) * self.shuffle_scale1 * self.shuffle_scale2, srcs.device)
query_mean = query_mean + reference_offset.reshape(1, -1, 2)
query = torch.cat([query_sigma, query_rho, query_alpha, query_rgb, query_mean],
dim=-1) # b, h_count*w_count*num_gs_seed, 9
return query
if __name__ == '__main__':
srcs = torch.randn(6, 64, 64, 64, requires_grad = True).cuda()
scale = torch.randn(6).cuda()
decoder = Fea2GS_ROPE_AMP(inchannel=64, channel=192, num_heads=6,
num_crossattn_blocks=1, num_crossattn_layers=2,
num_selfattn_blocks = 6, num_selfattn_layers = 6,
num_gs_seed=256, gs_up_factor=1.0, window_size=16,
img_range=1.0, shuffle_scale1 = 2, shuffle_scale2 = 2).cuda()
import time
for i in range(10):
torch.cuda.synchronize()
time1 = time.time()
# with torch.autocast(device_type = 'cuda'):
y = decoder(srcs, scale)
torch.cuda.synchronize()
time2 = time.time()
print(f"decoder time is {time2 - time1}")
print(y.shape)
torch.cuda.synchronize()
time3 = time.time()
y.sum().backward()
torch.cuda.synchronize()
time4 = time.time()
print(f"backward time is {time4 - time3}")