diff --git a/configs/esrgan_psnr_x4_div2k.yaml b/configs/esrgan_psnr_x4_div2k.yaml index 642818e2ca6d9048289a08f4ce9b37d000f62e4d..7dd7428f8fbda251f4cd5dc8dc0704e791e573ab 100644 --- a/configs/esrgan_psnr_x4_div2k.yaml +++ b/configs/esrgan_psnr_x4_div2k.yaml @@ -91,11 +91,11 @@ validate: psnr: # metric name, can be arbitrary name: PSNR crop_border: 4 - test_y_channel: false + test_y_channel: True ssim: name: SSIM crop_border: 4 - test_y_channel: false + test_y_channel: True log_config: interval: 10 diff --git a/configs/lesrcnn_psnr_x4_div2k.yaml b/configs/lesrcnn_psnr_x4_div2k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e9759e06126fa367030c100e0abfd5e21b118e51 --- /dev/null +++ b/configs/lesrcnn_psnr_x4_div2k.yaml @@ -0,0 +1,101 @@ +total_iters: 1000000 +output_dir: output_dir +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: BaseSRModel + generator: + name: LESRCNNGenerator + pixel_criterion: + name: L1Loss + +dataset: + train: + name: SRDataset + gt_folder: data/DIV2K/DIV2K_train_HR_sub + lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub + num_workers: 4 + batch_size: 16 + scale: 4 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 128 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + test: + name: SRDataset + gt_folder: data/DIV2K/val_set14/Set14 + lq_folder: data/DIV2K/val_set14/Set14_bicLRx4 + scale: 4 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: 0.0002 + periods: [250000, 250000, 250000, 250000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 4 + test_y_channel: True + ssim: + name: SSIM + crop_border: 4 + test_y_channel: True + +log_config: + interval: 100 + visiual_interval: 5000 + +snapshot_config: + interval: 5000 diff --git a/ppgan/datasets/preprocess/transforms.py b/ppgan/datasets/preprocess/transforms.py index 57daf44a43d0d7b9f1464949fcd1cdc6565844cc..b48a4a4c8c72301309f92a4d15376018ebf44962 100644 --- a/ppgan/datasets/preprocess/transforms.py +++ b/ppgan/datasets/preprocess/transforms.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +import cv2 import random import numbers import collections @@ -40,8 +41,9 @@ TRANSFORMS.register(T.Transpose) @PREPROCESS.register() class Transforms(): - def __init__(self, pipeline, input_keys): + def __init__(self, pipeline, input_keys, output_keys=None): self.input_keys = input_keys + self.output_keys = output_keys self.transforms = [] for transform_cfg in pipeline: self.transforms.append(build_from_config(transform_cfg, TRANSFORMS)) @@ -58,6 +60,11 @@ class Transforms(): transform.params, dict): datas.update(transform.params) + if self.output_keys is not None: + for i, k in enumerate(self.output_keys): + datas[k] = data[i] + return datas + for i, k in enumerate(self.input_keys): datas[k] = data[i] @@ -183,10 +190,11 @@ class SRPairedRandomCrop(T.BaseTransform): scale (int): model upscale factor. gt_patch_size (int): cropped gt patch size. """ - def __init__(self, scale, gt_patch_size, keys=None): + def __init__(self, scale, gt_patch_size, scale_list=False, keys=None): self.gt_patch_size = gt_patch_size self.scale = scale self.keys = keys + self.scale_list = scale_list def __call__(self, inputs): """inputs must be (lq_img, gt_img)""" @@ -214,5 +222,11 @@ class SRPairedRandomCrop(T.BaseTransform): gt = gt[top_gt:top_gt + self.gt_patch_size, left_gt:left_gt + self.gt_patch_size, ...] + if self.scale_list and self.scale == 4: + lqx2 = F.resize(gt, (lq_patch_size * 2, lq_patch_size * 2), + 'bicubic') + outputs = (lq, lqx2, gt) + return outputs + outputs = (lq, gt) return outputs diff --git a/ppgan/datasets/transforms/__init__.py b/ppgan/datasets/transforms/__init__.py index 3b0e160826a488a942807e210489c7f75f9048e6..acb1b770db0c05f74cce8e0350be8d0ef4e96b89 100644 --- a/ppgan/datasets/transforms/__init__.py +++ b/ppgan/datasets/transforms/__init__.py @@ -1 +1 @@ -from .transforms import ResizeToScale, PairedRandomCrop, PairedRandomHorizontalFlip, Add \ No newline at end of file +from .transforms import ResizeToScale, PairedRandomCrop, PairedRandomHorizontalFlip, Add diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 6d6b22835471883ca654962b627deb6c7ca75169..2db728c0a2ead372f3e828342c8f1b9b8285a15d 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -386,7 +386,6 @@ class Trainer: self.logger.warning( 'Can not find state dict of net {}. Skip load pretrained weight for net {}' .format(net_name, net_name)) - net.set_state_dict(state_dicts[net_name]) def close(self): """ diff --git a/ppgan/metrics/psnr_ssim.py b/ppgan/metrics/psnr_ssim.py index fb362cb86eb8e50b974d90b14f029ad6961ead35..5bd9dc8d4571a25ae2231f52b7ebbd1044a6ed9b 100644 --- a/ppgan/metrics/psnr_ssim.py +++ b/ppgan/metrics/psnr_ssim.py @@ -270,6 +270,48 @@ def bgr2ycbcr(img, y_only=False): return out_img +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + The RGB version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + + if img_type != np.uint8: + img *= 255. + + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) / 255. + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + + if img_type != np.uint8: + out_img /= 255. + else: + out_img = out_img.round() + + return out_img + + def to_y_channel(img): """Change to Y channel of YCbCr. @@ -281,6 +323,6 @@ def to_y_channel(img): """ img = img.astype(np.float32) / 255. if img.ndim == 3 and img.shape[2] == 3: - img = bgr2ycbcr(img, y_only=True) + img = rgb2ycbcr(img, y_only=True) img = img[..., None] return img * 255. diff --git a/ppgan/models/esrgan_model.py b/ppgan/models/esrgan_model.py index 09dc28318358c1fb40d4be486c1aac889002d9d6..fe67cff05550dd777e35ea9ae65a706cffa32916 100644 --- a/ppgan/models/esrgan_model.py +++ b/ppgan/models/esrgan_model.py @@ -61,7 +61,6 @@ class ESRGAN(BaseSRModel): self.gan_criterion = build_criterion(gan_criterion) def train_iter(self, optimizers=None): - self.set_requires_grad(self.nets['discriminator'], False) optimizers['optimG'].clear_grad() l_total = 0 self.output = self.nets['generator'](self.lq) @@ -83,41 +82,48 @@ class ESRGAN(BaseSRModel): self.losses['loss_style'] = l_g_style # gan loss (relativistic gan) - real_d_pred = self.nets['discriminator'](self.gt).detach() - fake_g_pred = self.nets['discriminator'](self.output) - l_g_real = self.gan_criterion(real_d_pred - paddle.mean(fake_g_pred), - False, - is_disc=False) - l_g_fake = self.gan_criterion(fake_g_pred - paddle.mean(real_d_pred), - True, - is_disc=False) - l_g_gan = (l_g_real + l_g_fake) / 2 - - l_total += l_g_gan - self.losses['l_g_gan'] = l_g_gan - - l_total.backward() - optimizers['optimG'].step() - - self.set_requires_grad(self.nets['discriminator'], True) - optimizers['optimD'].clear_grad() - # real - fake_d_pred = self.nets['discriminator'](self.output).detach() - real_d_pred = self.nets['discriminator'](self.gt) - l_d_real = self.gan_criterion( - real_d_pred - paddle.mean(fake_d_pred), True, is_disc=True) * 0.5 - - # fake - fake_d_pred = self.nets['discriminator'](self.output.detach()) - l_d_fake = self.gan_criterion( - fake_d_pred - paddle.mean(real_d_pred.detach()), - False, - is_disc=True) * 0.5 - - (l_d_real + l_d_fake).backward() - optimizers['optimD'].step() - - self.losses['l_d_real'] = l_d_real - self.losses['l_d_fake'] = l_d_fake - self.losses['out_d_real'] = paddle.mean(real_d_pred.detach()) - self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach()) + if hasattr(self, 'gan_criterion'): + self.set_requires_grad(self.nets['discriminator'], False) + real_d_pred = self.nets['discriminator'](self.gt).detach() + fake_g_pred = self.nets['discriminator'](self.output) + l_g_real = self.gan_criterion(real_d_pred - + paddle.mean(fake_g_pred), + False, + is_disc=False) + l_g_fake = self.gan_criterion(fake_g_pred - + paddle.mean(real_d_pred), + True, + is_disc=False) + l_g_gan = (l_g_real + l_g_fake) / 2 + + l_total += l_g_gan + self.losses['l_g_gan'] = l_g_gan + l_total.backward() + optimizers['optimG'].step() + + self.set_requires_grad(self.nets['discriminator'], True) + optimizers['optimD'].clear_grad() + # real + fake_d_pred = self.nets['discriminator'](self.output).detach() + real_d_pred = self.nets['discriminator'](self.gt) + l_d_real = self.gan_criterion( + real_d_pred - paddle.mean(fake_d_pred), True, + is_disc=True) * 0.5 + + # fake + fake_d_pred = self.nets['discriminator'](self.output.detach()) + l_d_fake = self.gan_criterion( + fake_d_pred - paddle.mean(real_d_pred.detach()), + False, + is_disc=True) * 0.5 + + (l_d_real + l_d_fake).backward() + optimizers['optimD'].step() + + self.losses['l_d_real'] = l_d_real + self.losses['l_d_fake'] = l_d_fake + self.losses['out_d_real'] = paddle.mean(real_d_pred.detach()) + self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach()) + else: + l_total.backward() + optimizers['optimG'].step() diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 247238571bd63286e8bd97c810592c2b6e572be1..7d069fe7ad38a55bf80aba9b01952d414881168f 100644 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -21,6 +21,7 @@ from .resnet_ugatit import ResnetUGATITGenerator from .dcgenerator import DCGenerator from .generater_animegan import AnimeGenerator, AnimeGeneratorLite from .wav2lip import Wav2Lip +from .lesrcnn import LESRCNNGenerator from .resnet_ugatit_p2c import ResnetUGATITP2CGenerator from .generator_styleganv2 import StyleGANv2Generator from .generator_pixel2style2pixel import Pixel2Style2Pixel diff --git a/ppgan/models/generators/lesrcnn.py b/ppgan/models/generators/lesrcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..19befcf4125ce0c1e6ea23e74fc53b2a591d31bd --- /dev/null +++ b/ppgan/models/generators/lesrcnn.py @@ -0,0 +1,331 @@ +import math +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .builder import GENERATORS + + +class MeanShift(nn.Layer): + def __init__(self, mean_rgb, sub): + super(MeanShift, self).__init__() + + sign = -1 if sub else 1 + r = mean_rgb[0] * sign + g = mean_rgb[1] * sign + b = mean_rgb[2] * sign + + self.shifter = nn.Conv2D(3, 3, 1, 1, 0) + self.shifter.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1])) + self.shifter.bias.set_value(np.array([r, g, b]).astype('float32')) + # Freeze the mean shift layer + for params in self.shifter.parameters(): + params.trainable = False + + def forward(self, x): + x = self.shifter(x) + return x + + +class UpsampleBlock(nn.Layer): + def __init__(self, n_channels, scale, multi_scale, group=1): + super(UpsampleBlock, self).__init__() + + if multi_scale: + self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) + self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) + self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) + else: + self.up = _UpsampleBlock(n_channels, scale=scale, group=group) + + self.multi_scale = multi_scale + + def forward(self, x, scale): + if self.multi_scale: + if scale == 2: + return self.up2(x) + elif scale == 3: + return self.up3(x) + elif scale == 4: + return self.up4(x) + else: + return self.up(x) + + +class _UpsampleBlock(nn.Layer): + def __init__(self, n_channels, scale, group=1): + super(_UpsampleBlock, self).__init__() + + modules = [] + if scale == 2 or scale == 4 or scale == 8: + for _ in range(int(math.log(scale, 2))): + modules += [ + nn.Conv2D(n_channels, 4 * n_channels, 3, 1, 1, groups=group) + ] + modules += [nn.PixelShuffle(2)] + elif scale == 3: + modules += [ + nn.Conv2D(n_channels, 9 * n_channels, 3, 1, 1, groups=group) + ] + modules += [nn.PixelShuffle(3)] + + self.body = nn.Sequential(*modules) + + def forward(self, x): + out = self.body(x) + return out + + +@GENERATORS.register() +class LESRCNNGenerator(nn.Layer): + """Construct a Resnet-based generator that consists of residual blocks + between a few downsampling/upsampling operations. + + Args: + scale (int): scale of upsample. + multi_scale (bool): Whether to train multi scale model. + group (int): group option for convolution. + """ + def __init__( + self, + scale=4, + multi_scale=False, + group=1, + ): + super(LESRCNNGenerator, self).__init__() + + kernel_size = 3 + kernel_size1 = 1 + padding1 = 0 + padding = 1 + features = 64 + groups = 1 + channels = 3 + features1 = 64 + self.scale = scale + self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True) + self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False) + + self.conv1 = nn.Sequential( + nn.Conv2D(in_channels=channels, + out_channels=features, + kernel_size=kernel_size, + padding=padding, + groups=1, + bias_attr=False)) + self.conv2 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv3 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size1, + padding=0, + groups=groups, + bias_attr=False)) + self.conv4 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv5 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size1, + padding=0, + groups=groups, + bias_attr=False)) + self.conv6 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv7 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size1, + padding=0, + groups=groups, + bias_attr=False)) + self.conv8 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv9 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size1, + padding=0, + groups=groups, + bias_attr=False)) + self.conv10 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv11 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size1, + padding=0, + groups=groups, + bias_attr=False)) + self.conv12 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv13 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size1, + padding=0, + groups=groups, + bias_attr=False)) + self.conv14 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv15 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size1, + padding=0, + groups=groups, + bias_attr=False)) + self.conv16 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv17 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size1, + padding=0, + groups=groups, + bias_attr=False)) + self.conv17_1 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv17_2 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv17_3 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv17_4 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=features, + kernel_size=kernel_size, + padding=1, + groups=1, + bias_attr=False), nn.ReLU()) + self.conv18 = nn.Sequential( + nn.Conv2D(in_channels=features, + out_channels=3, + kernel_size=kernel_size, + padding=padding, + groups=groups, + bias_attr=False)) + + self.ReLU = nn.ReLU() + self.upsample = UpsampleBlock(64, + scale=scale, + multi_scale=multi_scale, + group=1) + + def forward(self, x, scale=None): + if scale is None: + scale = self.scale + + x = self.sub_mean(x) + + x1 = self.conv1(x) + x1_1 = self.ReLU(x1) + x2 = self.conv2(x1_1) + x3 = self.conv3(x2) + + x2_3 = x1 + x3 + x2_4 = self.ReLU(x2_3) + x4 = self.conv4(x2_4) + x5 = self.conv5(x4) + x3_5 = x2_3 + x5 + + x3_6 = self.ReLU(x3_5) + x6 = self.conv6(x3_6) + x7 = self.conv7(x6) + x7_1 = x3_5 + x7 + + x7_2 = self.ReLU(x7_1) + x8 = self.conv8(x7_2) + x9 = self.conv9(x8) + x9_2 = x7_1 + x9 + + x9_1 = self.ReLU(x9_2) + x10 = self.conv10(x9_1) + x11 = self.conv11(x10) + x11_1 = x9_2 + x11 + + x11_2 = self.ReLU(x11_1) + x12 = self.conv12(x11_2) + x13 = self.conv13(x12) + x13_1 = x11_1 + x13 + + x13_2 = self.ReLU(x13_1) + x14 = self.conv14(x13_2) + x15 = self.conv15(x14) + x15_1 = x15 + x13_1 + + x15_2 = self.ReLU(x15_1) + x16 = self.conv16(x15_2) + x17 = self.conv17(x16) + x17_2 = x17 + x15_1 + + x17_3 = self.ReLU(x17_2) + temp = self.upsample(x17_3, scale=scale) + x1111 = self.upsample(x1_1, scale=scale) + temp1 = x1111 + temp + temp2 = self.ReLU(temp1) + temp3 = self.conv17_1(temp2) + temp4 = self.conv17_2(temp3) + temp5 = self.conv17_3(temp4) + temp6 = self.conv17_4(temp5) + x18 = self.conv18(temp6) + out = self.add_mean(x18) + + return out