From 64181fa9b433864e25a0c4ae630a5f017b15ffa8 Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Tue, 9 Mar 2021 21:09:01 +0800 Subject: [PATCH] add edvr model (#208) * add edvr model * modifying code formats and comments * modifying code formats and comments * modifying code formats and comments * add notes Co-authored-by: LielinJiang <50691816+LielinJiang@users.noreply.github.com> --- configs/edvr.yaml | 100 ++++ configs/edvr_wo_tsa.yaml | 100 ++++ ppgan/datasets/__init__.py | 1 + ppgan/datasets/edvr_dataset.py | 314 ++++++++++ ppgan/models/__init__.py | 1 + ppgan/models/criterions/__init__.py | 2 +- ppgan/models/criterions/pixel_loss.py | 21 + ppgan/models/edvr_model.py | 86 +++ ppgan/models/generators/__init__.py | 1 + ppgan/models/generators/edvr.py | 819 ++++++++++++++++++++++++++ ppgan/modules/dcn.py | 147 +++++ ppgan/modules/init.py | 7 + 12 files changed, 1598 insertions(+), 1 deletion(-) create mode 100644 configs/edvr.yaml create mode 100644 configs/edvr_wo_tsa.yaml create mode 100644 ppgan/datasets/edvr_dataset.py create mode 100644 ppgan/models/edvr_model.py create mode 100644 ppgan/models/generators/edvr.py create mode 100644 ppgan/modules/dcn.py diff --git a/configs/edvr.yaml b/configs/edvr.yaml new file mode 100644 index 0000000..8aa206a --- /dev/null +++ b/configs/edvr.yaml @@ -0,0 +1,100 @@ +total_iters: 600000 +output_dir: output_dir +checkpoints_dir: checkpoints +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: EDVRModel + tsa_iter: 50000 + generator: + name: EDVRNet + in_nf: 3 + out_nf: 3 + scale_factor: 4 + nf: 64 + nframes: 5 + groups: 8 + front_RBs: 5 + back_RBs: 10 + center: 2 + predeblur: False + HR_in: False + w_TSA: True + TSA_only: False + pixel_criterion: + name: CharbonnierLoss + +dataset: + train: + name: REDSDataset + mode: train + gt_folder: data/REDS/train_sharp/X4 + lq_folder: data/REDS/train_sharp_bicubic/X4 + img_format: png + crop_size: 256 + interval_list: [1] + random_reverse: False + number_frames: 5 + use_flip: True + use_rot: True + buf_size: 1024 + scale: 4 + fix_random_seed: 10 + num_workers: 3 + batch_size: 4 + + + test: + name: REDSDataset + mode: test + gt_folder: data/REDS/REDS4_test_sharp/X4 + lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 + img_format: png + interval_list: [1] + random_reverse: False + number_frames: 5 + batch_size: 1 + use_flip: False + use_rot: False + buf_size: 1024 + scale: 4 + fix_random_seed: 10 + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 4e-4 + periods: [50000, 100000, 150000, 150000, 150000] + restart_weights: [1, 1, 1, 1, 1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 0 + test_y_channel: False + ssim: + name: SSIM + crop_border: 0 + test_y_channel: False + +log_config: + interval: 10 + visiual_interval: 5000 + +snapshot_config: + interval: 5000 diff --git a/configs/edvr_wo_tsa.yaml b/configs/edvr_wo_tsa.yaml new file mode 100644 index 0000000..776da6a --- /dev/null +++ b/configs/edvr_wo_tsa.yaml @@ -0,0 +1,100 @@ +total_iters: 600000 +output_dir: output_dir +checkpoints_dir: checkpoints +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: EDVRModel + tsa_iter: 0 + generator: + name: EDVRNet + in_nf: 3 + out_nf: 3 + scale_factor: 4 + nf: 64 + nframes: 5 + groups: 8 + front_RBs: 5 + back_RBs: 10 + center: 2 + predeblur: False + HR_in: False + w_TSA: False + TSA_only: False + pixel_criterion: + name: CharbonnierLoss + +dataset: + train: + name: REDSDataset + mode: train + gt_folder: data/REDS/train_sharp/X4 + lq_folder: data/REDS/train_sharp_bicubic/X4 + img_format: png + crop_size: 256 + interval_list: [1] + random_reverse: False + number_frames: 5 + use_flip: True + use_rot: True + buf_size: 1024 + scale: 4 + fix_random_seed: 10 + num_workers: 3 + batch_size: 4 + + + test: + name: REDSDataset + mode: test + gt_folder: data/REDS/REDS4_test_sharp/X4 + lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 + img_format: png + interval_list: [1] + random_reverse: False + number_frames: 5 + batch_size: 1 + use_flip: False + use_rot: False + buf_size: 1024 + scale: 4 + fix_random_seed: 10 + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 4e-4 + periods: [150000, 150000, 150000, 150000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 0 + test_y_channel: False + ssim: + name: SSIM + crop_border: 0 + test_y_channel: False + +log_config: + interval: 10 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index cd5b544..84849ed 100755 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -21,3 +21,4 @@ from .common_vision_dataset import CommonVisionDataset from .animeganv2_dataset import AnimeGANV2Dataset from .wav2lip_dataset import Wav2LipDataset from .starganv2_dataset import StarGANv2Dataset +from .edvr_dataset import REDSDataset diff --git a/ppgan/datasets/edvr_dataset.py b/ppgan/datasets/edvr_dataset.py new file mode 100644 index 0000000..97a1a2f --- /dev/null +++ b/ppgan/datasets/edvr_dataset.py @@ -0,0 +1,314 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import random +import numpy as np +import scipy.io as scio +import cv2 +import paddle +from paddle.io import Dataset, DataLoader +from .builder import DATASETS + +logger = logging.getLogger(__name__) + + +@DATASETS.register() +class REDSDataset(Dataset): + """ + REDS dataset for EDVR model + """ + def __init__(self, + mode, + lq_folder, + gt_folder, + img_format="png", + crop_size=256, + interval_list=[1], + random_reverse=False, + number_frames=5, + batch_size=32, + use_flip=False, + use_rot=False, + buf_size=1024, + scale=4, + fix_random_seed=False): + super(REDSDataset, self).__init__() + self.format = img_format + self.mode = mode + self.crop_size = crop_size + self.interval_list = interval_list + self.random_reverse = random_reverse + self.number_frames = number_frames + self.batch_size = batch_size + self.fileroot = lq_folder + self.use_flip = use_flip + self.use_rot = use_rot + self.buf_size = buf_size + self.fix_random_seed = fix_random_seed + + if self.mode != 'infer': + self.gtroot = gt_folder + self.scale = scale + self.LR_input = (self.scale > 1) + if self.fix_random_seed: + random.seed(10) + np.random.seed(10) + self.num_reader_threads = 1 + + self._init_() + + def _init_(self): + logger.info('initialize reader ... ') + print("initialize reader") + self.filelist = [] + for video_name in os.listdir(self.fileroot): + if (self.mode == 'train') and (video_name in [ + '000', '011', '015', '020' + ]): #These four videos are used as val + continue + for frame_name in os.listdir(os.path.join(self.fileroot, + video_name)): + frame_idx = frame_name.split('.')[0] + video_frame_idx = video_name + '_' + str(frame_idx) + # for each item in self.filelist is like '010_00000015', '260_00000090' + self.filelist.append(video_frame_idx) + if self.mode == 'test': + self.filelist.sort() + print(len(self.filelist)) + + def __getitem__(self, index): + """Get training sample + + return: lq:[5,3,W,H], + gt:[3,W,H], + lq_path:str + """ + item = self.filelist[index] + img_LQs, img_GT = self.get_sample_data( + item, self.number_frames, self.interval_list, self.random_reverse, + self.gtroot, self.fileroot, self.LR_input, self.crop_size, + self.scale, self.use_flip, self.use_rot, self.mode) + return {'lq': img_LQs, 'gt': img_GT, 'lq_path': self.filelist[index]} + + def get_sample_data(self, + item, + number_frames, + interval_list, + random_reverse, + gtroot, + fileroot, + LR_input, + crop_size, + scale, + use_flip, + use_rot, + mode='train'): + video_name = item.split('_')[0] + frame_name = item.split('_')[1] + if (mode == 'train') or (mode == 'valid'): + ngb_frames, name_b = self.get_neighbor_frames(frame_name, \ + number_frames=number_frames, \ + interval_list=interval_list, \ + random_reverse=random_reverse) + elif mode == 'test': + ngb_frames, name_b = self.get_test_neighbor_frames( + int(frame_name), number_frames) + else: + raise NotImplementedError('mode {} not implemented'.format(mode)) + frame_name = name_b + img_GT = self.read_img( + os.path.join(gtroot, video_name, frame_name + '.png')) + frame_list = [] + for ngb_frm in ngb_frames: + ngb_name = "%08d" % ngb_frm + img = self.read_img( + os.path.join(fileroot, video_name, ngb_name + '.png')) + frame_list.append(img) + H, W, C = frame_list[0].shape + # add random crop + if (mode == 'train') or (mode == 'valid'): + if LR_input: + LQ_size = crop_size // scale + rnd_h = random.randint(0, max(0, H - LQ_size)) + rnd_w = random.randint(0, max(0, W - LQ_size)) + frame_list = [ + v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] + for v in frame_list + ] + rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) + img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, + rnd_w_HR:rnd_w_HR + crop_size, :] + else: + rnd_h = random.randint(0, max(0, H - crop_size)) + rnd_w = random.randint(0, max(0, W - crop_size)) + frame_list = [ + v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] + for v in frame_list + ] + img_GT = img_GT[rnd_h:rnd_h + crop_size, + rnd_w:rnd_w + crop_size, :] + + # add random flip and rotation + frame_list.append(img_GT) + if (mode == 'train') or (mode == 'valid'): + rlt = self.img_augment(frame_list, use_flip, use_rot) + else: + rlt = frame_list + frame_list = rlt[0:-1] + img_GT = rlt[-1] + + # stack LQ images to NHWC, N is the frame number + img_LQs = np.stack(frame_list, axis=0) + # BGR to RGB, HWC to CHW, numpy to tensor + img_GT = img_GT[:, :, [2, 1, 0]] + img_LQs = img_LQs[:, :, :, [2, 1, 0]] + img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32') + img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') + + return img_LQs, img_GT + + def get_neighbor_frames(self, + frame_name, + number_frames, + interval_list, + random_reverse, + max_frame=99, + bordermode=False): + center_frame_idx = int(frame_name) + half_N_frames = number_frames // 2 + interval = random.choice(interval_list) + if bordermode: + direction = 1 + if random_reverse and random.random() < 0.5: + direction = random.choice([0, 1]) + if center_frame_idx + interval * (number_frames - 1) > max_frame: + direction = 0 + elif center_frame_idx - interval * (number_frames - 1) < 0: + direction = 1 + if direction == 1: + neighbor_list = list( + range(center_frame_idx, + center_frame_idx + interval * number_frames, + interval)) + else: + neighbor_list = list( + range(center_frame_idx, + center_frame_idx - interval * number_frames, + -interval)) + name_b = '{:08d}'.format(neighbor_list[0]) + else: + # ensure not exceeding the borders + while (center_frame_idx + half_N_frames * interval > max_frame) or ( + center_frame_idx - half_N_frames * interval < 0): + center_frame_idx = random.randint(0, max_frame) + neighbor_list = list( + range(center_frame_idx - half_N_frames * interval, + center_frame_idx + half_N_frames * interval + 1, + interval)) + if random_reverse and random.random() < 0.5: + neighbor_list.reverse() + name_b = '{:08d}'.format(neighbor_list[half_N_frames]) + assert len(neighbor_list) == number_frames, \ + "frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames) + + return neighbor_list, name_b + + def read_img(self, path, size=None): + """read image by cv2 + + return: Numpy float32, HWC, BGR, [0,1] + """ + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + def img_augment(self, img_list, hflip=True, rot=True): + """horizontal flip OR rotate (0, 90, 180, 270 degrees) + """ + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + def get_test_neighbor_frames(self, crt_i, N, max_n=100, padding='new_info'): + """Generate an index list for reading N frames from a sequence of images + Args: + crt_i (int): current center index + max_n (int): max number of the sequence of images (calculated from 1) + N (int): reading N frames + padding (str): padding mode, one of replicate | reflection | new_info | circle + Example: crt_i = 0, N = 5 + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + new_info: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + return_l (list [int]): a list of indexes + """ + max_n = max_n - 1 + n_pad = N // 2 + return_l = [] + + for i in range(crt_i - n_pad, crt_i + n_pad + 1): + if i < 0: + if padding == 'replicate': + add_idx = 0 + elif padding == 'reflection': + add_idx = -i + elif padding == 'new_info': + add_idx = (crt_i + n_pad) + (-i) + elif padding == 'circle': + add_idx = N + i + else: + raise ValueError('Wrong padding mode') + elif i > max_n: + if padding == 'replicate': + add_idx = max_n + elif padding == 'reflection': + add_idx = max_n * 2 - i + elif padding == 'new_info': + add_idx = (crt_i - n_pad) - (i - max_n) + elif padding == 'circle': + add_idx = i - N + else: + raise ValueError('Wrong padding mode') + else: + add_idx = i + return_l.append(add_idx) + print(return_l) + name_b = '{:08d}'.format(crt_i) + return return_l, name_b + + def __len__(self): + """Return the total number of images in the dataset. + """ + return len(self.filelist) diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 52d07a9..a723ac1 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -27,3 +27,4 @@ from .styleganv2_model import StyleGAN2Model from .wav2lip_model import Wav2LipModel from .wav2lip_hq_model import Wav2LipModelHq from .starganv2_model import StarGANv2Model +from .edvr_model import EDVRModel diff --git a/ppgan/models/criterions/__init__.py b/ppgan/models/criterions/__init__.py index 4c49542..a172f37 100644 --- a/ppgan/models/criterions/__init__.py +++ b/ppgan/models/criterions/__init__.py @@ -1,5 +1,5 @@ from .gan_loss import GANLoss from .perceptual_loss import PerceptualLoss -from .pixel_loss import L1Loss, MSELoss +from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss from .builder import build_criterion diff --git a/ppgan/models/criterions/pixel_loss.py b/ppgan/models/criterions/pixel_loss.py index 4c94976..f496ef7 100644 --- a/ppgan/models/criterions/pixel_loss.py +++ b/ppgan/models/criterions/pixel_loss.py @@ -49,6 +49,27 @@ class L1Loss(): return self.loss_weight * self._l1_loss(pred, target) +@CRITERIONS.register() +class CharbonnierLoss(): + """Charbonnier Loss (L1). + + Args: + eps (float): Default: 1e-12. + + """ + def __init__(self, eps=1e-12): + self.eps = eps + + def __call__(self, pred, target, **kwargs): + """Forward Function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + """ + return paddle.sum(paddle.sqrt((pred - target)**2 + self.eps)) + + @CRITERIONS.register() class MSELoss(): """MSE (L2) loss. diff --git a/ppgan/models/edvr_model.py b/ppgan/models/edvr_model.py new file mode 100644 index 0000000..3a3330e --- /dev/null +++ b/ppgan/models/edvr_model.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .builder import MODELS +from .sr_model import BaseSRModel +from .generators.edvr import ResidualBlockNoBN +from ..modules.init import reset_parameters + + +@MODELS.register() +class EDVRModel(BaseSRModel): + """EDVR Model. + + Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. + """ + def __init__(self, generator, tsa_iter, pixel_criterion=None): + """Initialize the EDVR class. + + Args: + generator (dict): config of generator. + tsa_iter (dict): config of tsa_iter. + pixel_criterion (dict): config of pixel criterion. + """ + super(EDVRModel, self).__init__(generator, pixel_criterion) + self.tsa_iter = tsa_iter + self.current_iter = 1 + init_edvr_weight(self.nets['generator']) + + def setup_input(self, input): + self.lq = paddle.to_tensor(input['lq']) + self.visual_items['lq'] = self.lq[:, 2, :, :, :] + self.visual_items['lq-2'] = self.lq[:, 0, :, :, :] + self.visual_items['lq-1'] = self.lq[:, 1, :, :, :] + self.visual_items['lq+1'] = self.lq[:, 3, :, :, :] + self.visual_items['lq+2'] = self.lq[:, 4, :, :, :] + if 'gt' in input: + self.gt = paddle.to_tensor(input['gt']) + self.visual_items['gt'] = self.gt + self.image_paths = input['lq_path'] + + def train_iter(self, optims=None): + optims['optim'].clear_grad() + if self.tsa_iter: + if self.current_iter == 1: + print('Only train TSA module for', self.tsa_iter, 'iters.') + for name, param in self.nets['generator'].named_parameters(): + if 'TSAModule' not in name: + param.trainable = False + elif self.current_iter == self.tsa_iter + 1: + print('Train all the parameters.') + for param in self.nets['generator'].parameters(): + param.trainable = True + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + # pixel loss + loss_pixel = self.pixel_criterion(self.output, self.gt) + self.losses['loss_pixel'] = loss_pixel + + loss_pixel.backward() + optims['optim'].step() + self.current_iter += 1 + + +def init_edvr_weight(net): + def reset_func(m): + if hasattr(m, + 'weight') and (not isinstance(m, + (nn.BatchNorm, nn.BatchNorm2D)) + ) and (not isinstance(m, ResidualBlockNoBN)): + reset_parameters(m) + + net.apply(reset_func) diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index b420254..90f7b6a 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -27,4 +27,5 @@ from .generator_styleganv2 import StyleGANv2Generator from .generator_pixel2style2pixel import Pixel2Style2Pixel from .drn import DRNGenerator from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN +from .edvr import EDVRNet diff --git a/ppgan/models/generators/edvr.py b/ppgan/models/generators/edvr.py new file mode 100644 index 0000000..57cb859 --- /dev/null +++ b/ppgan/models/generators/edvr.py @@ -0,0 +1,819 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +import numpy as np +import scipy.io as scio + +import paddle.nn as nn +from paddle.nn import initializer +from ...modules.init import kaiming_normal_, constant_ + +from ...modules.dcn import DeformableConv_dygraph +# from paddle.vision.ops import DeformConv2D #to be compiled + +from .builder import GENERATORS + + +@paddle.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for m in module_list: + if isinstance(m, nn.Conv2D): + kaiming_normal_(m.weight, **kwargs) + scale_weight = scale * m.weight + m.weight.set_value(scale_weight) + if m.bias is not None: + constant_(m.bias, bias_fill) + elif isinstance(m, nn.Linear): + kaiming_normal_(m.weight, **kwargs) + scale_weight = scale * m.weight + m.weight.set_value(scale_weight) + if m.bias is not None: + constant_(m.bias, bias_fill) + + +class ResidualBlockNoBN(nn.Layer): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + def __init__(self, nf=64): + super(ResidualBlockNoBN, self).__init__() + self.nf = nf + self.conv1 = nn.Conv2D(self.nf, self.nf, 3, 1, 1) + self.conv2 = nn.Conv2D(self.nf, self.nf, 3, 1, 1) + self.relu = nn.ReLU() + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out + + +def MakeMultiBlocks(func, num_layers, nf=64): + """Make layers by stacking the same blocks. + + Args: + func (nn.Layer): nn.Layer class for basic block. + num_layers (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + Blocks = nn.Sequential() + for i in range(num_layers): + Blocks.add_sublayer('block%d' % i, func(nf)) + return Blocks + + +class PredeblurResNetPyramid(nn.Layer): + """Pre-dublur module. + + Args: + in_nf (int): Channel number of input image. Default: 3. + nf (int): Channel number of intermediate features. Default: 64. + HR_in (bool): Whether the input has high resolution. Default: False. + """ + def __init__(self, in_nf=3, nf=64, HR_in=False): + super(PredeblurResNetPyramid, self).__init__() + self.in_nf = in_nf + self.nf = nf + self.HR_in = True if HR_in else False + self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1) + if self.HR_in: + self.conv_first_1 = nn.Conv2D(in_channels=self.in_nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.conv_first_2 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=2, + padding=1) + self.conv_first_3 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=2, + padding=1) + else: + self.conv_first = nn.Conv2D(in_channels=self.in_nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.RB_L1_1 = ResidualBlockNoBN(nf=self.nf) + self.RB_L1_2 = ResidualBlockNoBN(nf=self.nf) + self.RB_L1_3 = ResidualBlockNoBN(nf=self.nf) + self.RB_L1_4 = ResidualBlockNoBN(nf=self.nf) + self.RB_L1_5 = ResidualBlockNoBN(nf=self.nf) + self.RB_L2_1 = ResidualBlockNoBN(nf=self.nf) + self.RB_L2_2 = ResidualBlockNoBN(nf=self.nf) + self.RB_L3_1 = ResidualBlockNoBN(nf=self.nf) + self.deblur_L2_conv = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=2, + padding=1) + self.deblur_L3_conv = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=2, + padding=1) + self.upsample = nn.Upsample(scale_factor=2, + mode="bilinear", + align_corners=False, + align_mode=0) + + def forward(self, x): + if self.HR_in: + L1_fea = self.Leaky_relu(self.conv_first_1(x)) + L1_fea = self.Leaky_relu(self.conv_first_2(L1_fea)) + L1_fea = self.Leaky_relu(self.conv_first_3(L1_fea)) + else: + L1_fea = self.Leaky_relu(self.conv_first(x)) + L2_fea = self.deblur_L2_conv(L1_fea) + L2_fea = self.Leaky_relu(L2_fea) + L3_fea = self.deblur_L3_conv(L2_fea) + L3_fea = self.Leaky_relu(L3_fea) + L3_fea = self.RB_L3_1(L3_fea) + L3_fea = self.upsample(L3_fea) + L2_fea = self.RB_L2_1(L2_fea) + L3_fea + L2_fea = self.RB_L2_2(L2_fea) + L2_fea = self.upsample(L2_fea) + L1_fea = self.RB_L1_1(L1_fea) + L1_fea = self.RB_L1_2(L1_fea) + L2_fea + out = self.RB_L1_3(L1_fea) + out = self.RB_L1_4(out) + out = self.RB_L1_5(out) + return out + + +class TSAFusion(nn.Layer): + """Temporal Spatial Attention (TSA) fusion module. + + Temporal: Calculate the correlation between center frame and + neighboring frames; + Spatial: It has 3 pyramid levels, the attention is similar to SFT. + (SFT: Recovering realistic texture in image super-resolution by deep + spatial feature transform.) + + Args: + nf (int): Channel number of middle features. Default: 64. + nframes (int): Number of frames. Default: 5. + center (int): The index of center frame. Default: 2. + """ + def __init__(self, nf=64, nframes=5, center=2): + super(TSAFusion, self).__init__() + self.nf = nf + self.nframes = nframes + self.center = center + self.sigmoid = nn.Sigmoid() + self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1) + self.tAtt_2 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.tAtt_1 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.fea_fusion = nn.Conv2D(in_channels=self.nf * self.nframes, + out_channels=self.nf, + kernel_size=1, + stride=1, + padding=0) + self.sAtt_1 = nn.Conv2D(in_channels=self.nf * self.nframes, + out_channels=self.nf, + kernel_size=1, + stride=1, + padding=0) + self.max_pool = nn.MaxPool2D(3, stride=2, padding=1) + self.avg_pool = nn.AvgPool2D(3, stride=2, padding=1, exclusive=False) + self.sAtt_2 = nn.Conv2D(in_channels=2 * self.nf, + out_channels=self.nf, + kernel_size=1, + stride=1, + padding=0) + self.sAtt_3 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.sAtt_4 = nn.Conv2D( + in_channels=self.nf, + out_channels=self.nf, + kernel_size=1, + stride=1, + padding=0, + ) + self.sAtt_5 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.sAtt_add_1 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=1, + stride=1, + padding=0) + self.sAtt_add_2 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=1, + stride=1, + padding=0) + self.sAtt_L1 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=1, + stride=1, + padding=0) + self.sAtt_L2 = nn.Conv2D( + in_channels=2 * self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1, + ) + self.sAtt_L3 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.upsample = nn.Upsample(scale_factor=2, + mode="bilinear", + align_corners=False, + align_mode=0) + + def forward(self, aligned_fea): + """ + Args: + aligned_feat (Tensor): Aligned features with shape (b, n, c, h, w). + + Returns: + Tensor: Features after TSA with the shape (b, c, h, w). + """ + B, N, C, H, W = aligned_fea.shape + x_center = aligned_fea[:, self.center, :, :, :] + emb_rf = self.tAtt_2(x_center) + emb = aligned_fea.reshape([-1, C, H, W]) + emb = self.tAtt_1(emb) + emb = emb.reshape([-1, N, self.nf, H, W]) + cor_l = [] + for i in range(N): + emb_nbr = emb[:, i, :, :, :] #[B,C,W,H] + cor_tmp = paddle.sum(emb_nbr * emb_rf, axis=1) + cor_tmp = paddle.unsqueeze(cor_tmp, axis=1) + cor_l.append(cor_tmp) + cor_prob = paddle.concat(cor_l, axis=1) #[B,N,H,W] + + cor_prob = self.sigmoid(cor_prob) + cor_prob = paddle.unsqueeze(cor_prob, axis=2) #[B,N,1,H,W] + cor_prob = paddle.expand(cor_prob, [B, N, self.nf, H, W]) #[B,N,C,H,W] + cor_prob = cor_prob.reshape([B, -1, H, W]) + aligned_fea = aligned_fea.reshape([B, -1, H, W]) + aligned_fea = aligned_fea * cor_prob + + fea = self.fea_fusion(aligned_fea) + fea = self.Leaky_relu(fea) + + #spatial fusion + att = self.sAtt_1(aligned_fea) + att = self.Leaky_relu(att) + att_max = self.max_pool(att) + att_avg = self.avg_pool(att) + att_pool = paddle.concat([att_max, att_avg], axis=1) + att = self.sAtt_2(att_pool) + att = self.Leaky_relu(att) + + #pyramid + att_L = self.sAtt_L1(att) + att_L = self.Leaky_relu(att_L) + att_max = self.max_pool(att_L) + att_avg = self.avg_pool(att_L) + att_pool = paddle.concat([att_max, att_avg], axis=1) + att_L = self.sAtt_L2(att_pool) + att_L = self.Leaky_relu(att_L) + att_L = self.sAtt_L3(att_L) + att_L = self.Leaky_relu(att_L) + att_L = self.upsample(att_L) + + att = self.sAtt_3(att) + att = self.Leaky_relu(att) + att = att + att_L + att = self.sAtt_4(att) + att = self.Leaky_relu(att) + att = self.upsample(att) + att = self.sAtt_5(att) + att_add = self.sAtt_add_1(att) + att_add = self.Leaky_relu(att_add) + att_add = self.sAtt_add_2(att_add) + att = self.sigmoid(att) + + fea = fea * att * 2 + att_add + return fea + + +class DCNPack(nn.Layer): + """Modulated deformable conv for deformable alignment. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + def __init__(self, + num_filters=64, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + deformable_groups=8, + extra_offset_mask=True): + super(DCNPack, self).__init__() + self.extra_offset_mask = extra_offset_mask + self.deformable_groups = deformable_groups + self.num_filters = num_filters + if isinstance(kernel_size, int): + self.kernel_size = [kernel_size, kernel_size] + self.conv_offset_mask = nn.Conv2D(in_channels=self.num_filters, + out_channels=self.deformable_groups * + 3 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=stride, + padding=padding) + self.total_channels = self.deformable_groups * 3 * self.kernel_size[ + 0] * self.kernel_size[1] + self.split_channels = self.total_channels // 3 + self.dcn = DeformableConv_dygraph( + num_filters=self.num_filters, + filter_size=self.kernel_size, + dilation=dilation, + stride=stride, + padding=padding, + deformable_groups=self.deformable_groups) + # self.dcn = DeformConv2D(in_channels=self.num_filters,out_channels=self.num_filters,kernel_size=self.kernel_size,stride=stride,padding=padding,dilation=dilation,deformable_groups=self.deformable_groups,groups=1) # to be compiled + self.sigmoid = nn.Sigmoid() + + def forward(self, fea_and_offset): + out = None + x = None + if self.extra_offset_mask: + out = self.conv_offset_mask(fea_and_offset[1]) + x = fea_and_offset[0] + o1 = out[:, 0:self.split_channels, :, :] + o2 = out[:, self.split_channels:2 * self.split_channels, :, :] + mask = out[:, 2 * self.split_channels:, :, :] + offset = paddle.concat([o1, o2], axis=1) + mask = self.sigmoid(mask) + y = self.dcn(x, offset, mask) + return y + + +class PCDAlign(nn.Layer): + """Alignment module using Pyramid, Cascading and Deformable convolution + (PCD). It is used in EDVR. + + Ref: + EDVR: Video Restoration with Enhanced Deformable Convolutional Networks + + Args: + nf (int): Channel number of middle features. Default: 64. + groups (int): Deformable groups. Defaults: 8. + """ + def __init__(self, nf=64, groups=8): + super(PCDAlign, self).__init__() + self.nf = nf + self.groups = groups + self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1) + self.upsample = nn.Upsample(scale_factor=2, + mode="bilinear", + align_corners=False, + align_mode=0) + # Pyramid has three levels: + # L3: level 3, 1/4 spatial size + # L2: level 2, 1/2 spatial size + # L1: level 1, original spatial size + + # L3 + self.PCD_Align_L3_offset_conv1 = nn.Conv2D(in_channels=nf * 2, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_L3_offset_conv2 = nn.Conv2D(in_channels=nf, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_L3_dcn = DCNPack(num_filters=nf, + kernel_size=3, + stride=1, + padding=1, + deformable_groups=groups) + #L2 + self.PCD_Align_L2_offset_conv1 = nn.Conv2D(in_channels=nf * 2, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_L2_offset_conv2 = nn.Conv2D(in_channels=nf * 2, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_L2_offset_conv3 = nn.Conv2D(in_channels=nf, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_L2_dcn = DCNPack(num_filters=nf, + kernel_size=3, + stride=1, + padding=1, + deformable_groups=groups) + self.PCD_Align_L2_fea_conv = nn.Conv2D(in_channels=nf * 2, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + #L1 + self.PCD_Align_L1_offset_conv1 = nn.Conv2D(in_channels=nf * 2, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_L1_offset_conv2 = nn.Conv2D(in_channels=nf * 2, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_L1_offset_conv3 = nn.Conv2D(in_channels=nf, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_L1_dcn = DCNPack(num_filters=nf, + kernel_size=3, + stride=1, + padding=1, + deformable_groups=groups) + self.PCD_Align_L1_fea_conv = nn.Conv2D(in_channels=nf * 2, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + #cascade + self.PCD_Align_cas_offset_conv1 = nn.Conv2D(in_channels=nf * 2, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_cas_offset_conv2 = nn.Conv2D(in_channels=nf, + out_channels=nf, + kernel_size=3, + stride=1, + padding=1) + self.PCD_Align_cascade_dcn = DCNPack(num_filters=nf, + kernel_size=3, + stride=1, + padding=1, + deformable_groups=groups) + + def forward(self, nbr_fea_l, ref_fea_l): + """Align neighboring frame features to the reference frame features. + + Args: + nbr_fea_l (list[Tensor]): Neighboring feature list. It + contains three pyramid levels (L1, L2, L3), + each with shape (b, c, h, w). + ref_fea_l (list[Tensor]): Reference feature list. It + contains three pyramid levels (L1, L2, L3), + each with shape (b, c, h, w). + + Returns: + Tensor: Aligned features. + """ + #L3 + L3_offset = paddle.concat([nbr_fea_l[2], ref_fea_l[2]], axis=1) + L3_offset = self.PCD_Align_L3_offset_conv1(L3_offset) + L3_offset = self.Leaky_relu(L3_offset) + L3_offset = self.PCD_Align_L3_offset_conv2(L3_offset) + L3_offset = self.Leaky_relu(L3_offset) + + L3_fea = self.PCD_Align_L3_dcn([nbr_fea_l[2], L3_offset]) + L3_fea = self.Leaky_relu(L3_fea) + #L2 + L2_offset = paddle.concat([nbr_fea_l[1], ref_fea_l[1]], axis=1) + L2_offset = self.PCD_Align_L2_offset_conv1(L2_offset) + L2_offset = self.Leaky_relu(L2_offset) + L3_offset = self.upsample(L3_offset) + L2_offset = paddle.concat([L2_offset, L3_offset * 2], axis=1) + L2_offset = self.PCD_Align_L2_offset_conv2(L2_offset) + L2_offset = self.Leaky_relu(L2_offset) + L2_offset = self.PCD_Align_L2_offset_conv3(L2_offset) + L2_offset = self.Leaky_relu(L2_offset) + L2_fea = self.PCD_Align_L2_dcn([nbr_fea_l[1], L2_offset]) + L3_fea = self.upsample(L3_fea) + L2_fea = paddle.concat([L2_fea, L3_fea], axis=1) + L2_fea = self.PCD_Align_L2_fea_conv(L2_fea) + L2_fea = self.Leaky_relu(L2_fea) + #L1 + L1_offset = paddle.concat([nbr_fea_l[0], ref_fea_l[0]], axis=1) + L1_offset = self.PCD_Align_L1_offset_conv1(L1_offset) + L1_offset = self.Leaky_relu(L1_offset) + L2_offset = self.upsample(L2_offset) + L1_offset = paddle.concat([L1_offset, L2_offset * 2], axis=1) + L1_offset = self.PCD_Align_L1_offset_conv2(L1_offset) + L1_offset = self.Leaky_relu(L1_offset) + L1_offset = self.PCD_Align_L1_offset_conv3(L1_offset) + L1_offset = self.Leaky_relu(L1_offset) + L1_fea = self.PCD_Align_L1_dcn([nbr_fea_l[0], L1_offset]) + L2_fea = self.upsample(L2_fea) + L1_fea = paddle.concat([L1_fea, L2_fea], axis=1) + L1_fea = self.PCD_Align_L1_fea_conv(L1_fea) + #cascade + offset = paddle.concat([L1_fea, ref_fea_l[0]], axis=1) + offset = self.PCD_Align_cas_offset_conv1(offset) + offset = self.Leaky_relu(offset) + offset = self.PCD_Align_cas_offset_conv2(offset) + offset = self.Leaky_relu(offset) + L1_fea = self.PCD_Align_cascade_dcn([L1_fea, offset]) + L1_fea = self.Leaky_relu(L1_fea) + + return L1_fea + + +@GENERATORS.register() +class EDVRNet(nn.Layer): + """EDVR network structure for video super-resolution. + + Now only support X4 upsampling factor. + Paper: + EDVR: Video Restoration with Enhanced Deformable Convolutional Networks + + Args: + in_nf (int): Channel number of input image. Default: 3. + out_nf (int): Channel number of output image. Default: 3. + scale_factor (int): Scale factor from input image to output image. Default: 4. + nf (int): Channel number of intermediate features. Default: 64. + nframes (int): Number of input frames. Default: 5. + groups (int): Deformable groups. Defaults: 8. + front_RBs (int): Number of blocks for feature extraction. Default: 5. + back_RBs (int): Number of blocks for reconstruction. Default: 10. + center (int): The index of center frame. Frame counting from 0. Default: None. + predeblur (bool): Whether has predeblur module. Default: False. + HR_in (bool): Whether the input has high resolution. Default: False. + with_tsa (bool): Whether has TSA module. Default: True. + TSA_only (bool): Whether only use TSA module. Default: False. + """ + def __init__(self, + in_nf=3, + out_nf=3, + scale_factor=4, + nf=64, + nframes=5, + groups=8, + front_RBs=5, + back_RBs=10, + center=None, + predeblur=False, + HR_in=False, + w_TSA=True, + TSA_only=False): + super(EDVRNet, self).__init__() + self.in_nf = in_nf + self.out_nf = out_nf + self.scale_factor = scale_factor + self.nf = nf + self.nframes = nframes + self.groups = groups + self.front_RBs = front_RBs + self.back_RBs = back_RBs + self.center = nframes // 2 if center is None else center + self.predeblur = True if predeblur else False + self.HR_in = True if HR_in else False + self.w_TSA = True if w_TSA else False + + self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1) + if self.predeblur: + self.pre_deblur = PredeblurResNetPyramid(in_nf=self.in_nf, + nf=self.nf, + HR_in=self.HR_in) + self.cov_1 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=1, + stride=1) + else: + if self.HR_in: + self.conv_first_1 = nn.Conv2D(in_channels=self.in_nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.conv_first_2 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=2, + padding=1) + self.conv_first_3 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=2, + padding=1) + else: + self.conv_first = nn.Conv2D(in_channels=self.in_nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + + #feature extraction module + self.feature_extractor = MakeMultiBlocks(ResidualBlockNoBN, + self.front_RBs, self.nf) + self.fea_L2_conv1 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=2, + padding=1) + self.fea_L2_conv2 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.fea_L3_conv1 = nn.Conv2D( + in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=2, + padding=1, + ) + self.fea_L3_conv2 = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + + #PCD alignment module + self.PCDModule = PCDAlign(nf=self.nf, groups=self.groups) + + #TSA Fusion module + if self.w_TSA: + self.TSAModule = TSAFusion(nf=self.nf, + nframes=self.nframes, + center=self.center) + else: + self.TSAModule = nn.Conv2D(in_channels=self.nframes * self.nf, + out_channels=self.nf, + kernel_size=1, + stride=1) + + #reconstruction module + self.reconstructor = MakeMultiBlocks(ResidualBlockNoBN, self.back_RBs, + self.nf) + self.upconv1 = nn.Conv2D(in_channels=self.nf, + out_channels=4 * self.nf, + kernel_size=3, + stride=1, + padding=1) + self.pixel_shuffle = nn.PixelShuffle(2) + self.upconv2 = nn.Conv2D(in_channels=self.nf, + out_channels=4 * self.nf, + kernel_size=3, + stride=1, + padding=1) + self.HRconv = nn.Conv2D(in_channels=self.nf, + out_channels=self.nf, + kernel_size=3, + stride=1, + padding=1) + self.conv_last = nn.Conv2D(in_channels=self.nf, + out_channels=self.out_nf, + kernel_size=3, + stride=1, + padding=1) + self.upsample = nn.Upsample(scale_factor=self.scale_factor, + mode="bilinear", + align_corners=False, + align_mode=0) + + def forward(self, x): + """ + Args: + x (Tensor): Input features with shape (b, n, c, h, w). + + Returns: + Tensor: Features after EDVR with the shape (b, c, scale_factor*h, scale_factor*w). + """ + B, N, C, H, W = x.shape + x_center = x[:, self.center, :, :, :] + L1_fea = x.reshape([-1, C, H, W]) #[B*N,C,W,H] + if self.predeblur: + L1_fea = self.pre_deblur(L1_fea) + L1_fea = self.cov_1(L1_fea) + if self.HR_in: + H, W = H // self.scale_factor, W // self.scale_factor + else: + if self.HR_in: + L1_fea = self.conv_first_1(L1_fea) + L1_fea = self.Leaky_relu(L1_fea) + L1_fea = self.conv_first_2(L1_fea) + L1_fea = self.Leaky_relu(L1_fea) + L1_fea = self.conv_first_3(L1_fea) + L1_fea = self.Leaky_relu(L1_fea) + H = H // self.scale_factor + W = W // self.scale_factor + else: + L1_fea = self.conv_first(L1_fea) + L1_fea = self.Leaky_relu(L1_fea) + + # feature extraction and create Pyramid + L1_fea = self.feature_extractor(L1_fea) + # L2 + L2_fea = self.fea_L2_conv1(L1_fea) + L2_fea = self.Leaky_relu(L2_fea) + L2_fea = self.fea_L2_conv2(L2_fea) + L2_fea = self.Leaky_relu(L2_fea) + # L3 + L3_fea = self.fea_L3_conv1(L2_fea) + L3_fea = self.Leaky_relu(L3_fea) + L3_fea = self.fea_L3_conv2(L3_fea) + L3_fea = self.Leaky_relu(L3_fea) + + L1_fea = L1_fea.reshape([-1, N, self.nf, H, W]) + L2_fea = L2_fea.reshape([-1, N, self.nf, H // 2, W // 2]) + L3_fea = L3_fea.reshape([-1, N, self.nf, H // 4, W // 4]) + + # pcd align + ref_fea_l = [ + L1_fea[:, self.center, :, :, :], L2_fea[:, self.center, :, :, :], + L3_fea[:, self.center, :, :, :] + ] + aligned_fea = [] + for i in range(N): + nbr_fea_l = [ + L1_fea[:, i, :, :, :], L2_fea[:, i, :, :, :], L3_fea[:, + i, :, :, :] + ] + aligned_fea.append(self.PCDModule(nbr_fea_l, ref_fea_l)) + + # TSA Fusion + aligned_fea = paddle.stack(aligned_fea, axis=1) # [B, N, C, H, W] + fea = None + if not self.w_TSA: + aligned_fea = aligned_fea.reshape([B, -1, H, W]) + fea = self.TSAModule(aligned_fea) # [B, N, C, H, W] + + #Reconstruct + out = self.reconstructor(fea) + + out = self.upconv1(out) + out = self.pixel_shuffle(out) + out = self.Leaky_relu(out) + out = self.upconv2(out) + out = self.pixel_shuffle(out) + out = self.Leaky_relu(out) + + out = self.HRconv(out) + out = self.Leaky_relu(out) + out = self.conv_last(out) + + if self.HR_in: + base = x_center + else: + base = self.upsample(x_center) + out += base + return out diff --git a/ppgan/modules/dcn.py b/ppgan/modules/dcn.py new file mode 100644 index 0000000..cf9a5a4 --- /dev/null +++ b/ppgan/modules/dcn.py @@ -0,0 +1,147 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype +from paddle.fluid.layers import deformable_conv +from paddle.fluid import core, layers +from paddle.fluid.layers import nn, utils +from paddle.nn import Layer +from paddle.fluid.initializer import Normal +from paddle.common_ops_import import * + + +class DeformConv2D(Layer): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + deformable_groups=1, + groups=1, + weight_attr=None, + bias_attr=None): + super(DeformConv2D, self).__init__() + assert weight_attr is not False, "weight_attr should not be False in Conv." + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self._deformable_groups = deformable_groups + self._groups = groups + self._in_channels = in_channels + self._out_channels = out_channels + self.padding = padding + self.stride = stride + self._channel_dim = 1 + + self._stride = utils.convert_to_list(stride, 2, 'stride') + self._dilation = utils.convert_to_list(dilation, 2, 'dilation') + self._kernel_size = utils.convert_to_list(kernel_size, 2, 'kernel_size') + + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups.") + + self._padding = utils.convert_to_list(padding, 2, 'padding') + + filter_shape = [out_channels, in_channels // groups] + self._kernel_size + + def _get_default_param_initializer(): + filter_elem_num = np.prod(self._kernel_size) * self._in_channels + std = (2.0 / filter_elem_num)**0.5 + return Normal(0.0, std, 0) + + self.weight = self.create_parameter( + shape=filter_shape, + attr=self._weight_attr, + default_initializer=_get_default_param_initializer()) + self.bias = self.create_parameter( + attr=self._bias_attr, shape=[self._out_channels], is_bias=True) + + def forward(self, x, offset, mask): + out = deform_conv2d( + x=x, + offset=offset, + mask=mask, + weight=self.weight, + bias=self.bias, + stride=self._stride, + padding=self._padding, + dilation=self._dilation, + deformable_groups=self._deformable_groups, + groups=self._groups, + ) + return out + + +def deform_conv2d(x, + offset, + weight, + mask, + bias=None, + stride=1, + padding=0, + dilation=1, + deformable_groups=1, + groups=1, + name=None): + + stride = utils.convert_to_list(stride, 2, 'stride') + padding = utils.convert_to_list(padding, 2, 'padding') + dilation = utils.convert_to_list(dilation, 2, 'dilation') + + use_deform_conv2d_v1 = True if mask is None else False + + if in_dygraph_mode(): + attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'deformable_groups',deformable_groups, + 'groups', groups, 'im2col_step', 1) + if use_deform_conv2d_v1: + op_type = 'deformable_conv_v1' + pre_bias = getattr(core.ops, op_type)(x, offset, weight, *attrs) + else: + op_type = 'deformable_conv' + pre_bias = getattr(core.ops, op_type)(x, offset, mask, weight, + *attrs) + if bias is not None: + out = nn.elementwise_add(pre_bias, bias, axis=1) + else: + out = pre_bias + return out + + +class DeformableConv_dygraph(Layer): + def __init__(self,num_filters,filter_size,dilation, + stride,padding,deformable_groups=1,groups=1): + super(DeformableConv_dygraph, self).__init__() + self.num_filters = num_filters + self.filter_size = filter_size + self.dilation = dilation + self.stride = stride + self.padding = padding + self.deformable_groups = deformable_groups + self.groups = groups + self.defor_conv = DeformConv2D(in_channels=self.num_filters, out_channels=self.num_filters, + kernel_size=self.filter_size, stride=self.stride, padding=self.padding, + dilation=self.dilation, deformable_groups=self.deformable_groups, groups=self.groups, weight_attr=None, bias_attr=None) + + + def forward(self,*input): + x = input[0] + offset = input[1] + mask = input[2] + out = self.defor_conv(x, offset, mask) + return out diff --git a/ppgan/modules/init.py b/ppgan/modules/init.py index 91dfd06..2730337 100644 --- a/ppgan/modules/init.py +++ b/ppgan/modules/init.py @@ -324,3 +324,10 @@ def init_weights(net, logger = get_logger() logger.debug('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function + +def reset_parameters(m): + kaiming_uniform_(m.weight, a=math.sqrt(5)) + if m.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + uniform_(m.bias, -bound, bound) -- GitLab