From acd1615eab107b8108ec9297a2a5c8241c0d2f53 Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Sat, 6 Nov 2021 09:40:50 +0800 Subject: [PATCH] Add BasicVSR++ (#383) * add BasicVSR++ --- configs/basicvsr++_reds.yaml | 93 ++++ configs/basicvsr_reds.yaml | 1 + configs/iconvsr_reds.yaml | 1 + .../en_US/tutorials/video_super_resolution.md | 15 +- .../zh_CN/tutorials/video_super_resolution.md | 14 +- ppgan/models/basicvsr_model.py | 5 +- ppgan/models/generators/__init__.py | 1 + ppgan/models/generators/basicvsr.py | 71 ++- ppgan/models/generators/basicvsr_plus_plus.py | 439 ++++++++++++++++++ 9 files changed, 633 insertions(+), 7 deletions(-) create mode 100644 configs/basicvsr++_reds.yaml create mode 100644 ppgan/models/generators/basicvsr_plus_plus.py diff --git a/configs/basicvsr++_reds.yaml b/configs/basicvsr++_reds.yaml new file mode 100644 index 0000000..61a8e50 --- /dev/null +++ b/configs/basicvsr++_reds.yaml @@ -0,0 +1,93 @@ +total_iters: 600000 +output_dir: output_dir +find_unused_parameters: True +checkpoints_dir: checkpoints +use_dataset: True +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: BasicVSRModel + fix_iter: 5000 + lr_mult: 0.25 + generator: + name: BasicVSRPlusPlus + mid_channels: 64 + num_blocks: 7 + is_low_res_input: True + pixel_criterion: + name: CharbonnierLoss + reduction: mean + +dataset: + train: + name: RepeatDataset + times: 1000 + num_workers: 4 + batch_size: 2 #4 gpus + dataset: + name: SRREDSMultipleGTDataset + mode: train + lq_folder: data/REDS/train_sharp_bicubic/X4 + gt_folder: data/REDS/train_sharp/X4 + crop_size: 256 + interval_list: [1] + random_reverse: False + number_frames: 30 + use_flip: True + use_rot: True + scale: 4 + val_partition: REDS4 + + test: + name: SRREDSMultipleGTDataset + mode: test + lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4 + gt_folder: data/REDS/REDS4_test_sharp/X4 + interval_list: [1] + random_reverse: False + number_frames: 100 + use_flip: False + use_rot: False + scale: 4 + val_partition: REDS4 + num_workers: 0 + batch_size: 1 + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 1e-4 + periods: [600000] + restart_weights: [1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 0 + test_y_channel: False + ssim: + name: SSIM + crop_border: 0 + test_y_channel: False + +log_config: + interval: 10 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/configs/basicvsr_reds.yaml b/configs/basicvsr_reds.yaml index b708197..566a4b2 100644 --- a/configs/basicvsr_reds.yaml +++ b/configs/basicvsr_reds.yaml @@ -10,6 +10,7 @@ min_max: model: name: BasicVSRModel fix_iter: 5000 + lr_mult: 0.125 generator: name: BasicVSRNet mid_channels: 64 diff --git a/configs/iconvsr_reds.yaml b/configs/iconvsr_reds.yaml index 4a959c5..314d29a 100644 --- a/configs/iconvsr_reds.yaml +++ b/configs/iconvsr_reds.yaml @@ -10,6 +10,7 @@ min_max: model: name: BasicVSRModel fix_iter: 5000 + lr_mult: 0.125 generator: name: IconVSR mid_channels: 64 diff --git a/docs/en_US/tutorials/video_super_resolution.md b/docs/en_US/tutorials/video_super_resolution.md index 0482571..4c482fa 100644 --- a/docs/en_US/tutorials/video_super_resolution.md +++ b/docs/en_US/tutorials/video_super_resolution.md @@ -3,7 +3,7 @@ ## 1.1 Principle - Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf). + Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf).[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf). [EDVR](https://arxiv.org/pdf/1905.02716.pdf) wins the champions and outperforms the second place by a large margin in all four tracks in the NTIRE19 video restoration and enhancement challenges. The main difficulties of video super-resolution from two aspects: (1) how to align multiple frames given large motions, and (2) how to effectively fuse different frames with diverse motion and blur. First, to handle large motions, EDVR devise a Pyramid, Cascading and Deformable (PCD) alignment module, in which frame alignment is done at the feature level using deformable convolutions in a coarse-to-fine manner. Second, EDVR propose a Temporal and Spatial Attention (TSA) fusion module, in which attention is applied both temporally and spatially, so as to emphasize important features for subsequent restoration. @@ -79,6 +79,7 @@ The metrics are PSNR / SSIM. | EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 | | BasicVSR_x4 | 31.4325 / 0.8913 | | IconVSR_x4 | 31.6882 / 0.8950 | +| BasicVSR++_x4 | 32.4018 / 0.9071 | ## 1.4 Model Download @@ -92,6 +93,7 @@ The metrics are PSNR / SSIM. | EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams) | BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams) | IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams) +| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams) @@ -120,3 +122,14 @@ The metrics are PSNR / SSIM. year = {2021} } ``` + +- 3. [BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment](https://arxiv.org/pdf/2104.13371v1.pdf) + + ``` + @article{chan2021basicvsr++, + author = {Chan, Kelvin C.K. and Zhou, Shangchen and Xu, Xiangyu and Loy, Chen Change}, + title = {BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment}, + booktitle = {arXiv preprint arXiv:2104.13371}, + year = {2021} + } + ``` diff --git a/docs/zh_CN/tutorials/video_super_resolution.md b/docs/zh_CN/tutorials/video_super_resolution.md index 4283db3..0e88959 100644 --- a/docs/zh_CN/tutorials/video_super_resolution.md +++ b/docs/zh_CN/tutorials/video_super_resolution.md @@ -3,7 +3,7 @@ ## 1.1 原理介绍 - 视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf). + 视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf),[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf). [EDVR](https://arxiv.org/pdf/1905.02716.pdf)模型在NTIRE19视频恢复和增强挑战赛的四个赛道中都赢得了冠军,并以巨大的优势超过了第二名。视频超分的主要难点在于(1)如何在给定大运动的情况下对齐多个帧;(2)如何有效地融合具有不同运动和模糊的不同帧。首先,为了处理大的运动,EDVR模型设计了一个金字塔级联的可变形(PCD)对齐模块,在该模块中,从粗到精的可变形卷积被使用来进行特征级的帧对齐。其次,EDVR使用了时空注意力(TSA)融合模块,该模块在时间和空间上同时应用注意力机制,以强调后续恢复的重要特征。 @@ -75,6 +75,7 @@ | EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 | | BasicVSR_x4 | 31.4325 / 0.8913 | | IconVSR_x4 | 31.6882 / 0.8950 | +| BasicVSR++_x4 | 32.4018 / 0.9071 | ## 1.4 模型下载 | 模型 | 数据集 | 下载地址 | @@ -87,6 +88,7 @@ | EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams) | BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams) | IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams) +| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams) @@ -113,3 +115,13 @@ year = {2021} } ``` +- 3. [BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment](https://arxiv.org/pdf/2104.13371v1.pdf) + + ``` + @article{chan2021basicvsr++, + author = {Chan, Kelvin C.K. and Zhou, Shangchen and Xu, Xiangyu and Loy, Chen Change}, + title = {BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment}, + booktitle = {arXiv preprint arXiv:2104.13371}, + year = {2021} + } + ``` diff --git a/ppgan/models/basicvsr_model.py b/ppgan/models/basicvsr_model.py index faa9911..b2db128 100644 --- a/ppgan/models/basicvsr_model.py +++ b/ppgan/models/basicvsr_model.py @@ -29,7 +29,7 @@ class BasicVSRModel(BaseSRModel): Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021 """ - def __init__(self, generator, fix_iter, pixel_criterion=None): + def __init__(self, generator, fix_iter, lr_mult, pixel_criterion=None): """Initialize the BasicVSR class. Args: @@ -41,6 +41,7 @@ class BasicVSRModel(BaseSRModel): self.fix_iter = fix_iter self.current_iter = 1 self.flag = True + self.lr_mult = lr_mult init_basicvsr_weight(self.nets['generator']) def setup_input(self, input): @@ -65,7 +66,7 @@ class BasicVSRModel(BaseSRModel): for name, param in self.nets['generator'].named_parameters(): param.trainable = True if 'spynet' in name: - param.optimize_attr['learning_rate'] = 0.125 + param.optimize_attr['learning_rate'] = self.lr_mult self.flag = False for net in self.nets.values(): net.find_unused_parameters = False diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 278afba..1a4feb6 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -35,3 +35,4 @@ from .mpr import MPRNet from .iconvsr import IconVSR from .gpen import GPEN from .pan import PAN +from .basicvsr_plus_plus import BasicVSRPlusPlus diff --git a/ppgan/models/generators/basicvsr.py b/ppgan/models/generators/basicvsr.py index f307e85..48bc7cc 100644 --- a/ppgan/models/generators/basicvsr.py +++ b/ppgan/models/generators/basicvsr.py @@ -1,14 +1,13 @@ # Copyright (c) MMEditing Authors. -import paddle - import numpy as np +import paddle import paddle.nn as nn import paddle.nn.functional as F +from paddle.vision.ops import DeformConv2D from ...utils.download import get_path_from_url from ...modules.init import kaiming_normal_, constant_ - from .builder import GENERATORS @@ -607,3 +606,69 @@ class BasicVSRNet(nn.Layer): outputs[i] = out return paddle.stack(outputs, axis=1) + + +class SecondOrderDeformableAlignment(nn.Layer): + """Second-order deformable alignment module. + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + deformable_groups (int). + """ + def __init__(self, + in_channels=128, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + deformable_groups=16): + super(SecondOrderDeformableAlignment, self).__init__() + + self.conv_offset = nn.Sequential( + nn.Conv2D(3 * out_channels + 4, out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2D(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2D(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2D(out_channels, 27 * deformable_groups, 3, 1, 1), + ) + self.dcn = DeformConv2D(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + deformable_groups=deformable_groups) + self.init_offset() + + def init_offset(self): + constant_(self.conv_offset[-1].weight, 0) + constant_(self.conv_offset[-1].bias, 0) + + def forward(self, x, extra_feat, flow_1, flow_2): + extra_feat = paddle.concat([extra_feat, flow_1, flow_2], axis=1) + out = self.conv_offset(extra_feat) + o1, o2, mask = paddle.chunk(out, 3, axis=1) + + # offset + offset = 10 * paddle.tanh(paddle.concat((o1, o2), axis=1)) + offset_1, offset_2 = paddle.chunk(offset, 2, axis=1) + offset_1 = offset_1 + flow_1.flip(1).tile( + [1, offset_1.shape[1] // 2, 1, 1]) + offset_2 = offset_2 + flow_2.flip(1).tile( + [1, offset_2.shape[1] // 2, 1, 1]) + offset = paddle.concat([offset_1, offset_2], axis=1) + + # mask + mask = F.sigmoid(mask) + + out = self.dcn(x, offset, mask) + return out diff --git a/ppgan/models/generators/basicvsr_plus_plus.py b/ppgan/models/generators/basicvsr_plus_plus.py new file mode 100644 index 0000000..b078325 --- /dev/null +++ b/ppgan/models/generators/basicvsr_plus_plus.py @@ -0,0 +1,439 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ...utils.download import get_path_from_url +from .basicvsr import PixelShufflePack, flow_warp, SPyNet, \ + ResidualBlocksWithInputConv, SecondOrderDeformableAlignment +from .builder import GENERATORS + + +@GENERATORS.register() +class BasicVSRPlusPlus(nn.Layer): + """BasicVSR++ network structure. + Support either x4 upsampling or same size output. Since DCN is used in this + model, it can only be used with CUDA enabled. + Paper: + BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation + and Alignment + + Adapted from 'https://github.com/open-mmlab/mmediting' + 'mmediting/blob/master/mmedit/models/backbones/sr_backbones/basicvsr_pp.py' + Copyright (c) MMEditing Authors. + + Args: + mid_channels (int, optional): Channel number of the intermediate + features. Default: 64. + num_blocks (int, optional): The number of residual blocks in each + propagation branch. Default: 7. + is_low_res_input (bool, optional): Whether the input is low-resolution + or not. If False, the output resolution is equal to the input + resolution. Default: True. + """ + def __init__(self, mid_channels=64, num_blocks=7, is_low_res_input=True): + + super().__init__() + + self.mid_channels = mid_channels + self.is_low_res_input = is_low_res_input + + # optical flow + self.spynet = SPyNet() + weight_path = get_path_from_url( + 'https://paddlegan.bj.bcebos.com/models/spynet.pdparams') + self.spynet.set_state_dict(paddle.load(weight_path)) + + # feature extraction module + if is_low_res_input: + self.feat_extract = ResidualBlocksWithInputConv(3, mid_channels, 5) + else: + self.feat_extract = nn.Sequential( + nn.Conv2D(3, mid_channels, 3, 2, 1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2D(mid_channels, mid_channels, 3, 2, 1), + nn.LeakyReLU(negative_slope=0.1), + ResidualBlocksWithInputConv(mid_channels, mid_channels, 5)) + + # propagation branches + self.deform_align_backward_1 = SecondOrderDeformableAlignment( + 2 * mid_channels, mid_channels, 3, padding=1, deformable_groups=16) + self.deform_align_forward_1 = SecondOrderDeformableAlignment( + 2 * mid_channels, mid_channels, 3, padding=1, deformable_groups=16) + self.deform_align_backward_2 = SecondOrderDeformableAlignment( + 2 * mid_channels, mid_channels, 3, padding=1, deformable_groups=16) + self.deform_align_forward_2 = SecondOrderDeformableAlignment( + 2 * mid_channels, mid_channels, 3, padding=1, deformable_groups=16) + self.backbone_backward_1 = ResidualBlocksWithInputConv( + 2 * mid_channels, mid_channels, num_blocks) + self.backbone_forward_1 = ResidualBlocksWithInputConv( + 3 * mid_channels, mid_channels, num_blocks) + self.backbone_backward_2 = ResidualBlocksWithInputConv( + 4 * mid_channels, mid_channels, num_blocks) + self.backbone_forward_2 = ResidualBlocksWithInputConv( + 5 * mid_channels, mid_channels, num_blocks) + + # upsampling module + self.reconstruction = ResidualBlocksWithInputConv( + 5 * mid_channels, mid_channels, 5) + self.upsample1 = PixelShufflePack(mid_channels, + mid_channels, + 2, + upsample_kernel=3) + self.upsample2 = PixelShufflePack(mid_channels, + 64, + 2, + upsample_kernel=3) + self.conv_hr = nn.Conv2D(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2D(64, 3, 3, 1, 1) + self.img_upsample = nn.Upsample(scale_factor=4, + mode='bilinear', + align_corners=False) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1) + + def check_if_mirror_extended(self, lrs): + """Check whether the input is a mirror-extended sequence. + If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the + (t-1-i)-th frame. + Args: + lqs (tensor): Input LR images with shape (n, t, c, h, w) + """ + + with paddle.no_grad(): + self.is_mirror_extended = False + if lrs.shape[1] % 2 == 0: + lrs_1, lrs_2 = paddle.chunk(lrs, 2, axis=1) + lrs_2 = paddle.flip(lrs_2, [1]) + if paddle.norm(lrs_1 - lrs_2) == 0: + self.is_mirror_extended = True + + def compute_flow(self, lrs): + """Compute optical flow using SPyNet for feature alignment. + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + Args: + lqs (tensor): Input LR images with shape (n, t, c, h, w) + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lrs.shape + + lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w]) + lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w]) + + flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w]) + + if self.is_mirror_extended: + flows_forward = flows_backward.flip(1) + else: + flows_forward = self.spynet(lrs_2, + lrs_1).reshape([n, t - 1, 2, h, w]) + + return flows_forward, flows_backward + + def upsample(self, lqs, feats): + """Compute the output image given the features. + Args: + lqs (tensor): Input LR images with shape (n, t, c, h, w). + feats (dict): The features from the propgation branches. + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + outputs = [] + num_outputs = len(feats['spatial']) + + mapping_idx = list(range(0, num_outputs)) + mapping_idx += mapping_idx[::-1] + + for i in range(0, lqs.shape[1]): + hr = [feats[k].pop(0) for k in feats if k != 'spatial'] + hr.insert(0, feats['spatial'][mapping_idx[i]]) + hr = paddle.concat(hr, axis=1) + + hr = self.reconstruction(hr) + hr = self.lrelu(self.upsample1(hr)) + hr = self.lrelu(self.upsample2(hr)) + hr = self.lrelu(self.conv_hr(hr)) + hr = self.conv_last(hr) + if self.is_low_res_input: + hr += self.img_upsample(lqs[:, i, :, :, :]) + else: + hr += lqs[:, i, :, :, :] + + outputs.append(hr) + + return paddle.stack(outputs, axis=1) + + def forward(self, lqs): + """Forward function for BasicVSR++. + Args: + lqs (Tensor): Input LR sequence with shape (n, t, c, h, w). + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + n, t, c, h, w = lqs.shape + + if self.is_low_res_input: + lqs_downsample = lqs + else: + lqs_downsample = F.interpolate(lqs.reshape([-1, c, h, w]), + scale_factor=0.25, + mode='bicubic').reshape( + [n, t, c, h // 4, w // 4]) + + # check whether the input is an extended sequence + self.check_if_mirror_extended(lqs) + + feats = {} + feats_ = self.feat_extract(lqs.reshape([-1, c, h, w])) + h, w = feats_.shape[2:] + feats_ = feats_.reshape([n, t, -1, h, w]) + feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)] + + # compute optical flow using the low-res inputs + assert lqs_downsample.shape[3] >= 64 and lqs_downsample.shape[4] >= 64, ( + 'The height and width of low-res inputs must be at least 64, ' + f'but got {h} and {w}.') + flows_forward, flows_backward = self.compute_flow(lqs_downsample) + + # feature propgation + + # backward_1 + feats['backward_1'] = [] + flows = flows_backward + + n, t, _, h, w = flows.shape + + frame_idx = range(t, -1, -1) + flow_idx = range(t, -1, -1) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + feat_prop = paddle.zeros([n, self.mid_channels, h, w]) + + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + + if i > 0: + flow_n1 = flows[:, flow_idx[i], :, :, :] + cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1])) + + # initialize second-order features + feat_n2 = paddle.zeros_like(feat_prop) + flow_n2 = paddle.zeros_like(flow_n1) + cond_n2 = paddle.zeros_like(cond_n1) + + if i > 1: # second-order features + feat_n2 = feats['backward_1'][-2] + + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + + flow_n2 = flow_n1 + flow_warp( + flow_n2, flow_n1.transpose([0, 2, 3, 1])) + + cond_n2 = flow_warp(feat_n2, flow_n2.transpose([0, 2, 3, + 1])) + + # flow-guided deformable convolution + cond = paddle.concat([cond_n1, feat_current, cond_n2], axis=1) + feat_prop = paddle.concat([feat_prop, feat_n2], axis=1) + + feat_prop = self.deform_align_backward_1( + feat_prop, cond, flow_n1, flow_n2) + + # concatenate and residual blocks + feat = [feat_current] + [ + feats[k][idx] + for k in feats if k not in ['spatial', 'backward_1'] + ] + [feat_prop] + + feat = paddle.concat(feat, axis=1) + feat_prop = feat_prop + self.backbone_backward_1(feat) + feats['backward_1'].append(feat_prop) + + feats['backward_1'] = feats['backward_1'][::-1] + + # forward_1 + feats['forward_1'] = [] + flows = flows_forward + + n, t, _, h, w = flows.shape + + frame_idx = range(0, t + 1) + flow_idx = range(-1, t) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + feat_prop = paddle.zeros([n, self.mid_channels, h, w]) + + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + + if i > 0: + flow_n1 = flows[:, flow_idx[i], :, :, :] + cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1])) + + # initialize second-order features + feat_n2 = paddle.zeros_like(feat_prop) + flow_n2 = paddle.zeros_like(flow_n1) + cond_n2 = paddle.zeros_like(cond_n1) + + if i > 1: # second-order features + feat_n2 = feats['forward_1'][-2] + + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + + flow_n2 = flow_n1 + flow_warp( + flow_n2, flow_n1.transpose([0, 2, 3, 1])) + + cond_n2 = flow_warp(feat_n2, flow_n2.transpose([0, 2, 3, + 1])) + + # flow-guided deformable convolution + cond = paddle.concat([cond_n1, feat_current, cond_n2], axis=1) + feat_prop = paddle.concat([feat_prop, feat_n2], axis=1) + + feat_prop = self.deform_align_forward_1(feat_prop, cond, + flow_n1, flow_n2) + + # concatenate and residual blocks + feat = [feat_current] + [ + feats[k][idx] + for k in feats if k not in ['spatial', 'forward_1'] + ] + [feat_prop] + + feat = paddle.concat(feat, axis=1) + feat_prop = feat_prop + self.backbone_forward_1(feat) + feats['forward_1'].append(feat_prop) + + # backward_2 + feats['backward_2'] = [] + flows = flows_backward + + n, t, _, h, w = flows.shape + + frame_idx = range(t, -1, -1) + flow_idx = range(t, -1, -1) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + feat_prop = paddle.zeros([n, self.mid_channels, h, w]) + + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + + if i > 0: + flow_n1 = flows[:, flow_idx[i], :, :, :] + cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1])) + + # initialize second-order features + feat_n2 = paddle.zeros_like(feat_prop) + flow_n2 = paddle.zeros_like(flow_n1) + cond_n2 = paddle.zeros_like(cond_n1) + + if i > 1: # second-order features + feat_n2 = feats['backward_2'][-2] + + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + + flow_n2 = flow_n1 + flow_warp( + flow_n2, flow_n1.transpose([0, 2, 3, 1])) + + cond_n2 = flow_warp(feat_n2, flow_n2.transpose([0, 2, 3, + 1])) + + # flow-guided deformable convolution + cond = paddle.concat([cond_n1, feat_current, cond_n2], axis=1) + feat_prop = paddle.concat([feat_prop, feat_n2], axis=1) + + feat_prop = self.deform_align_backward_2( + feat_prop, cond, flow_n1, flow_n2) + + # concatenate and residual blocks + feat = [feat_current] + [ + feats[k][idx] + for k in feats if k not in ['spatial', 'backward_2'] + ] + [feat_prop] + + feat = paddle.concat(feat, axis=1) + feat_prop = feat_prop + self.backbone_backward_2(feat) + feats['backward_2'].append(feat_prop) + + feats['backward_2'] = feats['backward_2'][::-1] + + # forward_2 + feats['forward_2'] = [] + flows = flows_forward + + n, t, _, h, w = flows.shape + + frame_idx = range(0, t + 1) + flow_idx = range(-1, t) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + feat_prop = paddle.zeros([n, self.mid_channels, h, w]) + + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + + if i > 0: + flow_n1 = flows[:, flow_idx[i], :, :, :] + cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1])) + + # initialize second-order features + feat_n2 = paddle.zeros_like(feat_prop) + flow_n2 = paddle.zeros_like(flow_n1) + cond_n2 = paddle.zeros_like(cond_n1) + + if i > 1: # second-order features + feat_n2 = feats['forward_2'][-2] + + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + + flow_n2 = flow_n1 + flow_warp( + flow_n2, flow_n1.transpose([0, 2, 3, 1])) + + cond_n2 = flow_warp(feat_n2, flow_n2.transpose([0, 2, 3, + 1])) + + # flow-guided deformable convolution + cond = paddle.concat([cond_n1, feat_current, cond_n2], axis=1) + feat_prop = paddle.concat([feat_prop, feat_n2], axis=1) + + feat_prop = self.deform_align_forward_2(feat_prop, cond, + flow_n1, flow_n2) + + # concatenate and residual blocks + feat = [feat_current] + [ + feats[k][idx] + for k in feats if k not in ['spatial', 'forward_2'] + ] + [feat_prop] + + feat = paddle.concat(feat, axis=1) + feat_prop = feat_prop + self.backbone_forward_2(feat) + feats['forward_2'].append(feat_prop) + + return self.upsample(lqs, feats) -- GitLab