diff --git a/configs/esrgan_psnr_x2_div2k.yaml b/configs/esrgan_psnr_x2_div2k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dca1581385e6fdfea784dab2e9e37f7d2ba20758 --- /dev/null +++ b/configs/esrgan_psnr_x2_div2k.yaml @@ -0,0 +1,106 @@ +total_iters: 1000000 +output_dir: output_dir +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: BaseSRModel + generator: + name: RRDBNet + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 23 + scale: 2 + pixel_criterion: + name: L1Loss + +dataset: + train: + name: SRDataset + gt_folder: data/DIV2K/DIV2K_train_HR_sub + lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X2_sub + num_workers: 4 + batch_size: 8 + scale: 2 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 128 + scale: 2 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + test: + name: SRDataset + gt_folder: data/Set14/GTmod12 + lq_folder: data/Set14/LRbicx2 + scale: 2 + preprocess: + - name: LoadImageFromFile + key: lq + - name: LoadImageFromFile + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: Transpose + keys: [image, image] + - name: Normalize + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: 0.0002 + periods: [250000, 250000, 250000, 250000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 2 + test_y_channel: True + ssim: + name: SSIM + crop_border: 2 + test_y_channel: True + +log_config: + interval: 100 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/ppgan/models/generators/rrdb_net.py b/ppgan/models/generators/rrdb_net.py index f69d42b3e34077da2a5bd075852d15fb7da07a15..c0d5f73a71a56638803f3f2a70dfcf267529b12e 100644 --- a/ppgan/models/generators/rrdb_net.py +++ b/ppgan/models/generators/rrdb_net.py @@ -22,6 +22,26 @@ import paddle.nn.functional as F from .builder import GENERATORS +def pixel_unshuffle(x, scale): + """ Pixel unshuffle function. + + Args: + x (paddle.Tensor): Input feature. + scale (int): Downsample ratio. + + Returns: + paddle.Tensor: the pixel unshuffled feature. + """ + b, c, h, w = x.shape + out_channel = c * (scale**2) + assert h % scale == 0 and w % scale == 0 + hh = h // scale + ww = w // scale + x_reshaped = x.reshape([b, c, hh, scale, ww, scale]) + return x_reshaped.transpose([0, 1, 3, 5, 2, + 4]).reshape([b, out_channel, hh, ww]) + + class ResidualDenseBlock_5C(nn.Layer): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() @@ -66,13 +86,21 @@ def make_layer(block, n_layers): @GENERATORS.register() class RRDBNet(nn.Layer): - def __init__(self, in_nc, out_nc, nf, nb, gc=32): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4): super(RRDBNet, self).__init__() + + self.scale = scale + if scale == 2: + in_nc = in_nc * 4 + elif scale == 1: + in_nc = in_nc * 16 + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2D(in_nc, nf, 3, 1, 1, bias_attr=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True) + #### upsampling self.upconv1 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True) self.upconv2 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True) @@ -82,7 +110,14 @@ class RRDBNet(nn.Layer): self.lrelu = nn.LeakyReLU(negative_slope=0.2) def forward(self, x): - fea = self.conv_first(x) + if self.scale == 2: + fea = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + fea = pixel_unshuffle(x, scale=4) + else: + fea = x + + fea = self.conv_first(fea) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk