srntt-pytorch-master/models/srntt.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import models
class SRNTT(nn.Module):
"""
PyTorch Module for SRNTT.
Now x4 is only supported.
Parameters
---
ngf : int, optional
the number of filterd of generator.
n_blucks : int, optional
the number of residual blocks for each module.
"""
def __init__(self, ngf=64, n_blocks=16, use_weights=False):
super(SRNTT, self).__init__()
self.content_extractor = ContentExtractor(ngf, n_blocks)
self.texture_transfer = TextureTransfer(ngf, n_blocks, use_weights)
models.init_weights(self, init_type='normal', init_gain=0.02)
def forward(self, x, maps, weights=None):
"""
Parameters
---
x : torch.Tensor
the input image of SRNTT.
maps : dict of torch.Tensor
the swapped feature maps on relu3_1, relu2_1 and relu1_1.
depths of the maps are 256, 128 and 64 respectively.
"""
base = F.interpolate(x, None, 4, 'bilinear', False)
upscale_plain, content_feat = self.content_extractor(x)
if maps is not None:
if hasattr(self.texture_transfer, 'a'): # if weight is used
upscale_srntt = self.texture_transfer(
content_feat, maps, weights)
else:
upscale_srntt = self.texture_transfer(
content_feat, maps)
return upscale_plain + base, upscale_srntt + base
else:
return upscale_plain + base, None
class ContentExtractor(nn.Module):
"""
Content Extractor for SRNTT, which outputs maps before-and-after upscale.
more detail: https://github.com/ZZUTK/SRNTT/blob/master/SRNTT/model.py#L73.
Currently this module only supports `scale_factor=4`.
Parameters
---
ngf : int, optional
a number of generator's features.
n_blocks : int, optional
a number of residual blocks, see also `ResBlock` class.
"""
def __init__(self, ngf=64, n_blocks=16):
super(ContentExtractor, self).__init__()
self.head = nn.Sequential(
nn.Conv2d(3, ngf, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, True)
)
self.body = nn.Sequential(
*[ResBlock(ngf) for _ in range(n_blocks)],
# nn.Conv2d(ngf, ngf, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(ngf)
)
self.tail = nn.Sequential(
nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.1, True),
nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.1, True),
nn.Conv2d(ngf, ngf, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
# nn.Tanh()
)
def forward(self, x):
h = self.head(x)
h = self.body(h) + h
upscale = self.tail(h)
return upscale, h
class TextureTransfer(nn.Module):
"""
Conditional Texture Transfer for SRNTT,
see https://github.com/ZZUTK/SRNTT/blob/master/SRNTT/model.py#L116.
This module is devided 3 parts for each scales.
Parameters
---
ngf : int
a number of generator's filters.
n_blocks : int, optional
a number of residual blocks, see also `ResBlock` class.
"""
def __init__(self, ngf=64, n_blocks=16, use_weights=False):
super(TextureTransfer, self).__init__()
# for small scale
self.head_small = nn.Sequential(
nn.Conv2d(ngf + 256, ngf, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, True),
)
self.body_small = nn.Sequential(
*[ResBlock(ngf) for _ in range(n_blocks)],
# nn.Conv2d(ngf, ngf, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(ngf)
)
self.tail_small = nn.Sequential(
nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.1, True),
)
# for medium scale
self.head_medium = nn.Sequential(
nn.Conv2d(ngf + 128, ngf, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, True),
)
self.body_medium = nn.Sequential(
*[ResBlock(ngf) for _ in range(n_blocks)],
# nn.Conv2d(ngf, ngf, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(ngf)
)
self.tail_medium = nn.Sequential(
nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.1, True),
)
# for large scale
self.head_large = nn.Sequential(
nn.Conv2d(ngf + 64, ngf, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, True),
)
self.body_large = nn.Sequential(
*[ResBlock(ngf) for _ in range(n_blocks)],
# nn.Conv2d(ngf, ngf, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(ngf)
)
self.tail_large = nn.Sequential(
nn.Conv2d(ngf, ngf // 2, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(ngf // 2, 3, kernel_size=3, stride=1, padding=1),
# nn.Tanh()
)
if use_weights:
self.a = nn.Parameter(torch.ones(3), requires_grad=True)
self.b = nn.Parameter(torch.ones(3), requires_grad=True)
def forward(self, x, maps, weights=None):
# compute weighted maps
if hasattr(self, 'a') and weights is not None:
for idx, layer in enumerate(['relu3_1', 'relu2_1', 'relu1_1']):
weights_scaled = F.interpolate(
F.pad(weights, (1, 1, 1, 1), mode='replicate'),
scale_factor=2**idx,
mode='bicubic',
align_corners=True) * self.a[idx] + self.b[idx]
maps[layer] *= torch.sigmoid(weights_scaled)
# small scale
h = torch.cat([x, maps['relu3_1']], 1)
h = self.head_small(h)
h = self.body_small(h) + x
x = self.tail_small(h)
# medium scale
h = torch.cat([x, maps['relu2_1']], 1)
h = self.head_medium(h)
h = self.body_medium(h) + x
x = self.tail_medium(h)
# large scale
h = torch.cat([x, maps['relu1_1']], 1)
h = self.head_large(h)
h = self.body_large(h) + x
x = self.tail_large(h)
return x
class ResBlock(nn.Module):
"""
Basic residual block for SRNTT.
Parameters
---
n_filters : int, optional
a number of filters.
"""
def __init__(self, n_filters=64):
super(ResBlock, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(n_filters, n_filters, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(n_filters, n_filters, 3, 1, 1),
)
def forward(self, x):
return self.body(x) + x
if __name__ == "__main__":
device = torch.device('cuda:0')
x = torch.rand(16, 3, 24, 24).to(device)
maps = {}
maps.update({'relu3_1': torch.rand(16, 256, 24, 24).to(device)})
maps.update({'relu2_1': torch.rand(16, 128, 48, 48).to(device)})
maps.update({'relu1_1': torch.rand(16, 64, 96, 96).to(device)})
model = SRNTT().to(device)
_, out = model(x, maps)