未验证 提交 cd642c08 编写于 作者: L LielinJiang 提交者: GitHub

Add Lesrcnn model (#136)

* add lesrcnn model
上级 ffb1a225
...@@ -91,11 +91,11 @@ validate: ...@@ -91,11 +91,11 @@ validate:
psnr: # metric name, can be arbitrary psnr: # metric name, can be arbitrary
name: PSNR name: PSNR
crop_border: 4 crop_border: 4
test_y_channel: false test_y_channel: True
ssim: ssim:
name: SSIM name: SSIM
crop_border: 4 crop_border: 4
test_y_channel: false test_y_channel: True
log_config: log_config:
interval: 10 interval: 10
......
total_iters: 1000000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: BaseSRModel
generator:
name: LESRCNNGenerator
pixel_criterion:
name: L1Loss
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers: 4
batch_size: 16
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 128
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: 0.0002
periods: [250000, 250000, 250000, 250000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
log_config:
interval: 100
visiual_interval: 5000
snapshot_config:
interval: 5000
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
import cv2
import random import random
import numbers import numbers
import collections import collections
...@@ -40,8 +41,9 @@ TRANSFORMS.register(T.Transpose) ...@@ -40,8 +41,9 @@ TRANSFORMS.register(T.Transpose)
@PREPROCESS.register() @PREPROCESS.register()
class Transforms(): class Transforms():
def __init__(self, pipeline, input_keys): def __init__(self, pipeline, input_keys, output_keys=None):
self.input_keys = input_keys self.input_keys = input_keys
self.output_keys = output_keys
self.transforms = [] self.transforms = []
for transform_cfg in pipeline: for transform_cfg in pipeline:
self.transforms.append(build_from_config(transform_cfg, TRANSFORMS)) self.transforms.append(build_from_config(transform_cfg, TRANSFORMS))
...@@ -58,6 +60,11 @@ class Transforms(): ...@@ -58,6 +60,11 @@ class Transforms():
transform.params, dict): transform.params, dict):
datas.update(transform.params) datas.update(transform.params)
if self.output_keys is not None:
for i, k in enumerate(self.output_keys):
datas[k] = data[i]
return datas
for i, k in enumerate(self.input_keys): for i, k in enumerate(self.input_keys):
datas[k] = data[i] datas[k] = data[i]
...@@ -183,10 +190,11 @@ class SRPairedRandomCrop(T.BaseTransform): ...@@ -183,10 +190,11 @@ class SRPairedRandomCrop(T.BaseTransform):
scale (int): model upscale factor. scale (int): model upscale factor.
gt_patch_size (int): cropped gt patch size. gt_patch_size (int): cropped gt patch size.
""" """
def __init__(self, scale, gt_patch_size, keys=None): def __init__(self, scale, gt_patch_size, scale_list=False, keys=None):
self.gt_patch_size = gt_patch_size self.gt_patch_size = gt_patch_size
self.scale = scale self.scale = scale
self.keys = keys self.keys = keys
self.scale_list = scale_list
def __call__(self, inputs): def __call__(self, inputs):
"""inputs must be (lq_img, gt_img)""" """inputs must be (lq_img, gt_img)"""
...@@ -214,5 +222,11 @@ class SRPairedRandomCrop(T.BaseTransform): ...@@ -214,5 +222,11 @@ class SRPairedRandomCrop(T.BaseTransform):
gt = gt[top_gt:top_gt + self.gt_patch_size, gt = gt[top_gt:top_gt + self.gt_patch_size,
left_gt:left_gt + self.gt_patch_size, ...] left_gt:left_gt + self.gt_patch_size, ...]
if self.scale_list and self.scale == 4:
lqx2 = F.resize(gt, (lq_patch_size * 2, lq_patch_size * 2),
'bicubic')
outputs = (lq, lqx2, gt)
return outputs
outputs = (lq, gt) outputs = (lq, gt)
return outputs return outputs
from .transforms import ResizeToScale, PairedRandomCrop, PairedRandomHorizontalFlip, Add from .transforms import ResizeToScale, PairedRandomCrop, PairedRandomHorizontalFlip, Add
\ No newline at end of file
...@@ -386,7 +386,6 @@ class Trainer: ...@@ -386,7 +386,6 @@ class Trainer:
self.logger.warning( self.logger.warning(
'Can not find state dict of net {}. Skip load pretrained weight for net {}' 'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.format(net_name, net_name)) .format(net_name, net_name))
net.set_state_dict(state_dicts[net_name])
def close(self): def close(self):
""" """
......
...@@ -270,6 +270,48 @@ def bgr2ycbcr(img, y_only=False): ...@@ -270,6 +270,48 @@ def bgr2ycbcr(img, y_only=False):
return out_img return out_img
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
The RGB version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
if img_type != np.uint8:
img *= 255.
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) / 255. + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
if img_type != np.uint8:
out_img /= 255.
else:
out_img = out_img.round()
return out_img
def to_y_channel(img): def to_y_channel(img):
"""Change to Y channel of YCbCr. """Change to Y channel of YCbCr.
...@@ -281,6 +323,6 @@ def to_y_channel(img): ...@@ -281,6 +323,6 @@ def to_y_channel(img):
""" """
img = img.astype(np.float32) / 255. img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3: if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True) img = rgb2ycbcr(img, y_only=True)
img = img[..., None] img = img[..., None]
return img * 255. return img * 255.
...@@ -61,7 +61,6 @@ class ESRGAN(BaseSRModel): ...@@ -61,7 +61,6 @@ class ESRGAN(BaseSRModel):
self.gan_criterion = build_criterion(gan_criterion) self.gan_criterion = build_criterion(gan_criterion)
def train_iter(self, optimizers=None): def train_iter(self, optimizers=None):
self.set_requires_grad(self.nets['discriminator'], False)
optimizers['optimG'].clear_grad() optimizers['optimG'].clear_grad()
l_total = 0 l_total = 0
self.output = self.nets['generator'](self.lq) self.output = self.nets['generator'](self.lq)
...@@ -83,41 +82,48 @@ class ESRGAN(BaseSRModel): ...@@ -83,41 +82,48 @@ class ESRGAN(BaseSRModel):
self.losses['loss_style'] = l_g_style self.losses['loss_style'] = l_g_style
# gan loss (relativistic gan) # gan loss (relativistic gan)
real_d_pred = self.nets['discriminator'](self.gt).detach() if hasattr(self, 'gan_criterion'):
fake_g_pred = self.nets['discriminator'](self.output) self.set_requires_grad(self.nets['discriminator'], False)
l_g_real = self.gan_criterion(real_d_pred - paddle.mean(fake_g_pred), real_d_pred = self.nets['discriminator'](self.gt).detach()
False, fake_g_pred = self.nets['discriminator'](self.output)
is_disc=False) l_g_real = self.gan_criterion(real_d_pred -
l_g_fake = self.gan_criterion(fake_g_pred - paddle.mean(real_d_pred), paddle.mean(fake_g_pred),
True, False,
is_disc=False) is_disc=False)
l_g_gan = (l_g_real + l_g_fake) / 2 l_g_fake = self.gan_criterion(fake_g_pred -
paddle.mean(real_d_pred),
l_total += l_g_gan True,
self.losses['l_g_gan'] = l_g_gan is_disc=False)
l_g_gan = (l_g_real + l_g_fake) / 2
l_total.backward()
optimizers['optimG'].step() l_total += l_g_gan
self.losses['l_g_gan'] = l_g_gan
self.set_requires_grad(self.nets['discriminator'], True) l_total.backward()
optimizers['optimD'].clear_grad() optimizers['optimG'].step()
# real
fake_d_pred = self.nets['discriminator'](self.output).detach() self.set_requires_grad(self.nets['discriminator'], True)
real_d_pred = self.nets['discriminator'](self.gt) optimizers['optimD'].clear_grad()
l_d_real = self.gan_criterion( # real
real_d_pred - paddle.mean(fake_d_pred), True, is_disc=True) * 0.5 fake_d_pred = self.nets['discriminator'](self.output).detach()
real_d_pred = self.nets['discriminator'](self.gt)
# fake l_d_real = self.gan_criterion(
fake_d_pred = self.nets['discriminator'](self.output.detach()) real_d_pred - paddle.mean(fake_d_pred), True,
l_d_fake = self.gan_criterion( is_disc=True) * 0.5
fake_d_pred - paddle.mean(real_d_pred.detach()),
False, # fake
is_disc=True) * 0.5 fake_d_pred = self.nets['discriminator'](self.output.detach())
l_d_fake = self.gan_criterion(
(l_d_real + l_d_fake).backward() fake_d_pred - paddle.mean(real_d_pred.detach()),
optimizers['optimD'].step() False,
is_disc=True) * 0.5
self.losses['l_d_real'] = l_d_real
self.losses['l_d_fake'] = l_d_fake (l_d_real + l_d_fake).backward()
self.losses['out_d_real'] = paddle.mean(real_d_pred.detach()) optimizers['optimD'].step()
self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach())
self.losses['l_d_real'] = l_d_real
self.losses['l_d_fake'] = l_d_fake
self.losses['out_d_real'] = paddle.mean(real_d_pred.detach())
self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach())
else:
l_total.backward()
optimizers['optimG'].step()
...@@ -21,6 +21,7 @@ from .resnet_ugatit import ResnetUGATITGenerator ...@@ -21,6 +21,7 @@ from .resnet_ugatit import ResnetUGATITGenerator
from .dcgenerator import DCGenerator from .dcgenerator import DCGenerator
from .generater_animegan import AnimeGenerator, AnimeGeneratorLite from .generater_animegan import AnimeGenerator, AnimeGeneratorLite
from .wav2lip import Wav2Lip from .wav2lip import Wav2Lip
from .lesrcnn import LESRCNNGenerator
from .resnet_ugatit_p2c import ResnetUGATITP2CGenerator from .resnet_ugatit_p2c import ResnetUGATITP2CGenerator
from .generator_styleganv2 import StyleGANv2Generator from .generator_styleganv2 import StyleGANv2Generator
from .generator_pixel2style2pixel import Pixel2Style2Pixel from .generator_pixel2style2pixel import Pixel2Style2Pixel
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
class MeanShift(nn.Layer):
def __init__(self, mean_rgb, sub):
super(MeanShift, self).__init__()
sign = -1 if sub else 1
r = mean_rgb[0] * sign
g = mean_rgb[1] * sign
b = mean_rgb[2] * sign
self.shifter = nn.Conv2D(3, 3, 1, 1, 0)
self.shifter.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1]))
self.shifter.bias.set_value(np.array([r, g, b]).astype('float32'))
# Freeze the mean shift layer
for params in self.shifter.parameters():
params.trainable = False
def forward(self, x):
x = self.shifter(x)
return x
class UpsampleBlock(nn.Layer):
def __init__(self, n_channels, scale, multi_scale, group=1):
super(UpsampleBlock, self).__init__()
if multi_scale:
self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
else:
self.up = _UpsampleBlock(n_channels, scale=scale, group=group)
self.multi_scale = multi_scale
def forward(self, x, scale):
if self.multi_scale:
if scale == 2:
return self.up2(x)
elif scale == 3:
return self.up3(x)
elif scale == 4:
return self.up4(x)
else:
return self.up(x)
class _UpsampleBlock(nn.Layer):
def __init__(self, n_channels, scale, group=1):
super(_UpsampleBlock, self).__init__()
modules = []
if scale == 2 or scale == 4 or scale == 8:
for _ in range(int(math.log(scale, 2))):
modules += [
nn.Conv2D(n_channels, 4 * n_channels, 3, 1, 1, groups=group)
]
modules += [nn.PixelShuffle(2)]
elif scale == 3:
modules += [
nn.Conv2D(n_channels, 9 * n_channels, 3, 1, 1, groups=group)
]
modules += [nn.PixelShuffle(3)]
self.body = nn.Sequential(*modules)
def forward(self, x):
out = self.body(x)
return out
@GENERATORS.register()
class LESRCNNGenerator(nn.Layer):
"""Construct a Resnet-based generator that consists of residual blocks
between a few downsampling/upsampling operations.
Args:
scale (int): scale of upsample.
multi_scale (bool): Whether to train multi scale model.
group (int): group option for convolution.
"""
def __init__(
self,
scale=4,
multi_scale=False,
group=1,
):
super(LESRCNNGenerator, self).__init__()
kernel_size = 3
kernel_size1 = 1
padding1 = 0
padding = 1
features = 64
groups = 1
channels = 3
features1 = 64
self.scale = scale
self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True)
self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False)
self.conv1 = nn.Sequential(
nn.Conv2D(in_channels=channels,
out_channels=features,
kernel_size=kernel_size,
padding=padding,
groups=1,
bias_attr=False))
self.conv2 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv3 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size1,
padding=0,
groups=groups,
bias_attr=False))
self.conv4 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv5 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size1,
padding=0,
groups=groups,
bias_attr=False))
self.conv6 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv7 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size1,
padding=0,
groups=groups,
bias_attr=False))
self.conv8 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv9 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size1,
padding=0,
groups=groups,
bias_attr=False))
self.conv10 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv11 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size1,
padding=0,
groups=groups,
bias_attr=False))
self.conv12 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv13 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size1,
padding=0,
groups=groups,
bias_attr=False))
self.conv14 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv15 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size1,
padding=0,
groups=groups,
bias_attr=False))
self.conv16 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv17 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size1,
padding=0,
groups=groups,
bias_attr=False))
self.conv17_1 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv17_2 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv17_3 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv17_4 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=features,
kernel_size=kernel_size,
padding=1,
groups=1,
bias_attr=False), nn.ReLU())
self.conv18 = nn.Sequential(
nn.Conv2D(in_channels=features,
out_channels=3,
kernel_size=kernel_size,
padding=padding,
groups=groups,
bias_attr=False))
self.ReLU = nn.ReLU()
self.upsample = UpsampleBlock(64,
scale=scale,
multi_scale=multi_scale,
group=1)
def forward(self, x, scale=None):
if scale is None:
scale = self.scale
x = self.sub_mean(x)
x1 = self.conv1(x)
x1_1 = self.ReLU(x1)
x2 = self.conv2(x1_1)
x3 = self.conv3(x2)
x2_3 = x1 + x3
x2_4 = self.ReLU(x2_3)
x4 = self.conv4(x2_4)
x5 = self.conv5(x4)
x3_5 = x2_3 + x5
x3_6 = self.ReLU(x3_5)
x6 = self.conv6(x3_6)
x7 = self.conv7(x6)
x7_1 = x3_5 + x7
x7_2 = self.ReLU(x7_1)
x8 = self.conv8(x7_2)
x9 = self.conv9(x8)
x9_2 = x7_1 + x9
x9_1 = self.ReLU(x9_2)
x10 = self.conv10(x9_1)
x11 = self.conv11(x10)
x11_1 = x9_2 + x11
x11_2 = self.ReLU(x11_1)
x12 = self.conv12(x11_2)
x13 = self.conv13(x12)
x13_1 = x11_1 + x13
x13_2 = self.ReLU(x13_1)
x14 = self.conv14(x13_2)
x15 = self.conv15(x14)
x15_1 = x15 + x13_1
x15_2 = self.ReLU(x15_1)
x16 = self.conv16(x15_2)
x17 = self.conv17(x16)
x17_2 = x17 + x15_1
x17_3 = self.ReLU(x17_2)
temp = self.upsample(x17_3, scale=scale)
x1111 = self.upsample(x1_1, scale=scale)
temp1 = x1111 + temp
temp2 = self.ReLU(temp1)
temp3 = self.conv17_1(temp2)
temp4 = self.conv17_2(temp3)
temp5 = self.conv17_3(temp4)
temp6 = self.conv17_4(temp5)
x18 = self.conv18(temp6)
out = self.add_mean(x18)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册