diff --git a/configs/basicvsr_reds.yaml b/configs/basicvsr_reds.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91c6a8889f96d692dbb1419f2828e08ec936ef47 --- /dev/null +++ b/configs/basicvsr_reds.yaml @@ -0,0 +1,92 @@ +total_iters: 300000 +output_dir: output_dir +find_unused_parameters: True +checkpoints_dir: checkpoints +use_dataset: True +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: BasicVSRModel + fix_iter: 5000 + generator: + name: BasicVSRNet + mid_channels: 64 + num_blocks: 30 + pixel_criterion: + name: CharbonnierLoss + reduction: mean + +dataset: + train: + name: RepeatDataset + times: 1000 + num_workers: 4 # 6 + batch_size: 2 # 4*2 + dataset: + name: SRREDSMultipleGTDataset + mode: train + lq_folder: data/REDS/train_sharp_bicubic/X4 + gt_folder: data/REDS/train_sharp/X4 + crop_size: 256 + interval_list: [1] + random_reverse: False + number_frames: 15 + use_flip: True + use_rot: True + scale: 4 + val_partition: REDS4 + + test: + name: SRREDSMultipleGTDataset + mode: test + lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 + gt_folder: data/REDS/REDS4_test_sharp/X4 + interval_list: [1] + random_reverse: False + number_frames: 100 + use_flip: False + use_rot: False + scale: 4 + val_partition: REDS4 + num_workers: 0 + batch_size: 1 + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 2e-4 + periods: [300000] + restart_weights: [1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + # FIXME: avoid oom + interval: 5000000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 0 + test_y_channel: False + ssim: + name: SSIM + crop_border: 0 + test_y_channel: False + +log_config: + interval: 100 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index ef9d7ad314fb9fb6a7636ca60cedae98414d95c4..7e3182ab01138e56542cdc37c70a9a5bf0e06a02 100755 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -24,4 +24,5 @@ from .starganv2_dataset import StarGANv2Dataset from .edvr_dataset import REDSDataset from .firstorder_dataset import FirstOrderDataset from .lapstyle_dataset import LapStyleDataset +from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset from .mpr_dataset import MPRTrain, MPRVal, MPRTest diff --git a/ppgan/datasets/builder.py b/ppgan/datasets/builder.py index a192f3f73d2596d31163eed0ef3a2b579a5f10d5..2e1bfb9ba7307db64bb471b120016fc4d220b28a 100644 --- a/ppgan/datasets/builder.py +++ b/ppgan/datasets/builder.py @@ -19,11 +19,25 @@ import numpy as np from paddle.distributed import ParallelEnv from paddle.io import DistributedBatchSampler -from ..utils.registry import Registry + +from .repeat_dataset import RepeatDataset +from ..utils.registry import Registry, build_from_config DATASETS = Registry("DATASETS") +def build_dataset(cfg): + name = cfg.pop('name') + + if name == 'RepeatDataset': + dataset_ = build_from_config(cfg['dataset'], DATASETS) + dataset = RepeatDataset(dataset_, cfg['times']) + else: + dataset = dataset = DATASETS.get(name)(**cfg) + + return dataset + + def build_dataloader(cfg, is_train=True, distributed=True): cfg_ = cfg.copy() @@ -31,9 +45,7 @@ def build_dataloader(cfg, is_train=True, distributed=True): num_workers = cfg_.pop('num_workers', 0) use_shared_memory = cfg_.pop('use_shared_memory', True) - name = cfg_.pop('name') - - dataset = DATASETS.get(name)(**cfg_) + dataset = build_dataset(cfg_) if distributed: sampler = DistributedBatchSampler(dataset, diff --git a/ppgan/datasets/repeat_dataset.py b/ppgan/datasets/repeat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1e803500479dca9519b16fdb5548e36131d64f --- /dev/null +++ b/ppgan/datasets/repeat_dataset.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + + +class RepeatDataset(paddle.io.Dataset): + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + + self._ori_len = len(self.dataset) + + def __getitem__(self, idx): + """Get item at each call. + + Args: + idx (int): Index for getting each item. + """ + return self.dataset[idx % self._ori_len] + + def __len__(self): + """Length of the dataset. + + Returns: + int: Length of the dataset. + """ + return self.times * self._ori_len diff --git a/ppgan/datasets/sr_reds_multiple_gt_dataset.py b/ppgan/datasets/sr_reds_multiple_gt_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bec816052fbfd5e789d6fb487119f9d9712988ad --- /dev/null +++ b/ppgan/datasets/sr_reds_multiple_gt_dataset.py @@ -0,0 +1,233 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import random +import numpy as np +import cv2 +from paddle.io import Dataset + +from .builder import DATASETS + +logger = logging.getLogger(__name__) + + +@DATASETS.register() +class SRREDSMultipleGTDataset(Dataset): + """REDS dataset for video super resolution for recurrent networks. + + The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) + frames. Then it applies specified transforms and finally returns a dict + containing paired data and other information. + + Args: + lq_folder (str | :obj:`Path`): Path to a lq folder. + gt_folder (str | :obj:`Path`): Path to a gt folder. + num_input_frames (int): Number of input frames. + pipeline (list[dict | callable]): A sequence of data transformations. + scale (int): Upsampling scale ratio. + val_partition (str): Validation partition mode. Choices ['official' or + 'REDS4']. Default: 'official'. + test_mode (bool): Store `True` when building test dataset. + Default: `False`. + """ + def __init__(self, + mode, + lq_folder, + gt_folder, + crop_size=256, + interval_list=[1], + random_reverse=False, + number_frames=15, + use_flip=False, + use_rot=False, + scale=4, + val_partition='REDS4', + batch_size=4): + super(SRREDSMultipleGTDataset, self).__init__() + self.mode = mode + self.fileroot = str(lq_folder) + self.gtroot = str(gt_folder) + self.crop_size = crop_size + self.interval_list = interval_list + self.random_reverse = random_reverse + self.number_frames = number_frames + self.use_flip = use_flip + self.use_rot = use_rot + self.scale = scale + self.val_partition = val_partition + self.batch_size = batch_size + self.data_infos = self.load_annotations() + + def __getitem__(self, idx): + """Get item at each call. + + Args: + idx (int): Index for getting each item. + """ + item = self.data_infos[idx] + idt = random.randint(0, 100 - self.number_frames) + item = item + '_' + f'{idt:03d}' + img_LQs, img_GTs = self.get_sample_data( + item, self.number_frames, self.interval_list, self.random_reverse, + self.gtroot, self.fileroot, self.crop_size, self.scale, + self.use_flip, self.use_rot, self.mode) + return {'lq': img_LQs, 'gt': img_GTs, 'lq_path': self.data_infos[idx]} + + def load_annotations(self): + """Load annoations for REDS dataset. + + Returns: + dict: Returned dict for LQ and GT pairs. + """ + # generate keys + keys = [f'{i:03d}' for i in range(0, 270)] + + if self.val_partition == 'REDS4': + val_partition = ['000', '011', '015', '020'] + elif self.val_partition == 'official': + val_partition = [f'{i:03d}' for i in range(240, 270)] + else: + raise ValueError(f'Wrong validation partition {self.val_partition}.' + f'Supported ones are ["official", "REDS4"]') + + if self.mode == 'train': + keys = [v for v in keys if v not in val_partition] + else: + keys = [v for v in keys if v in val_partition] + + data_infos = [] + for key in keys: + data_infos.append(key) + + return data_infos + + def get_sample_data(self, + item, + number_frames, + interval_list, + random_reverse, + gtroot, + fileroot, + crop_size, + scale, + use_flip, + use_rot, + mode='train'): + video_name = item.split('_')[0] + frame_name = item.split('_')[1] + frame_idxs = self.get_neighbor_frames(frame_name, + number_frames=number_frames, + interval_list=interval_list, + random_reverse=random_reverse) + + frame_list = [] + gt_list = [] + for frame_idx in frame_idxs: + frame_idx_name = "%08d" % frame_idx + img = self.read_img( + os.path.join(fileroot, video_name, frame_idx_name + '.png')) + frame_list.append(img) + gt_img = self.read_img( + os.path.join(gtroot, video_name, frame_idx_name + '.png')) + gt_list.append(gt_img) + H, W, C = frame_list[0].shape + # add random crop + if (mode == 'train') or (mode == 'valid'): + LQ_size = crop_size // scale + rnd_h = random.randint(0, max(0, H - LQ_size)) + rnd_w = random.randint(0, max(0, W - LQ_size)) + frame_list = [ + v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] + for v in frame_list + ] + rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) + gt_list = [ + v[rnd_h_HR:rnd_h_HR + crop_size, + rnd_w_HR:rnd_w_HR + crop_size, :] for v in gt_list + ] + + # add random flip and rotation + for v in gt_list: + frame_list.append(v) + if (mode == 'train') or (mode == 'valid'): + rlt = self.img_augment(frame_list, use_flip, use_rot) + else: + rlt = frame_list + frame_list = rlt[0:number_frames] + gt_list = rlt[number_frames:] + + # stack LQ images to NHWC, N is the frame number + frame_list = [v.transpose(2, 0, 1).astype('float32') for v in frame_list] + gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list] + + img_LQs = np.stack(frame_list, axis=0) + img_GTs = np.stack(gt_list, axis=0) + + return img_LQs, img_GTs + + def get_neighbor_frames(self, frame_name, number_frames, interval_list, + random_reverse): + frame_idx = int(frame_name) + interval = random.choice(interval_list) + neighbor_list = list( + range(frame_idx, frame_idx + number_frames, interval)) + if random_reverse and random.random() < 0.5: + neighbor_list.reverse() + + assert len(neighbor_list) == number_frames, \ + "frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames) + + return neighbor_list + + def read_img(self, path, size=None): + """read image by cv2 + + return: Numpy float32, HWC, BGR, [0,1] + """ + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + def img_augment(self, img_list, hflip=True, rot=True): + """horizontal flip OR rotate (0, 90, 180, 270 degrees) + """ + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + def __len__(self): + """Length of the dataset. + + Returns: + int: Length of the dataset. + """ + return len(self.data_infos) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 756cbfc9c0288fcfc412c216b88aa1f7f01e762f..1e6407fc21e2df5606b3eceff25175b3f5e76722 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -323,7 +323,6 @@ class Trainer: is_save_image=False): """ visual the images, use visualdl or directly write to the directory - Parameters: results_dir (str) -- directory name which contains saved images visual_results (dict) -- the results images dict @@ -440,7 +439,6 @@ class Trainer: def close(self): """ when finish the training need close file handler or other. - """ if self.enable_visualdl: - self.vdl_logger.close() + self.vdl_logger.close() \ No newline at end of file diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 31cf00de83edc4ce164361e8b0b841b83a4c9d54..0bbc6e16e667519f66519c24d1b4654c0119ee07 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -30,4 +30,5 @@ from .starganv2_model import StarGANv2Model from .edvr_model import EDVRModel from .firstorder_model import FirstOrderModel from .lapstyle_model import LapStyleDraModel, LapStyleRevFirstModel, LapStyleRevSecondModel +from .basicvsr_model import BasicVSRModel from .mpr_model import MPRModel diff --git a/ppgan/models/basicvsr_model.py b/ppgan/models/basicvsr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..76680b76070aab5e091fca0e8b4bd8163a3b7c03 --- /dev/null +++ b/ppgan/models/basicvsr_model.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .builder import MODELS +from .sr_model import BaseSRModel +from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet +from ..modules.init import reset_parameters +from ..utils.visual import tensor2img + + +@MODELS.register() +class BasicVSRModel(BaseSRModel): + """BasicVSR Model. + + Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021 + """ + def __init__(self, generator, fix_iter, pixel_criterion=None): + """Initialize the BasicVSR class. + + Args: + generator (dict): config of generator. + fix_iter (dict): config of fix_iter. + pixel_criterion (dict): config of pixel criterion. + """ + super(BasicVSRModel, self).__init__(generator, pixel_criterion) + self.fix_iter = fix_iter + self.current_iter = 1 + self.flag = True + init_basicvsr_weight(self.nets['generator']) + + def setup_input(self, input): + self.lq = paddle.to_tensor(input['lq']) + self.visual_items['lq'] = self.lq[:, 0, :, :, :] + if 'gt' in input: + self.gt = paddle.to_tensor(input['gt']) + self.visual_items['gt'] = self.gt[:, 0, :, :, :] + self.image_paths = input['lq_path'] + + def train_iter(self, optims=None): + optims['optim'].clear_grad() + if self.fix_iter: + if self.current_iter == 1: + print('Train BasicVSR with fixed spynet for', self.fix_iter, + 'iters.') + for name, param in self.nets['generator'].named_parameters(): + if 'spynet' in name: + param.trainable = False + elif self.current_iter >= self.fix_iter + 1 and self.flag: + print('Train all the parameters.') + for name, param in self.nets['generator'].named_parameters(): + param.trainable = True + if 'spynet' in name: + param.optimize_attr['learning_rate'] = 0.125 + self.flag = False + for net in self.nets.values(): + net.find_unused_parameters = False + + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output[:, 0, :, :, :] + # pixel loss + loss_pixel = self.pixel_criterion(self.output, self.gt) + + loss_pixel.backward() + optims['optim'].step() + + self.losses['loss_pixel'] = loss_pixel + + self.current_iter += 1 + + def test_iter(self, metrics=None): + self.nets['generator'].eval() + with paddle.no_grad(): + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output[:, 0, :, :, :] + self.nets['generator'].train() + + out_img = [] + gt_img = [] + for out_tensor, gt_tensor in zip(self.output[0], self.gt[0]): + # print(out_tensor.shape, gt_tensor.shape) + out_img.append(tensor2img(out_tensor, (0., 1.))) + gt_img.append(tensor2img(gt_tensor, (0., 1.))) + + if metrics is not None: + for metric in metrics.values(): + metric.update(out_img, gt_img) + + +def init_basicvsr_weight(net): + for m in net.children(): + if hasattr(m, 'weight') and not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D)): + reset_parameters(m) + continue + + if (not isinstance( + m, (ResidualBlockNoBN, PixelShufflePack, SPyNet))): + init_basicvsr_weight(m) diff --git a/ppgan/models/criterions/pixel_loss.py b/ppgan/models/criterions/pixel_loss.py index f33e632dad541052c9a7bba0ac257a6c2ed1ee97..6e878ad735306222b8a3a12bb85f3ee26c1c992b 100644 --- a/ppgan/models/criterions/pixel_loss.py +++ b/ppgan/models/criterions/pixel_loss.py @@ -59,8 +59,9 @@ class CharbonnierLoss(): eps (float): Default: 1e-12. """ - def __init__(self, eps=1e-12): + def __init__(self, eps=1e-12, reduction='sum'): self.eps = eps + self.reduction = reduction def __call__(self, pred, target, **kwargs): """Forward Function. @@ -69,7 +70,14 @@ class CharbonnierLoss(): pred (Tensor): of shape (N, C, H, W). Predicted tensor. target (Tensor): of shape (N, C, H, W). Ground truth tensor. """ - return paddle.sum(paddle.sqrt((pred - target)**2 + self.eps)) + if self.reduction == 'sum': + out = paddle.sum(paddle.sqrt((pred - target)**2 + self.eps)) + elif self.reduction == 'mean': + out = paddle.mean(paddle.sqrt((pred - target)**2 + self.eps)) + else: + raise NotImplementedError('CharbonnierLoss %s not implemented' % + self.reduction) + return out @CRITERIONS.register() diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 8e193107a789b469ee80bf517ae81fb1fbcd4d0d..8df2ec1d1a18d7523a29558bb330ad442197da0b 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -30,4 +30,5 @@ from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Ma from .edvr import EDVRNet from .generator_firstorder import FirstOrderGenerator from .generater_lapstyle import DecoderNet, Encoder, RevisionNet +from .basicvsr import BasicVSRNet from .mpr import MPRNet diff --git a/ppgan/models/generators/basicvsr.py b/ppgan/models/generators/basicvsr.py new file mode 100644 index 0000000000000000000000000000000000000000..acb412b9924e6524aad28d37e877a0619776f1bb --- /dev/null +++ b/ppgan/models/generators/basicvsr.py @@ -0,0 +1,623 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +import numpy as np +import scipy.io as scio + +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import initializer +from ...utils.download import get_path_from_url +from ...modules.init import kaiming_normal_, constant_ + +from .builder import GENERATORS + + +@paddle.no_grad() +def default_init_weights(layer_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + layer_list (list[nn.Layer] | nn.Layer): Layers to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(layer_list, list): + layer_list = [layer_list] + for m in layer_list: + if isinstance(m, nn.Conv2D): + kaiming_normal_(m.weight, **kwargs) + scale_weight = scale * m.weight + m.weight.set_value(scale_weight) + if m.bias is not None: + constant_(m.bias, bias_fill) + elif isinstance(m, nn.Linear): + kaiming_normal_(m.weight, **kwargs) + scale_weight = scale * m.weight + m.weight.set_value(scale_weight) + if m.bias is not None: + constant_(m.bias, bias_fill) + elif isinstance(m, nn.BatchNorm): + constant_(m.weight, 1) + + +class PixelShufflePack(nn.Layer): + """ Pixel Shuffle upsample layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + scale_factor (int): Upsample ratio. + upsample_kernel (int): Kernel size of Conv layer to expand channels. + + Returns: + Upsampled feature map. + """ + def __init__(self, in_channels, out_channels, scale_factor, + upsample_kernel): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + self.upsample_kernel = upsample_kernel + self.upsample_conv = nn.Conv2D(self.in_channels, + self.out_channels * scale_factor * + scale_factor, + self.upsample_kernel, + padding=(self.upsample_kernel - 1) // 2) + self.pixel_shuffle = nn.PixelShuffle(self.scale_factor) + self.init_weights() + + def init_weights(self): + """Initialize weights for PixelShufflePack. + """ + default_init_weights(self, 1) + + def forward(self, x): + """Forward function for PixelShufflePack. + + Args: + x (Tensor): Input tensor with shape (in_channels, c, h, w). + + Returns: + Tensor with shape (out_channels, c, scale_factor*h, scale_factor*w). + """ + x = self.upsample_conv(x) + x = self.pixel_shuffle(x) + return x + + +def MakeMultiBlocks(func, num_layers, nf=64): + """Make layers by stacking the same blocks. + + Args: + func (nn.Layer): nn.Layer class for basic block. + num_layers (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + Blocks = nn.Sequential() + for i in range(num_layers): + Blocks.add_sublayer('block%d' % i, func(nf)) + return Blocks + + +class ResidualBlockNoBN(nn.Layer): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + nf (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1.0. + """ + def __init__(self, nf=64, res_scale=1.0): + super(ResidualBlockNoBN, self).__init__() + self.nf = nf + self.res_scale = res_scale + self.conv1 = nn.Conv2D(self.nf, self.nf, 3, 1, 1) + self.conv2 = nn.Conv2D(self.nf, self.nf, 3, 1, 1) + self.relu = nn.ReLU() + if self.res_scale == 1.0: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor with shape (n, c, h, w). + """ + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +def flow_warp(x, + flow, + interpolation='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or a feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is + a two-channel, denoting the width and height relative offsets. + Note that the values are not normalized to [-1, 1]. + interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. + Default: 'bilinear'. + padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Whether align corners. Default: True. + + Returns: + Tensor: Warped image or feature map. + """ + if x.shape[-2:] != flow.shape[1:3]: + raise ValueError(f'The spatial sizes of input ({x.shape[-2:]}) and ' + f'flow ({flow.shape[1:3]}) are not the same.') + _, _, h, w = x.shape + # create mesh grid + grid_y, grid_x = paddle.meshgrid(paddle.arange(0, h), paddle.arange(0, w)) + grid = paddle.stack((grid_x, grid_y), axis=2) # (w, h, 2) + grid = paddle.cast(grid, 'float32') + grid.stop_gradient = True + + grid_flow = grid + flow + # scale grid_flow to [-1,1] + grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 + grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 + grid_flow = paddle.stack((grid_flow_x, grid_flow_y), axis=3) + output = F.grid_sample(x, + grid_flow, + mode=interpolation, + padding_mode=padding_mode, + align_corners=align_corners) + return output + + +class SPyNetBasicModule(nn.Layer): + """Basic Module for SPyNet. + + Paper: + Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 + """ + def __init__(self): + super().__init__() + + self.conv1 = nn.Conv2D(in_channels=8, + out_channels=32, + kernel_size=7, + stride=1, + padding=3) + self.conv2 = nn.Conv2D(in_channels=32, + out_channels=64, + kernel_size=7, + stride=1, + padding=3) + self.conv3 = nn.Conv2D(in_channels=64, + out_channels=32, + kernel_size=7, + stride=1, + padding=3) + self.conv4 = nn.Conv2D(in_channels=32, + out_channels=16, + kernel_size=7, + stride=1, + padding=3) + self.conv5 = nn.Conv2D(in_channels=16, + out_channels=2, + kernel_size=7, + stride=1, + padding=3) + self.relu = nn.ReLU() + + def forward(self, tensor_input): + """ + Args: + tensor_input (Tensor): Input tensor with shape (b, 8, h, w). + 8 channels contain: + [reference image (3), neighbor image (3), initial flow (2)]. + + Returns: + Tensor: Refined flow with shape (b, 2, h, w) + """ + out = self.relu(self.conv1(tensor_input)) + out = self.relu(self.conv2(out)) + out = self.relu(self.conv3(out)) + out = self.relu(self.conv4(out)) + out = self.conv5(out) + return out + + +class SPyNet(nn.Layer): + """SPyNet network structure. + + The difference to the SPyNet in paper is that + 1. more SPyNetBasicModule is used in this version, and + 2. no batch normalization is used in this version. + + Paper: + Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 + + """ + def __init__(self): + super().__init__() + + self.basic_module0 = SPyNetBasicModule() + self.basic_module1 = SPyNetBasicModule() + self.basic_module2 = SPyNetBasicModule() + self.basic_module3 = SPyNetBasicModule() + self.basic_module4 = SPyNetBasicModule() + self.basic_module5 = SPyNetBasicModule() + + self.register_buffer( + 'mean', + paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1])) + self.register_buffer( + 'std', + paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1])) + + def compute_flow(self, ref, supp): + """Compute flow from ref to supp. + + Note that in this function, the images are already resized to a + multiple of 32. + + Args: + ref (Tensor): Reference image with shape of (n, 3, h, w). + supp (Tensor): Supporting image with shape of (n, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (n, 2, h, w). + """ + + n, _, h, w = ref.shape + + # normalize the input images + ref = [(ref - self.mean) / self.std] + supp = [(supp - self.mean) / self.std] + + # generate downsampled frames + for level in range(5): + ref.append(F.avg_pool2d(ref[-1], kernel_size=2, stride=2)) + supp.append(F.avg_pool2d(supp[-1], kernel_size=2, stride=2)) + ref = ref[::-1] + supp = supp[::-1] + + # flow computation + flow = paddle.to_tensor(np.zeros([n, 2, h // 32, w // 32], 'float32')) + + # level=0 + flow_up = flow + flow = flow_up + self.basic_module0( + paddle.concat([ + ref[0], + flow_warp(supp[0], + flow_up.transpose([0, 2, 3, 1]), + padding_mode='border'), flow_up + ], 1)) + + # level=1 + flow_up = F.interpolate( + flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + flow = flow_up + self.basic_module1( + paddle.concat([ + ref[1], + flow_warp(supp[1], + flow_up.transpose([0, 2, 3, 1]), + padding_mode='border'), flow_up + ], 1)) + + # level=2 + flow_up = F.interpolate( + flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + flow = flow_up + self.basic_module2( + paddle.concat([ + ref[2], + flow_warp(supp[2], + flow_up.transpose([0, 2, 3, 1]), + padding_mode='border'), flow_up + ], 1)) + + # level=3 + flow_up = F.interpolate( + flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + flow = flow_up + self.basic_module3( + paddle.concat([ + ref[3], + flow_warp(supp[3], + flow_up.transpose([0, 2, 3, 1]), + padding_mode='border'), flow_up + ], 1)) + + # level=4 + flow_up = F.interpolate( + flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + flow = flow_up + self.basic_module4( + paddle.concat([ + ref[4], + flow_warp(supp[4], + flow_up.transpose([0, 2, 3, 1]), + padding_mode='border'), flow_up + ], 1)) + + # level=5 + flow_up = F.interpolate( + flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + flow = flow_up + self.basic_module5( + paddle.concat([ + ref[5], + flow_warp(supp[5], + flow_up.transpose([0, 2, 3, 1]), + padding_mode='border'), flow_up + ], 1)) + + return flow + + def forward(self, ref, supp): + """Forward function of SPyNet. + + This function computes the optical flow from ref to supp. + + Args: + ref (Tensor): Reference image with shape of (n, 3, h, w). + supp (Tensor): Supporting image with shape of (n, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (n, 2, h, w). + """ + + # upsize to a multiple of 32 + h, w = ref.shape[2:4] + w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1) + h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1) + ref = F.interpolate(ref, + size=(h_up, w_up), + mode='bilinear', + align_corners=False) + supp = F.interpolate(supp, + size=(h_up, w_up), + mode='bilinear', + align_corners=False) + ref.stop_gradient = False + supp.stop_gradient = False + # compute flow, and resize back to the original resolution + flow_up = self.compute_flow(ref, supp) + flow = F.interpolate(flow_up, + size=(h, w), + mode='bilinear', + align_corners=False) + + # adjust the flow values + # todo: grad bug + # flow[:, 0, :, :] *= (float(w) / float(w_up)) + # flow[:, 1, :, :] *= (float(h) / float(h_up)) + + flow_x = flow[:, 0:1, :, :] * (float(w) / float(w_up)) + flow_y = flow[:, 1:2, :, :] * (float(h) / float(h_up)) + flow = paddle.concat([flow_x, flow_y], 1) + + return flow + + +class ResidualBlocksWithInputConv(nn.Layer): + """Residual blocks with a convolution in front. + + Args: + in_channels (int): Number of input channels of the first conv. + out_channels (int): Number of channels of the residual blocks. + Default: 64. + num_blocks (int): Number of residual blocks. Default: 30. + """ + def __init__(self, in_channels, out_channels=64, num_blocks=30): + super().__init__() + + # a convolution used to match the channels of the residual blocks + self.covn1 = nn.Conv2D(in_channels, out_channels, 3, 1, 1) + self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1) + + # residual blocks + self.ResidualBlocks = MakeMultiBlocks(ResidualBlockNoBN, + num_blocks, + nf=out_channels) + + def forward(self, feat): + """ + Forward function for ResidualBlocksWithInputConv. + + Args: + feat (Tensor): Input feature with shape (n, in_channels, h, w) + + Returns: + Tensor: Output feature with shape (n, out_channels, h, w) + """ + out = self.Leaky_relu(self.covn1(feat)) + out = self.ResidualBlocks(out) + return out + + +@GENERATORS.register() +class BasicVSRNet(nn.Layer): + """BasicVSR network structure for video super-resolution. + + Support only x4 upsampling. + Paper: + BasicVSR: The Search for Essential Components in Video Super-Resolution + and Beyond, CVPR, 2021 + + Args: + mid_channels (int): Channel number of the intermediate features. + Default: 64. + num_blocks (int): Number of residual blocks in each propagation branch. + Default: 30. + """ + def __init__(self, mid_channels=64, num_blocks=30): + + super().__init__() + + self.mid_channels = mid_channels + + # optical flow network for feature alignment + self.spynet = SPyNet() + weight_path = get_path_from_url( + 'https://paddlegan.bj.bcebos.com/models/spynet.pdparams') + self.spynet.set_state_dict(paddle.load(weight_path)) + + # propagation branches + self.backward_resblocks = ResidualBlocksWithInputConv( + mid_channels + 3, mid_channels, num_blocks) + self.forward_resblocks = ResidualBlocksWithInputConv( + mid_channels + 3, mid_channels, num_blocks) + + # upsample + self.fusion = nn.Conv2D(mid_channels * 2, mid_channels, 1, 1, 0) + self.upsample1 = PixelShufflePack(mid_channels, + mid_channels, + 2, + upsample_kernel=3) + self.upsample2 = PixelShufflePack(mid_channels, + 64, + 2, + upsample_kernel=3) + self.conv_hr = nn.Conv2D(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2D(64, 3, 3, 1, 1) + self.img_upsample = nn.Upsample(scale_factor=4, + mode='bilinear', + align_corners=False) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1) + + def check_if_mirror_extended(self, lrs): + """Check whether the input is a mirror-extended sequence. + + If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the + (t-1-i)-th frame. + + Args: + lrs (tensor): Input LR images with shape (n, t, c, h, w) + """ + + self.is_mirror_extended = False + if lrs.shape[1] % 2 == 0: + lrs_1, lrs_2 = paddle.chunk(lrs, 2, axis=1) + lrs_2 = paddle.flip(lrs_2, [1]) + if paddle.norm(lrs_1 - lrs_2) == 0: + self.is_mirror_extended = True + + def compute_flow(self, lrs): + """Compute optical flow using SPyNet for feature warping. + + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + + Args: + lrs (tensor): Input LR images with shape (n, t, c, h, w) + + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lrs.shape + + lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w]) + lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w]) + + flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w]) + + if self.is_mirror_extended: # flows_forward = flows_backward.flip(1) + flows_forward = None + else: + flows_forward = self.spynet(lrs_2, + lrs_1).reshape([n, t - 1, 2, h, w]) + + return flows_forward, flows_backward + + def forward(self, lrs): + """Forward function for BasicVSR. + + Args: + lrs (Tensor): Input LR sequence with shape (n, t, c, h, w). + + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + n, t, c, h, w = lrs.shape + assert h >= 64 and w >= 64, ( + 'The height and width of inputs should be at least 64, ' + f'but got {h} and {w}.') + + # check whether the input is an extended sequence + self.check_if_mirror_extended(lrs) + + # compute optical flow + flows_forward, flows_backward = self.compute_flow(lrs) + + # backward-time propgation + outputs = [] + feat_prop = paddle.to_tensor( + np.zeros([n, self.mid_channels, h, w], 'float32')) + for i in range(t - 1, -1, -1): + if i < t - 1: # no warping required for the last timestep + flow = flows_backward[:, i, :, :, :] + feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1])) + + feat_prop = paddle.concat([lrs[:, i, :, :, :], feat_prop], axis=1) + feat_prop = self.backward_resblocks(feat_prop) + + outputs.append(feat_prop) + outputs = outputs[::-1] + + # forward-time propagation and upsampling + feat_prop = paddle.zeros_like(feat_prop) + for i in range(0, t): + lr_curr = lrs[:, i, :, :, :] + if i > 0: # no warping required for the first timestep + if flows_forward is not None: + flow = flows_forward[:, i - 1, :, :, :] + else: + flow = flows_backward[:, -i, :, :, :] + feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1])) + + feat_prop = paddle.concat([lr_curr, feat_prop], axis=1) + feat_prop = self.forward_resblocks(feat_prop) + + # upsampling given the backward and forward features + out = paddle.concat([outputs[i], feat_prop], axis=1) + out = self.lrelu(self.fusion(out)) + out = self.lrelu(self.upsample1(out)) + out = self.lrelu(self.upsample2(out)) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + base = self.img_upsample(lr_curr) + out += base + outputs[i] = out + + return paddle.stack(outputs, axis=1)