未验证 提交 099c595e 编写于 作者: L LielinJiang 提交者: GitHub

Implement IconVSR (#384)

* add iconvsr
上级 1ffa3fba
total_iters: 300000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: BasicVSRModel
fix_iter: 5000
generator:
name: IconVSR
mid_channels: 64
num_blocks: 30
pixel_criterion:
name: CharbonnierLoss
reduction: mean
dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4 # 6
batch_size: 2 # 4*2
dataset:
name: SRREDSMultipleGTDataset
mode: train
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 15
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
test:
name: SRREDSMultipleGTDataset
mode: test
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [300000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5000
...@@ -78,6 +78,7 @@ The metrics are PSNR / SSIM. ...@@ -78,6 +78,7 @@ The metrics are PSNR / SSIM.
| EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 | | EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 |
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 | | EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 | | BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
## 1.4 Model Download ## 1.4 Model Download
...@@ -90,6 +91,7 @@ The metrics are PSNR / SSIM. ...@@ -90,6 +91,7 @@ The metrics are PSNR / SSIM.
| EDVR_L_wo_tsa_deblur | REDS | [EDVR_L_wo_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_deblur.pdparams) | EDVR_L_wo_tsa_deblur | REDS | [EDVR_L_wo_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_deblur.pdparams)
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams) | EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams) | BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
......
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
| EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 | | EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 |
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 | | EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 | | BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
## 1.4 模型下载 ## 1.4 模型下载
| 模型 | 数据集 | 下载地址 | | 模型 | 数据集 | 下载地址 |
...@@ -86,7 +86,7 @@ ...@@ -86,7 +86,7 @@
| EDVR_L_wo_tsa_deblur | REDS | [EDVR_L_wo_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_deblur.pdparams) | EDVR_L_wo_tsa_deblur | REDS | [EDVR_L_wo_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_deblur.pdparams)
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams) | EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams) | BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
......
...@@ -17,7 +17,8 @@ import paddle.nn as nn ...@@ -17,7 +17,8 @@ import paddle.nn as nn
from .builder import MODELS from .builder import MODELS
from .sr_model import BaseSRModel from .sr_model import BaseSRModel
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet from .generators.iconvsr import EDVRFeatureExtractor
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet
from ..modules.init import reset_parameters from ..modules.init import reset_parameters
from ..utils.visual import tensor2img from ..utils.visual import tensor2img
...@@ -57,7 +58,7 @@ class BasicVSRModel(BaseSRModel): ...@@ -57,7 +58,7 @@ class BasicVSRModel(BaseSRModel):
print('Train BasicVSR with fixed spynet for', self.fix_iter, print('Train BasicVSR with fixed spynet for', self.fix_iter,
'iters.') 'iters.')
for name, param in self.nets['generator'].named_parameters(): for name, param in self.nets['generator'].named_parameters():
if 'spynet' in name: if 'spynet' in name or 'edvr' in name:
param.trainable = False param.trainable = False
elif self.current_iter >= self.fix_iter + 1 and self.flag: elif self.current_iter >= self.fix_iter + 1 and self.flag:
print('Train all the parameters.') print('Train all the parameters.')
...@@ -107,5 +108,5 @@ def init_basicvsr_weight(net): ...@@ -107,5 +108,5 @@ def init_basicvsr_weight(net):
continue continue
if (not isinstance( if (not isinstance(
m, (ResidualBlockNoBN, PixelShufflePack, SPyNet))): m, (ResidualBlockNoBN, PixelShufflePack, SPyNet, EDVRFeatureExtractor))):
init_basicvsr_weight(m) init_basicvsr_weight(m)
...@@ -32,4 +32,5 @@ from .generator_firstorder import FirstOrderGenerator ...@@ -32,4 +32,5 @@ from .generator_firstorder import FirstOrderGenerator
from .generater_lapstyle import DecoderNet, Encoder, RevisionNet from .generater_lapstyle import DecoderNet, Encoder, RevisionNet
from .basicvsr import BasicVSRNet from .basicvsr import BasicVSRNet
from .mpr import MPRNet from .mpr import MPRNet
from .iconvsr import IconVSR
from .gpen import GPEN from .gpen import GPEN
...@@ -15,11 +15,9 @@ ...@@ -15,11 +15,9 @@
import paddle import paddle
import numpy as np import numpy as np
import scipy.io as scio
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn import initializer
from ...utils.download import get_path_from_url from ...utils.download import get_path_from_url
from ...modules.init import kaiming_normal_, constant_ from ...modules.init import kaiming_normal_, constant_
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# basicvsr and iconvsr code are heavily based on mmedit
import paddle
import numpy as np
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
from .edvr import PCDAlign, TSAFusion
from .basicvsr import SPyNet, PixelShufflePack, ResidualBlockNoBN, \
ResidualBlocksWithInputConv, flow_warp
from ...utils.download import get_path_from_url
@GENERATORS.register()
class IconVSR(nn.Layer):
"""BasicVSR network structure for video super-resolution.
Support only x4 upsampling.
Paper:
BasicVSR: The Search for Essential Components in Video Super-Resolution
and Beyond, CVPR, 2021
Args:
mid_channels (int): Channel number of the intermediate features.
Default: 64.
num_blocks (int): Number of residual blocks in each propagation branch.
Default: 30.
padding (int): Number of frames to be padded at two ends of the
sequence. 2 for REDS and 3 for Vimeo-90K. Default: 2.
keyframe_stride (int): Number determining the keyframes. If stride=5,
then the (0, 5, 10, 15, ...)-th frame will be the keyframes.
Default: 5.
"""
def __init__(self,
mid_channels=64,
num_blocks=30,
padding=2,
keyframe_stride=5):
super().__init__()
self.mid_channels = mid_channels
self.padding = padding
self.keyframe_stride = keyframe_stride
# optical flow network for feature alignment
self.spynet = SPyNet()
weight_path = get_path_from_url(
'https://paddlegan.bj.bcebos.com/models/spynet.pdparams')
self.spynet.set_state_dict(paddle.load(weight_path))
# information-refill
self.edvr = EDVRFeatureExtractor(num_frames=padding * 2 + 1,
center_frame_idx=padding)
edvr_wight_path = get_path_from_url(
'https://paddlegan.bj.bcebos.com/models/edvrm.pdparams')
self.edvr.set_state_dict(paddle.load(edvr_wight_path))
self.backward_fusion = nn.Conv2D(2 * mid_channels,
mid_channels,
3,
1,
1,
bias_attr=True)
self.forward_fusion = nn.Conv2D(2 * mid_channels,
mid_channels,
3,
1,
1,
bias_attr=True)
# propagation branches
self.backward_resblocks = ResidualBlocksWithInputConv(
mid_channels + 3, mid_channels, num_blocks)
self.forward_resblocks = ResidualBlocksWithInputConv(
2 * mid_channels + 3, mid_channels, num_blocks)
# upsample
# self.fusion = nn.Conv2D(mid_channels * 2, mid_channels, 1, 1, 0)
self.upsample1 = PixelShufflePack(mid_channels,
mid_channels,
2,
upsample_kernel=3)
self.upsample2 = PixelShufflePack(mid_channels,
64,
2,
upsample_kernel=3)
self.conv_hr = nn.Conv2D(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2D(64, 3, 3, 1, 1)
self.img_upsample = nn.Upsample(scale_factor=4,
mode='bilinear',
align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1)
def spatial_padding(self, lrs):
""" Apply pdding spatially.
Since the PCD module in EDVR requires that the resolution is a multiple
of 4, we apply padding to the input LR images if their resolution is
not divisible by 4.
Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
"""
n, t, c, h, w = lrs.shape
pad_h = (4 - h % 4) % 4
pad_w = (4 - w % 4) % 4
# padding
lrs = lrs.reshape([-1, c, h, w])
lrs = F.pad(lrs, [0, pad_w, 0, pad_h], mode='reflect')
return lrs.reshape([n, t, c, h + pad_h, w + pad_w])
def check_if_mirror_extended(self, lrs):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
(t-1-i)-th frame.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
Returns:
bool: whether the input is a mirror-extended sequence
"""
self.is_mirror_extended = False
if lrs.shape[1] % 2 == 0:
lrs_1, lrs_2 = paddle.chunk(lrs, 2, axis=1)
lrs_2 = paddle.flip(lrs_2, [1])
if paddle.norm(lrs_1 - lrs_2) == 0:
self.is_mirror_extended = True
def compute_refill_features(self, lrs, keyframe_idx):
""" Compute keyframe features for information-refill.
Since EDVR-M is used, padding is performed before feature computation.
Args:
lrs (Tensor): Input LR images with shape (n, t, c, h, w)
keyframe_idx (list(int)): The indices specifying the keyframes.
Return:
dict(Tensor): The keyframe features. Each key corresponds to the
indices in keyframe_idx.
"""
if self.padding == 2:
lrs = [
lrs[:, 4:5, :, :], lrs[:, 3:4, :, :], lrs, lrs[:, -4:-3, :, :],
lrs[:, -5:-4, :, :]
]
elif self.padding == 3:
lrs = [lrs[:, [6, 5, 4]], lrs, lrs[:, [-5, -6, -7]]]
lrs = paddle.concat(lrs, axis=1)
num_frames = 2 * self.padding + 1
feats_refill = {}
for i in keyframe_idx:
feats_refill[i] = self.edvr(lrs[:, i:i + num_frames])
return feats_refill
def compute_flow(self, lrs):
"""Compute optical flow using SPyNet for feature warping.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
flows used for forward-time propagation (current to previous).
'flows_backward' corresponds to the flows used for
backward-time propagation (current to next).
"""
n, t, c, h, w = lrs.shape
lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w])
lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w])
flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w])
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
flows_forward = None
else:
flows_forward = self.spynet(lrs_2,
lrs_1).reshape([n, t - 1, 2, h, w])
return flows_forward, flows_backward
def forward(self, lrs):
"""Forward function for BasicVSR.
Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
n, t, c, h_input, w_input = lrs.shape
assert h_input >= 64 and w_input >= 64, (
'The height and width of inputs should be at least 64, '
f'but got {h_input} and {w_input}.')
# check whether the input is an extended sequence
self.check_if_mirror_extended(lrs)
lrs = self.spatial_padding(lrs)
h, w = lrs.shape[3], lrs.shape[4]
# get the keyframe indices for information-refill
keyframe_idx = list(range(0, t, self.keyframe_stride))
if keyframe_idx[-1] != t - 1:
keyframe_idx.append(t - 1) # the last frame must be a keyframe
# compute optical flow and compute features for information-refill
flows_forward, flows_backward = self.compute_flow(lrs)
feats_refill = self.compute_refill_features(lrs, keyframe_idx)
# compute optical flow
flows_forward, flows_backward = self.compute_flow(lrs)
# backward-time propgation
outputs = []
feat_prop = paddle.to_tensor(
np.zeros([n, self.mid_channels, h, w], 'float32'))
for i in range(t - 1, -1, -1):
# no warping required for the last timestep
if i < t - 1:
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1]))
# information refill
if i in keyframe_idx:
feat_prop = paddle.concat([feat_prop, feats_refill[i]], axis=1)
feat_prop = self.backward_fusion(feat_prop)
feat_prop = paddle.concat([lrs[:, i, :, :, :], feat_prop], axis=1)
feat_prop = self.backward_resblocks(feat_prop)
outputs.append(feat_prop)
outputs = outputs[::-1]
# forward-time propagation and upsampling
feat_prop = paddle.zeros_like(feat_prop)
for i in range(0, t):
lr_curr = lrs[:, i, :, :, :]
if i > 0: # no warping required for the first timestep
if flows_forward is not None:
flow = flows_forward[:, i - 1, :, :, :]
else:
flow = flows_backward[:, -i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1]))
# information refill
if i in keyframe_idx:
feat_prop = paddle.concat([feat_prop, feats_refill[i]], axis=1)
feat_prop = self.forward_fusion(feat_prop)
feat_prop = paddle.concat([lr_curr, outputs[i], feat_prop], axis=1)
feat_prop = self.forward_resblocks(feat_prop)
# upsampling given the backward and forward features
out = self.lrelu(self.upsample1(feat_prop))
out = self.lrelu(self.upsample2(out))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = self.img_upsample(lr_curr)
out += base
outputs[i] = out
return paddle.stack(outputs, axis=1)
class EDVRFeatureExtractor(nn.Layer):
"""EDVR feature extractor for information-refill in IconVSR.
We use EDVR-M in IconVSR. To adopt pretrained models, please
specify "pretrained".
Paper:
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
mid_channels (int): Channel number of intermediate features.
Default: 64.
num_frames (int): Number of input frames. Default: 5.
deform_groups (int): Deformable groups. Defaults: 8.
num_blocks_extraction (int): Number of blocks for feature extraction.
Default: 5.
num_blocks_reconstruction (int): Number of blocks for reconstruction.
Default: 10.
center_frame_idx (int): The index of center frame. Frame counting from
0. Default: 2.
with_tsa (bool): Whether to use TSA module. Default: True.
"""
def __init__(self,
in_channels=3,
out_channel=3,
mid_channels=64,
num_frames=5,
deform_groups=8,
num_blocks_extraction=5,
num_blocks_reconstruction=10,
center_frame_idx=2,
with_tsa=True):
super().__init__()
self.center_frame_idx = center_frame_idx
self.with_tsa = with_tsa
self.conv_first = nn.Conv2D(in_channels, mid_channels, 3, 1, 1)
self.feature_extraction = make_layer(ResidualBlockNoBN,
num_blocks_extraction,
nf=mid_channels)
# generate pyramid features
self.feat_l2_conv1 = nn.Conv2D(mid_channels, mid_channels, 3, 2, 1)
self.feat_l2_conv2 = nn.Conv2D(mid_channels, mid_channels, 3, 1, 1)
self.feat_l3_conv1 = nn.Conv2D(mid_channels, mid_channels, 3, 2, 1)
self.feat_l3_conv2 = nn.Conv2D(mid_channels, mid_channels, 3, 1, 1)
# pcd alignment
self.pcd_alignment = PCDAlign(nf=mid_channels, groups=deform_groups)
# fusion
if self.with_tsa:
self.fusion = TSAFusion(nf=mid_channels,
nframes=num_frames,
center=self.center_frame_idx)
else:
self.fusion = nn.Conv2D(num_frames * mid_channels, mid_channels, 1,
1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1)
def forward(self, x):
"""Forward function for EDVRFeatureExtractor.
Args:
x (Tensor): Input tensor with shape (n, t, 3, h, w).
Returns:
Tensor: Intermediate feature with shape (n, mid_channels, h, w).
"""
n, t, c, h, w = x.shape
# extract LR features
# L1
l1_feat = self.lrelu(self.conv_first(x.reshape([-1, c, h, w])))
l1_feat = self.feature_extraction(l1_feat)
# L2
l2_feat = self.lrelu(
self.feat_l2_conv2(self.lrelu(self.feat_l2_conv1(l1_feat))))
# L3
l3_feat = self.lrelu(
self.feat_l3_conv2(self.lrelu(self.feat_l3_conv1(l2_feat))))
l1_feat = l1_feat.reshape([n, t, -1, h, w])
l2_feat = l2_feat.reshape([n, t, -1, h // 2, w // 2])
l3_feat = l3_feat.reshape([n, t, -1, h // 4, w // 4])
# pcd alignment
ref_feats = [ # reference feature list
l1_feat[:, self.center_frame_idx, :, :, :].clone(),
l2_feat[:, self.center_frame_idx, :, :, :].clone(),
l3_feat[:, self.center_frame_idx, :, :, :].clone()
]
aligned_feat = []
for i in range(t):
neighbor_feats = [
l1_feat[:, i, :, :, :].clone(), l2_feat[:, i, :, :, :].clone(),
l3_feat[:, i, :, :, :].clone()
]
aligned_feat.append(self.pcd_alignment(neighbor_feats, ref_feats))
aligned_feat = paddle.stack(aligned_feat, axis=1) # (n, t, c, h, w)
if self.with_tsa:
feat = self.fusion(aligned_feat)
else:
aligned_feat = aligned_feat.reshape([n, -1, h, w])
feat = self.fusion(aligned_feat)
return feat
def make_layer(block, num_blocks, **kwarg):
"""Make layers by stacking the same blocks.
Args:
block (nn.Layer): nn.module class for basic block.
num_blocks (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_blocks):
layers.append(block(**kwarg))
return nn.Sequential(*layers)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册