GSASR / utils /rdn.py
mt-cly
init
909940e
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
class RDB_Conv(nn.Module):
def __init__(self, inChannels, growRate, kSize=3):
super(RDB_Conv, self).__init__()
Cin = inChannels
G = growRate
self.conv = nn.Sequential(*[
nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
nn.ReLU()
])
def forward(self, x):
out = self.conv(x)
return torch.cat((x, out), 1)
class RDB(nn.Module):
def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
super(RDB, self).__init__()
G0 = growRate0
G = growRate
C = nConvLayers
convs = []
for c in range(C):
convs.append(RDB_Conv(G0 + c*G, G))
self.convs = nn.Sequential(*convs)
# Local Feature Fusion
self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
def forward(self, x):
return self.LFF(self.convs(x)) + x
class RDNNOUP(nn.Module):
def __init__(self, G0 = 64, kSize = 3, r = 4, n_colors = 3, RDNconfig = 'B',
no_upsampling = True, img_range = 1.0):
super(RDNNOUP, self).__init__()
self.no_upsampling = no_upsampling
self.img_range = img_range
# number of RDB blocks, conv layers, out channels
self.D, C, G = {
'A': (20, 6, 32),
'B': (16, 8, 64),
}[RDNconfig]
# Shallow feature extraction net
self.SFENet1 = nn.Conv2d(n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
# Redidual dense blocks and dense feature fusion
self.RDBs = nn.ModuleList()
for i in range(self.D):
self.RDBs.append(
RDB(growRate0 = G0, growRate = G, nConvLayers = C)
)
# Global Feature Fusion
self.GFF = nn.Sequential(*[
nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
])
if no_upsampling:
self.out_dim = G0
else:
self.out_dim = n_colors
# Up-sampling net
if r == 2 or r == 3:
self.UPNet = nn.Sequential(*[
nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(r),
nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)
])
elif r == 4:
self.UPNet = nn.Sequential(*[
nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1)
])
else:
raise ValueError("scale must be 2 or 3 or 4.")
def forward(self, x):
x = x * self.img_range
f__1 = self.SFENet1(x)
x = self.SFENet2(f__1)
RDBs_out = []
for i in range(self.D):
x = self.RDBs[i](x)
RDBs_out.append(x)
x = self.GFF(torch.cat(RDBs_out,1))
x += f__1
if self.no_upsampling:
return x
else:
return self.UPNet(x)
if __name__ == '__main__':
x = torch.randn(8,3,48,48)
model = RDNNOUP()
y = model(x)
print(y.shape)