未验证 提交 acd1615e 编写于 作者: W wangna11BD 提交者: GitHub

Add BasicVSR++ (#383)

* add BasicVSR++
上级 8c7878d9
total_iters: 600000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: BasicVSRModel
fix_iter: 5000
lr_mult: 0.25
generator:
name: BasicVSRPlusPlus
mid_channels: 64
num_blocks: 7
is_low_res_input: True
pixel_criterion:
name: CharbonnierLoss
reduction: mean
dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4
batch_size: 2 #4 gpus
dataset:
name: SRREDSMultipleGTDataset
mode: train
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 30
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
test:
name: SRREDSMultipleGTDataset
mode: test
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 1e-4
periods: [600000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
......@@ -10,6 +10,7 @@ min_max:
model:
name: BasicVSRModel
fix_iter: 5000
lr_mult: 0.125
generator:
name: BasicVSRNet
mid_channels: 64
......
......@@ -10,6 +10,7 @@ min_max:
model:
name: BasicVSRModel
fix_iter: 5000
lr_mult: 0.125
generator:
name: IconVSR
mid_channels: 64
......
......@@ -3,7 +3,7 @@
## 1.1 Principle
Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf).
Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf).[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf).
[EDVR](https://arxiv.org/pdf/1905.02716.pdf) wins the champions and outperforms the second place by a large margin in all four tracks in the NTIRE19 video restoration and enhancement challenges. The main difficulties of video super-resolution from two aspects: (1) how to align multiple frames given large motions, and (2) how to effectively fuse different frames with diverse motion and blur. First, to handle large motions, EDVR devise a Pyramid, Cascading and Deformable (PCD) alignment module, in which frame alignment is done at the feature level using deformable convolutions in a coarse-to-fine manner. Second, EDVR propose a Temporal and Spatial Attention (TSA) fusion module, in which attention is applied both temporally and spatially, so as to emphasize important features for subsequent restoration.
......@@ -79,6 +79,7 @@ The metrics are PSNR / SSIM.
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
| BasicVSR++_x4 | 32.4018 / 0.9071 |
## 1.4 Model Download
......@@ -92,6 +93,7 @@ The metrics are PSNR / SSIM.
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams)
......@@ -120,3 +122,14 @@ The metrics are PSNR / SSIM.
year = {2021}
}
```
- 3. [BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment](https://arxiv.org/pdf/2104.13371v1.pdf)
```
@article{chan2021basicvsr++,
author = {Chan, Kelvin C.K. and Zhou, Shangchen and Xu, Xiangyu and Loy, Chen Change},
title = {BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment},
booktitle = {arXiv preprint arXiv:2104.13371},
year = {2021}
}
```
......@@ -3,7 +3,7 @@
## 1.1 原理介绍
视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf).
视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf),[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf).
[EDVR](https://arxiv.org/pdf/1905.02716.pdf)模型在NTIRE19视频恢复和增强挑战赛的四个赛道中都赢得了冠军,并以巨大的优势超过了第二名。视频超分的主要难点在于(1)如何在给定大运动的情况下对齐多个帧;(2)如何有效地融合具有不同运动和模糊的不同帧。首先,为了处理大的运动,EDVR模型设计了一个金字塔级联的可变形(PCD)对齐模块,在该模块中,从粗到精的可变形卷积被使用来进行特征级的帧对齐。其次,EDVR使用了时空注意力(TSA)融合模块,该模块在时间和空间上同时应用注意力机制,以强调后续恢复的重要特征。
......@@ -75,6 +75,7 @@
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
| BasicVSR++_x4 | 32.4018 / 0.9071 |
## 1.4 模型下载
| 模型 | 数据集 | 下载地址 |
......@@ -87,6 +88,7 @@
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams)
......@@ -113,3 +115,13 @@
year = {2021}
}
```
- 3. [BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment](https://arxiv.org/pdf/2104.13371v1.pdf)
```
@article{chan2021basicvsr++,
author = {Chan, Kelvin C.K. and Zhou, Shangchen and Xu, Xiangyu and Loy, Chen Change},
title = {BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment},
booktitle = {arXiv preprint arXiv:2104.13371},
year = {2021}
}
```
......@@ -29,7 +29,7 @@ class BasicVSRModel(BaseSRModel):
Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021
"""
def __init__(self, generator, fix_iter, pixel_criterion=None):
def __init__(self, generator, fix_iter, lr_mult, pixel_criterion=None):
"""Initialize the BasicVSR class.
Args:
......@@ -41,6 +41,7 @@ class BasicVSRModel(BaseSRModel):
self.fix_iter = fix_iter
self.current_iter = 1
self.flag = True
self.lr_mult = lr_mult
init_basicvsr_weight(self.nets['generator'])
def setup_input(self, input):
......@@ -65,7 +66,7 @@ class BasicVSRModel(BaseSRModel):
for name, param in self.nets['generator'].named_parameters():
param.trainable = True
if 'spynet' in name:
param.optimize_attr['learning_rate'] = 0.125
param.optimize_attr['learning_rate'] = self.lr_mult
self.flag = False
for net in self.nets.values():
net.find_unused_parameters = False
......
......@@ -35,3 +35,4 @@ from .mpr import MPRNet
from .iconvsr import IconVSR
from .gpen import GPEN
from .pan import PAN
from .basicvsr_plus_plus import BasicVSRPlusPlus
# Copyright (c) MMEditing Authors.
import paddle
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.ops import DeformConv2D
from ...utils.download import get_path_from_url
from ...modules.init import kaiming_normal_, constant_
from .builder import GENERATORS
......@@ -607,3 +606,69 @@ class BasicVSRNet(nn.Layer):
outputs[i] = out
return paddle.stack(outputs, axis=1)
class SecondOrderDeformableAlignment(nn.Layer):
"""Second-order deformable alignment module.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
deformable_groups (int).
"""
def __init__(self,
in_channels=128,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
deformable_groups=16):
super(SecondOrderDeformableAlignment, self).__init__()
self.conv_offset = nn.Sequential(
nn.Conv2D(3 * out_channels + 4, out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2D(out_channels, 27 * deformable_groups, 3, 1, 1),
)
self.dcn = DeformConv2D(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
deformable_groups=deformable_groups)
self.init_offset()
def init_offset(self):
constant_(self.conv_offset[-1].weight, 0)
constant_(self.conv_offset[-1].bias, 0)
def forward(self, x, extra_feat, flow_1, flow_2):
extra_feat = paddle.concat([extra_feat, flow_1, flow_2], axis=1)
out = self.conv_offset(extra_feat)
o1, o2, mask = paddle.chunk(out, 3, axis=1)
# offset
offset = 10 * paddle.tanh(paddle.concat((o1, o2), axis=1))
offset_1, offset_2 = paddle.chunk(offset, 2, axis=1)
offset_1 = offset_1 + flow_1.flip(1).tile(
[1, offset_1.shape[1] // 2, 1, 1])
offset_2 = offset_2 + flow_2.flip(1).tile(
[1, offset_2.shape[1] // 2, 1, 1])
offset = paddle.concat([offset_1, offset_2], axis=1)
# mask
mask = F.sigmoid(mask)
out = self.dcn(x, offset, mask)
return out
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...utils.download import get_path_from_url
from .basicvsr import PixelShufflePack, flow_warp, SPyNet, \
ResidualBlocksWithInputConv, SecondOrderDeformableAlignment
from .builder import GENERATORS
@GENERATORS.register()
class BasicVSRPlusPlus(nn.Layer):
"""BasicVSR++ network structure.
Support either x4 upsampling or same size output. Since DCN is used in this
model, it can only be used with CUDA enabled.
Paper:
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation
and Alignment
Adapted from 'https://github.com/open-mmlab/mmediting'
'mmediting/blob/master/mmedit/models/backbones/sr_backbones/basicvsr_pp.py'
Copyright (c) MMEditing Authors.
Args:
mid_channels (int, optional): Channel number of the intermediate
features. Default: 64.
num_blocks (int, optional): The number of residual blocks in each
propagation branch. Default: 7.
is_low_res_input (bool, optional): Whether the input is low-resolution
or not. If False, the output resolution is equal to the input
resolution. Default: True.
"""
def __init__(self, mid_channels=64, num_blocks=7, is_low_res_input=True):
super().__init__()
self.mid_channels = mid_channels
self.is_low_res_input = is_low_res_input
# optical flow
self.spynet = SPyNet()
weight_path = get_path_from_url(
'https://paddlegan.bj.bcebos.com/models/spynet.pdparams')
self.spynet.set_state_dict(paddle.load(weight_path))
# feature extraction module
if is_low_res_input:
self.feat_extract = ResidualBlocksWithInputConv(3, mid_channels, 5)
else:
self.feat_extract = nn.Sequential(
nn.Conv2D(3, mid_channels, 3, 2, 1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2D(mid_channels, mid_channels, 3, 2, 1),
nn.LeakyReLU(negative_slope=0.1),
ResidualBlocksWithInputConv(mid_channels, mid_channels, 5))
# propagation branches
self.deform_align_backward_1 = SecondOrderDeformableAlignment(
2 * mid_channels, mid_channels, 3, padding=1, deformable_groups=16)
self.deform_align_forward_1 = SecondOrderDeformableAlignment(
2 * mid_channels, mid_channels, 3, padding=1, deformable_groups=16)
self.deform_align_backward_2 = SecondOrderDeformableAlignment(
2 * mid_channels, mid_channels, 3, padding=1, deformable_groups=16)
self.deform_align_forward_2 = SecondOrderDeformableAlignment(
2 * mid_channels, mid_channels, 3, padding=1, deformable_groups=16)
self.backbone_backward_1 = ResidualBlocksWithInputConv(
2 * mid_channels, mid_channels, num_blocks)
self.backbone_forward_1 = ResidualBlocksWithInputConv(
3 * mid_channels, mid_channels, num_blocks)
self.backbone_backward_2 = ResidualBlocksWithInputConv(
4 * mid_channels, mid_channels, num_blocks)
self.backbone_forward_2 = ResidualBlocksWithInputConv(
5 * mid_channels, mid_channels, num_blocks)
# upsampling module
self.reconstruction = ResidualBlocksWithInputConv(
5 * mid_channels, mid_channels, 5)
self.upsample1 = PixelShufflePack(mid_channels,
mid_channels,
2,
upsample_kernel=3)
self.upsample2 = PixelShufflePack(mid_channels,
64,
2,
upsample_kernel=3)
self.conv_hr = nn.Conv2D(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2D(64, 3, 3, 1, 1)
self.img_upsample = nn.Upsample(scale_factor=4,
mode='bilinear',
align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1)
def check_if_mirror_extended(self, lrs):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
(t-1-i)-th frame.
Args:
lqs (tensor): Input LR images with shape (n, t, c, h, w)
"""
with paddle.no_grad():
self.is_mirror_extended = False
if lrs.shape[1] % 2 == 0:
lrs_1, lrs_2 = paddle.chunk(lrs, 2, axis=1)
lrs_2 = paddle.flip(lrs_2, [1])
if paddle.norm(lrs_1 - lrs_2) == 0:
self.is_mirror_extended = True
def compute_flow(self, lrs):
"""Compute optical flow using SPyNet for feature alignment.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lqs (tensor): Input LR images with shape (n, t, c, h, w)
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
flows used for forward-time propagation (current to previous).
'flows_backward' corresponds to the flows used for
backward-time propagation (current to next).
"""
n, t, c, h, w = lrs.shape
lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w])
lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w])
flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w])
if self.is_mirror_extended:
flows_forward = flows_backward.flip(1)
else:
flows_forward = self.spynet(lrs_2,
lrs_1).reshape([n, t - 1, 2, h, w])
return flows_forward, flows_backward
def upsample(self, lqs, feats):
"""Compute the output image given the features.
Args:
lqs (tensor): Input LR images with shape (n, t, c, h, w).
feats (dict): The features from the propgation branches.
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
outputs = []
num_outputs = len(feats['spatial'])
mapping_idx = list(range(0, num_outputs))
mapping_idx += mapping_idx[::-1]
for i in range(0, lqs.shape[1]):
hr = [feats[k].pop(0) for k in feats if k != 'spatial']
hr.insert(0, feats['spatial'][mapping_idx[i]])
hr = paddle.concat(hr, axis=1)
hr = self.reconstruction(hr)
hr = self.lrelu(self.upsample1(hr))
hr = self.lrelu(self.upsample2(hr))
hr = self.lrelu(self.conv_hr(hr))
hr = self.conv_last(hr)
if self.is_low_res_input:
hr += self.img_upsample(lqs[:, i, :, :, :])
else:
hr += lqs[:, i, :, :, :]
outputs.append(hr)
return paddle.stack(outputs, axis=1)
def forward(self, lqs):
"""Forward function for BasicVSR++.
Args:
lqs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
n, t, c, h, w = lqs.shape
if self.is_low_res_input:
lqs_downsample = lqs
else:
lqs_downsample = F.interpolate(lqs.reshape([-1, c, h, w]),
scale_factor=0.25,
mode='bicubic').reshape(
[n, t, c, h // 4, w // 4])
# check whether the input is an extended sequence
self.check_if_mirror_extended(lqs)
feats = {}
feats_ = self.feat_extract(lqs.reshape([-1, c, h, w]))
h, w = feats_.shape[2:]
feats_ = feats_.reshape([n, t, -1, h, w])
feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
# compute optical flow using the low-res inputs
assert lqs_downsample.shape[3] >= 64 and lqs_downsample.shape[4] >= 64, (
'The height and width of low-res inputs must be at least 64, '
f'but got {h} and {w}.')
flows_forward, flows_backward = self.compute_flow(lqs_downsample)
# feature propgation
# backward_1
feats['backward_1'] = []
flows = flows_backward
n, t, _, h, w = flows.shape
frame_idx = range(t, -1, -1)
flow_idx = range(t, -1, -1)
mapping_idx = list(range(0, len(feats['spatial'])))
mapping_idx += mapping_idx[::-1]
feat_prop = paddle.zeros([n, self.mid_channels, h, w])
for i, idx in enumerate(frame_idx):
feat_current = feats['spatial'][mapping_idx[idx]]
if i > 0:
flow_n1 = flows[:, flow_idx[i], :, :, :]
cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1]))
# initialize second-order features
feat_n2 = paddle.zeros_like(feat_prop)
flow_n2 = paddle.zeros_like(flow_n1)
cond_n2 = paddle.zeros_like(cond_n1)
if i > 1: # second-order features
feat_n2 = feats['backward_1'][-2]
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
flow_n2 = flow_n1 + flow_warp(
flow_n2, flow_n1.transpose([0, 2, 3, 1]))
cond_n2 = flow_warp(feat_n2, flow_n2.transpose([0, 2, 3,
1]))
# flow-guided deformable convolution
cond = paddle.concat([cond_n1, feat_current, cond_n2], axis=1)
feat_prop = paddle.concat([feat_prop, feat_n2], axis=1)
feat_prop = self.deform_align_backward_1(
feat_prop, cond, flow_n1, flow_n2)
# concatenate and residual blocks
feat = [feat_current] + [
feats[k][idx]
for k in feats if k not in ['spatial', 'backward_1']
] + [feat_prop]
feat = paddle.concat(feat, axis=1)
feat_prop = feat_prop + self.backbone_backward_1(feat)
feats['backward_1'].append(feat_prop)
feats['backward_1'] = feats['backward_1'][::-1]
# forward_1
feats['forward_1'] = []
flows = flows_forward
n, t, _, h, w = flows.shape
frame_idx = range(0, t + 1)
flow_idx = range(-1, t)
mapping_idx = list(range(0, len(feats['spatial'])))
mapping_idx += mapping_idx[::-1]
feat_prop = paddle.zeros([n, self.mid_channels, h, w])
for i, idx in enumerate(frame_idx):
feat_current = feats['spatial'][mapping_idx[idx]]
if i > 0:
flow_n1 = flows[:, flow_idx[i], :, :, :]
cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1]))
# initialize second-order features
feat_n2 = paddle.zeros_like(feat_prop)
flow_n2 = paddle.zeros_like(flow_n1)
cond_n2 = paddle.zeros_like(cond_n1)
if i > 1: # second-order features
feat_n2 = feats['forward_1'][-2]
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
flow_n2 = flow_n1 + flow_warp(
flow_n2, flow_n1.transpose([0, 2, 3, 1]))
cond_n2 = flow_warp(feat_n2, flow_n2.transpose([0, 2, 3,
1]))
# flow-guided deformable convolution
cond = paddle.concat([cond_n1, feat_current, cond_n2], axis=1)
feat_prop = paddle.concat([feat_prop, feat_n2], axis=1)
feat_prop = self.deform_align_forward_1(feat_prop, cond,
flow_n1, flow_n2)
# concatenate and residual blocks
feat = [feat_current] + [
feats[k][idx]
for k in feats if k not in ['spatial', 'forward_1']
] + [feat_prop]
feat = paddle.concat(feat, axis=1)
feat_prop = feat_prop + self.backbone_forward_1(feat)
feats['forward_1'].append(feat_prop)
# backward_2
feats['backward_2'] = []
flows = flows_backward
n, t, _, h, w = flows.shape
frame_idx = range(t, -1, -1)
flow_idx = range(t, -1, -1)
mapping_idx = list(range(0, len(feats['spatial'])))
mapping_idx += mapping_idx[::-1]
feat_prop = paddle.zeros([n, self.mid_channels, h, w])
for i, idx in enumerate(frame_idx):
feat_current = feats['spatial'][mapping_idx[idx]]
if i > 0:
flow_n1 = flows[:, flow_idx[i], :, :, :]
cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1]))
# initialize second-order features
feat_n2 = paddle.zeros_like(feat_prop)
flow_n2 = paddle.zeros_like(flow_n1)
cond_n2 = paddle.zeros_like(cond_n1)
if i > 1: # second-order features
feat_n2 = feats['backward_2'][-2]
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
flow_n2 = flow_n1 + flow_warp(
flow_n2, flow_n1.transpose([0, 2, 3, 1]))
cond_n2 = flow_warp(feat_n2, flow_n2.transpose([0, 2, 3,
1]))
# flow-guided deformable convolution
cond = paddle.concat([cond_n1, feat_current, cond_n2], axis=1)
feat_prop = paddle.concat([feat_prop, feat_n2], axis=1)
feat_prop = self.deform_align_backward_2(
feat_prop, cond, flow_n1, flow_n2)
# concatenate and residual blocks
feat = [feat_current] + [
feats[k][idx]
for k in feats if k not in ['spatial', 'backward_2']
] + [feat_prop]
feat = paddle.concat(feat, axis=1)
feat_prop = feat_prop + self.backbone_backward_2(feat)
feats['backward_2'].append(feat_prop)
feats['backward_2'] = feats['backward_2'][::-1]
# forward_2
feats['forward_2'] = []
flows = flows_forward
n, t, _, h, w = flows.shape
frame_idx = range(0, t + 1)
flow_idx = range(-1, t)
mapping_idx = list(range(0, len(feats['spatial'])))
mapping_idx += mapping_idx[::-1]
feat_prop = paddle.zeros([n, self.mid_channels, h, w])
for i, idx in enumerate(frame_idx):
feat_current = feats['spatial'][mapping_idx[idx]]
if i > 0:
flow_n1 = flows[:, flow_idx[i], :, :, :]
cond_n1 = flow_warp(feat_prop, flow_n1.transpose([0, 2, 3, 1]))
# initialize second-order features
feat_n2 = paddle.zeros_like(feat_prop)
flow_n2 = paddle.zeros_like(flow_n1)
cond_n2 = paddle.zeros_like(cond_n1)
if i > 1: # second-order features
feat_n2 = feats['forward_2'][-2]
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
flow_n2 = flow_n1 + flow_warp(
flow_n2, flow_n1.transpose([0, 2, 3, 1]))
cond_n2 = flow_warp(feat_n2, flow_n2.transpose([0, 2, 3,
1]))
# flow-guided deformable convolution
cond = paddle.concat([cond_n1, feat_current, cond_n2], axis=1)
feat_prop = paddle.concat([feat_prop, feat_n2], axis=1)
feat_prop = self.deform_align_forward_2(feat_prop, cond,
flow_n1, flow_n2)
# concatenate and residual blocks
feat = [feat_current] + [
feats[k][idx]
for k in feats if k not in ['spatial', 'forward_2']
] + [feat_prop]
feat = paddle.concat(feat, axis=1)
feat_prop = feat_prop + self.backbone_forward_2(feat)
feats['forward_2'].append(feat_prop)
return self.upsample(lqs, feats)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册