|
import collections.abc |
|
import math |
|
import torch |
|
import torchvision |
|
import warnings |
|
from distutils.version import LooseVersion |
|
from itertools import repeat |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
from torch.nn import init as init |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
class RDB_Conv(nn.Module): |
|
def __init__(self, inChannels, growRate, kSize=3): |
|
super(RDB_Conv, self).__init__() |
|
Cin = inChannels |
|
G = growRate |
|
self.conv = nn.Sequential(*[ |
|
nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), |
|
nn.ReLU() |
|
]) |
|
|
|
def forward(self, x): |
|
out = self.conv(x) |
|
return torch.cat((x, out), 1) |
|
|
|
class RDB(nn.Module): |
|
def __init__(self, growRate0, growRate, nConvLayers, kSize=3): |
|
super(RDB, self).__init__() |
|
G0 = growRate0 |
|
G = growRate |
|
C = nConvLayers |
|
|
|
convs = [] |
|
for c in range(C): |
|
convs.append(RDB_Conv(G0 + c*G, G)) |
|
self.convs = nn.Sequential(*convs) |
|
|
|
|
|
self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) |
|
|
|
def forward(self, x): |
|
return self.LFF(self.convs(x)) + x |
|
|
|
class RDNNOUP(nn.Module): |
|
def __init__(self, G0 = 64, kSize = 3, r = 4, n_colors = 3, RDNconfig = 'B', |
|
no_upsampling = True, img_range = 1.0): |
|
super(RDNNOUP, self).__init__() |
|
|
|
self.no_upsampling = no_upsampling |
|
self.img_range = img_range |
|
|
|
|
|
self.D, C, G = { |
|
'A': (20, 6, 32), |
|
'B': (16, 8, 64), |
|
}[RDNconfig] |
|
|
|
|
|
self.SFENet1 = nn.Conv2d(n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) |
|
self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) |
|
|
|
|
|
self.RDBs = nn.ModuleList() |
|
for i in range(self.D): |
|
self.RDBs.append( |
|
RDB(growRate0 = G0, growRate = G, nConvLayers = C) |
|
) |
|
|
|
|
|
self.GFF = nn.Sequential(*[ |
|
nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), |
|
nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) |
|
]) |
|
|
|
if no_upsampling: |
|
self.out_dim = G0 |
|
else: |
|
self.out_dim = n_colors |
|
|
|
if r == 2 or r == 3: |
|
self.UPNet = nn.Sequential(*[ |
|
nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), |
|
nn.PixelShuffle(r), |
|
nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1) |
|
]) |
|
elif r == 4: |
|
self.UPNet = nn.Sequential(*[ |
|
nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), |
|
nn.PixelShuffle(2), |
|
nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), |
|
nn.PixelShuffle(2), |
|
nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1) |
|
]) |
|
else: |
|
raise ValueError("scale must be 2 or 3 or 4.") |
|
|
|
def forward(self, x): |
|
x = x * self.img_range |
|
f__1 = self.SFENet1(x) |
|
x = self.SFENet2(f__1) |
|
|
|
RDBs_out = [] |
|
for i in range(self.D): |
|
x = self.RDBs[i](x) |
|
RDBs_out.append(x) |
|
|
|
x = self.GFF(torch.cat(RDBs_out,1)) |
|
x += f__1 |
|
|
|
if self.no_upsampling: |
|
return x |
|
else: |
|
return self.UPNet(x) |
|
|
|
if __name__ == '__main__': |
|
x = torch.randn(8,3,48,48) |
|
model = RDNNOUP() |
|
y = model(x) |
|
print(y.shape) |