未验证 提交 4bb26861 编写于 作者: L LielinJiang 提交者: GitHub

Implement basicvsr (#356)

* add basicvsr model
上级 058faa87
total_iters: 300000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: BasicVSRModel
fix_iter: 5000
generator:
name: BasicVSRNet
mid_channels: 64
num_blocks: 30
pixel_criterion:
name: CharbonnierLoss
reduction: mean
dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4 # 6
batch_size: 2 # 4*2
dataset:
name: SRREDSMultipleGTDataset
mode: train
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 15
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
test:
name: SRREDSMultipleGTDataset
mode: test
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [300000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
# FIXME: avoid oom
interval: 5000000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5000
......@@ -24,4 +24,5 @@ from .starganv2_dataset import StarGANv2Dataset
from .edvr_dataset import REDSDataset
from .firstorder_dataset import FirstOrderDataset
from .lapstyle_dataset import LapStyleDataset
from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset
from .mpr_dataset import MPRTrain, MPRVal, MPRTest
......@@ -19,11 +19,25 @@ import numpy as np
from paddle.distributed import ParallelEnv
from paddle.io import DistributedBatchSampler
from ..utils.registry import Registry
from .repeat_dataset import RepeatDataset
from ..utils.registry import Registry, build_from_config
DATASETS = Registry("DATASETS")
def build_dataset(cfg):
name = cfg.pop('name')
if name == 'RepeatDataset':
dataset_ = build_from_config(cfg['dataset'], DATASETS)
dataset = RepeatDataset(dataset_, cfg['times'])
else:
dataset = dataset = DATASETS.get(name)(**cfg)
return dataset
def build_dataloader(cfg, is_train=True, distributed=True):
cfg_ = cfg.copy()
......@@ -31,9 +45,7 @@ def build_dataloader(cfg, is_train=True, distributed=True):
num_workers = cfg_.pop('num_workers', 0)
use_shared_memory = cfg_.pop('use_shared_memory', True)
name = cfg_.pop('name')
dataset = DATASETS.get(name)(**cfg_)
dataset = build_dataset(cfg_)
if distributed:
sampler = DistributedBatchSampler(dataset,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class RepeatDataset(paddle.io.Dataset):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
return self.dataset[idx % self._ori_len]
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return self.times * self._ori_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import numpy as np
import cv2
from paddle.io import Dataset
from .builder import DATASETS
logger = logging.getLogger(__name__)
@DATASETS.register()
class SRREDSMultipleGTDataset(Dataset):
"""REDS dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
num_input_frames (int): Number of input frames.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
val_partition (str): Validation partition mode. Choices ['official' or
'REDS4']. Default: 'official'.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""
def __init__(self,
mode,
lq_folder,
gt_folder,
crop_size=256,
interval_list=[1],
random_reverse=False,
number_frames=15,
use_flip=False,
use_rot=False,
scale=4,
val_partition='REDS4',
batch_size=4):
super(SRREDSMultipleGTDataset, self).__init__()
self.mode = mode
self.fileroot = str(lq_folder)
self.gtroot = str(gt_folder)
self.crop_size = crop_size
self.interval_list = interval_list
self.random_reverse = random_reverse
self.number_frames = number_frames
self.use_flip = use_flip
self.use_rot = use_rot
self.scale = scale
self.val_partition = val_partition
self.batch_size = batch_size
self.data_infos = self.load_annotations()
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
item = self.data_infos[idx]
idt = random.randint(0, 100 - self.number_frames)
item = item + '_' + f'{idt:03d}'
img_LQs, img_GTs = self.get_sample_data(
item, self.number_frames, self.interval_list, self.random_reverse,
self.gtroot, self.fileroot, self.crop_size, self.scale,
self.use_flip, self.use_rot, self.mode)
return {'lq': img_LQs, 'gt': img_GTs, 'lq_path': self.data_infos[idx]}
def load_annotations(self):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys = [f'{i:03d}' for i in range(0, 270)]
if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif self.val_partition == 'official':
val_partition = [f'{i:03d}' for i in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {self.val_partition}.'
f'Supported ones are ["official", "REDS4"]')
if self.mode == 'train':
keys = [v for v in keys if v not in val_partition]
else:
keys = [v for v in keys if v in val_partition]
data_infos = []
for key in keys:
data_infos.append(key)
return data_infos
def get_sample_data(self,
item,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
frame_idxs = self.get_neighbor_frames(frame_name,
number_frames=number_frames,
interval_list=interval_list,
random_reverse=random_reverse)
frame_list = []
gt_list = []
for frame_idx in frame_idxs:
frame_idx_name = "%08d" % frame_idx
img = self.read_img(
os.path.join(fileroot, video_name, frame_idx_name + '.png'))
frame_list.append(img)
gt_img = self.read_img(
os.path.join(gtroot, video_name, frame_idx_name + '.png'))
gt_list.append(gt_img)
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
frame_list = [
v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
for v in frame_list
]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
gt_list = [
v[rnd_h_HR:rnd_h_HR + crop_size,
rnd_w_HR:rnd_w_HR + crop_size, :] for v in gt_list
]
# add random flip and rotation
for v in gt_list:
frame_list.append(v)
if (mode == 'train') or (mode == 'valid'):
rlt = self.img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:number_frames]
gt_list = rlt[number_frames:]
# stack LQ images to NHWC, N is the frame number
frame_list = [v.transpose(2, 0, 1).astype('float32') for v in frame_list]
gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list]
img_LQs = np.stack(frame_list, axis=0)
img_GTs = np.stack(gt_list, axis=0)
return img_LQs, img_GTs
def get_neighbor_frames(self, frame_name, number_frames, interval_list,
random_reverse):
frame_idx = int(frame_name)
interval = random.choice(interval_list)
neighbor_list = list(
range(frame_idx, frame_idx + number_frames, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list
def read_img(self, path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def img_augment(self, img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return len(self.data_infos)
......@@ -323,7 +323,6 @@ class Trainer:
is_save_image=False):
"""
visual the images, use visualdl or directly write to the directory
Parameters:
results_dir (str) -- directory name which contains saved images
visual_results (dict) -- the results images dict
......@@ -440,7 +439,6 @@ class Trainer:
def close(self):
"""
when finish the training need close file handler or other.
"""
if self.enable_visualdl:
self.vdl_logger.close()
self.vdl_logger.close()
\ No newline at end of file
......@@ -30,4 +30,5 @@ from .starganv2_model import StarGANv2Model
from .edvr_model import EDVRModel
from .firstorder_model import FirstOrderModel
from .lapstyle_model import LapStyleDraModel, LapStyleRevFirstModel, LapStyleRevSecondModel
from .basicvsr_model import BasicVSRModel
from .mpr_model import MPRModel
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .builder import MODELS
from .sr_model import BaseSRModel
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet
from ..modules.init import reset_parameters
from ..utils.visual import tensor2img
@MODELS.register()
class BasicVSRModel(BaseSRModel):
"""BasicVSR Model.
Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021
"""
def __init__(self, generator, fix_iter, pixel_criterion=None):
"""Initialize the BasicVSR class.
Args:
generator (dict): config of generator.
fix_iter (dict): config of fix_iter.
pixel_criterion (dict): config of pixel criterion.
"""
super(BasicVSRModel, self).__init__(generator, pixel_criterion)
self.fix_iter = fix_iter
self.current_iter = 1
self.flag = True
init_basicvsr_weight(self.nets['generator'])
def setup_input(self, input):
self.lq = paddle.to_tensor(input['lq'])
self.visual_items['lq'] = self.lq[:, 0, :, :, :]
if 'gt' in input:
self.gt = paddle.to_tensor(input['gt'])
self.visual_items['gt'] = self.gt[:, 0, :, :, :]
self.image_paths = input['lq_path']
def train_iter(self, optims=None):
optims['optim'].clear_grad()
if self.fix_iter:
if self.current_iter == 1:
print('Train BasicVSR with fixed spynet for', self.fix_iter,
'iters.')
for name, param in self.nets['generator'].named_parameters():
if 'spynet' in name:
param.trainable = False
elif self.current_iter >= self.fix_iter + 1 and self.flag:
print('Train all the parameters.')
for name, param in self.nets['generator'].named_parameters():
param.trainable = True
if 'spynet' in name:
param.optimize_attr['learning_rate'] = 0.125
self.flag = False
for net in self.nets.values():
net.find_unused_parameters = False
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output[:, 0, :, :, :]
# pixel loss
loss_pixel = self.pixel_criterion(self.output, self.gt)
loss_pixel.backward()
optims['optim'].step()
self.losses['loss_pixel'] = loss_pixel
self.current_iter += 1
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output[:, 0, :, :, :]
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output[0], self.gt[0]):
# print(out_tensor.shape, gt_tensor.shape)
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
def init_basicvsr_weight(net):
for m in net.children():
if hasattr(m, 'weight') and not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D)):
reset_parameters(m)
continue
if (not isinstance(
m, (ResidualBlockNoBN, PixelShufflePack, SPyNet))):
init_basicvsr_weight(m)
......@@ -59,8 +59,9 @@ class CharbonnierLoss():
eps (float): Default: 1e-12.
"""
def __init__(self, eps=1e-12):
def __init__(self, eps=1e-12, reduction='sum'):
self.eps = eps
self.reduction = reduction
def __call__(self, pred, target, **kwargs):
"""Forward Function.
......@@ -69,7 +70,14 @@ class CharbonnierLoss():
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
return paddle.sum(paddle.sqrt((pred - target)**2 + self.eps))
if self.reduction == 'sum':
out = paddle.sum(paddle.sqrt((pred - target)**2 + self.eps))
elif self.reduction == 'mean':
out = paddle.mean(paddle.sqrt((pred - target)**2 + self.eps))
else:
raise NotImplementedError('CharbonnierLoss %s not implemented' %
self.reduction)
return out
@CRITERIONS.register()
......
......@@ -30,4 +30,5 @@ from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Ma
from .edvr import EDVRNet
from .generator_firstorder import FirstOrderGenerator
from .generater_lapstyle import DecoderNet, Encoder, RevisionNet
from .basicvsr import BasicVSRNet
from .mpr import MPRNet
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import scipy.io as scio
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import initializer
from ...utils.download import get_path_from_url
from ...modules.init import kaiming_normal_, constant_
from .builder import GENERATORS
@paddle.no_grad()
def default_init_weights(layer_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
layer_list (list[nn.Layer] | nn.Layer): Layers to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(layer_list, list):
layer_list = [layer_list]
for m in layer_list:
if isinstance(m, nn.Conv2D):
kaiming_normal_(m.weight, **kwargs)
scale_weight = scale * m.weight
m.weight.set_value(scale_weight)
if m.bias is not None:
constant_(m.bias, bias_fill)
elif isinstance(m, nn.Linear):
kaiming_normal_(m.weight, **kwargs)
scale_weight = scale * m.weight
m.weight.set_value(scale_weight)
if m.bias is not None:
constant_(m.bias, bias_fill)
elif isinstance(m, nn.BatchNorm):
constant_(m.weight, 1)
class PixelShufflePack(nn.Layer):
""" Pixel Shuffle upsample layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
scale_factor (int): Upsample ratio.
upsample_kernel (int): Kernel size of Conv layer to expand channels.
Returns:
Upsampled feature map.
"""
def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.scale_factor = scale_factor
self.upsample_kernel = upsample_kernel
self.upsample_conv = nn.Conv2D(self.in_channels,
self.out_channels * scale_factor *
scale_factor,
self.upsample_kernel,
padding=(self.upsample_kernel - 1) // 2)
self.pixel_shuffle = nn.PixelShuffle(self.scale_factor)
self.init_weights()
def init_weights(self):
"""Initialize weights for PixelShufflePack.
"""
default_init_weights(self, 1)
def forward(self, x):
"""Forward function for PixelShufflePack.
Args:
x (Tensor): Input tensor with shape (in_channels, c, h, w).
Returns:
Tensor with shape (out_channels, c, scale_factor*h, scale_factor*w).
"""
x = self.upsample_conv(x)
x = self.pixel_shuffle(x)
return x
def MakeMultiBlocks(func, num_layers, nf=64):
"""Make layers by stacking the same blocks.
Args:
func (nn.Layer): nn.Layer class for basic block.
num_layers (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
Blocks = nn.Sequential()
for i in range(num_layers):
Blocks.add_sublayer('block%d' % i, func(nf))
return Blocks
class ResidualBlockNoBN(nn.Layer):
"""Residual block without BN.
It has a style of:
---Conv-ReLU-Conv-+-
|________________|
Args:
nf (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.0.
"""
def __init__(self, nf=64, res_scale=1.0):
super(ResidualBlockNoBN, self).__init__()
self.nf = nf
self.res_scale = res_scale
self.conv1 = nn.Conv2D(self.nf, self.nf, 3, 1, 1)
self.conv2 = nn.Conv2D(self.nf, self.nf, 3, 1, 1)
self.relu = nn.ReLU()
if self.res_scale == 1.0:
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor with shape (n, c, h, w).
"""
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
def flow_warp(x,
flow,
interpolation='bilinear',
padding_mode='zeros',
align_corners=True):
"""Warp an image or a feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
a two-channel, denoting the width and height relative offsets.
Note that the values are not normalized to [-1, 1].
interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
Default: 'bilinear'.
padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Whether align corners. Default: True.
Returns:
Tensor: Warped image or feature map.
"""
if x.shape[-2:] != flow.shape[1:3]:
raise ValueError(f'The spatial sizes of input ({x.shape[-2:]}) and '
f'flow ({flow.shape[1:3]}) are not the same.')
_, _, h, w = x.shape
# create mesh grid
grid_y, grid_x = paddle.meshgrid(paddle.arange(0, h), paddle.arange(0, w))
grid = paddle.stack((grid_x, grid_y), axis=2) # (w, h, 2)
grid = paddle.cast(grid, 'float32')
grid.stop_gradient = True
grid_flow = grid + flow
# scale grid_flow to [-1,1]
grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
grid_flow = paddle.stack((grid_flow_x, grid_flow_y), axis=3)
output = F.grid_sample(x,
grid_flow,
mode=interpolation,
padding_mode=padding_mode,
align_corners=align_corners)
return output
class SPyNetBasicModule(nn.Layer):
"""Basic Module for SPyNet.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2D(in_channels=8,
out_channels=32,
kernel_size=7,
stride=1,
padding=3)
self.conv2 = nn.Conv2D(in_channels=32,
out_channels=64,
kernel_size=7,
stride=1,
padding=3)
self.conv3 = nn.Conv2D(in_channels=64,
out_channels=32,
kernel_size=7,
stride=1,
padding=3)
self.conv4 = nn.Conv2D(in_channels=32,
out_channels=16,
kernel_size=7,
stride=1,
padding=3)
self.conv5 = nn.Conv2D(in_channels=16,
out_channels=2,
kernel_size=7,
stride=1,
padding=3)
self.relu = nn.ReLU()
def forward(self, tensor_input):
"""
Args:
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
8 channels contain:
[reference image (3), neighbor image (3), initial flow (2)].
Returns:
Tensor: Refined flow with shape (b, 2, h, w)
"""
out = self.relu(self.conv1(tensor_input))
out = self.relu(self.conv2(out))
out = self.relu(self.conv3(out))
out = self.relu(self.conv4(out))
out = self.conv5(out)
return out
class SPyNet(nn.Layer):
"""SPyNet network structure.
The difference to the SPyNet in paper is that
1. more SPyNetBasicModule is used in this version, and
2. no batch normalization is used in this version.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
"""
def __init__(self):
super().__init__()
self.basic_module0 = SPyNetBasicModule()
self.basic_module1 = SPyNetBasicModule()
self.basic_module2 = SPyNetBasicModule()
self.basic_module3 = SPyNetBasicModule()
self.basic_module4 = SPyNetBasicModule()
self.basic_module5 = SPyNetBasicModule()
self.register_buffer(
'mean',
paddle.to_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]))
self.register_buffer(
'std',
paddle.to_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]))
def compute_flow(self, ref, supp):
"""Compute flow from ref to supp.
Note that in this function, the images are already resized to a
multiple of 32.
Args:
ref (Tensor): Reference image with shape of (n, 3, h, w).
supp (Tensor): Supporting image with shape of (n, 3, h, w).
Returns:
Tensor: Estimated optical flow: (n, 2, h, w).
"""
n, _, h, w = ref.shape
# normalize the input images
ref = [(ref - self.mean) / self.std]
supp = [(supp - self.mean) / self.std]
# generate downsampled frames
for level in range(5):
ref.append(F.avg_pool2d(ref[-1], kernel_size=2, stride=2))
supp.append(F.avg_pool2d(supp[-1], kernel_size=2, stride=2))
ref = ref[::-1]
supp = supp[::-1]
# flow computation
flow = paddle.to_tensor(np.zeros([n, 2, h // 32, w // 32], 'float32'))
# level=0
flow_up = flow
flow = flow_up + self.basic_module0(
paddle.concat([
ref[0],
flow_warp(supp[0],
flow_up.transpose([0, 2, 3, 1]),
padding_mode='border'), flow_up
], 1))
# level=1
flow_up = F.interpolate(
flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module1(
paddle.concat([
ref[1],
flow_warp(supp[1],
flow_up.transpose([0, 2, 3, 1]),
padding_mode='border'), flow_up
], 1))
# level=2
flow_up = F.interpolate(
flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module2(
paddle.concat([
ref[2],
flow_warp(supp[2],
flow_up.transpose([0, 2, 3, 1]),
padding_mode='border'), flow_up
], 1))
# level=3
flow_up = F.interpolate(
flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module3(
paddle.concat([
ref[3],
flow_warp(supp[3],
flow_up.transpose([0, 2, 3, 1]),
padding_mode='border'), flow_up
], 1))
# level=4
flow_up = F.interpolate(
flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module4(
paddle.concat([
ref[4],
flow_warp(supp[4],
flow_up.transpose([0, 2, 3, 1]),
padding_mode='border'), flow_up
], 1))
# level=5
flow_up = F.interpolate(
flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module5(
paddle.concat([
ref[5],
flow_warp(supp[5],
flow_up.transpose([0, 2, 3, 1]),
padding_mode='border'), flow_up
], 1))
return flow
def forward(self, ref, supp):
"""Forward function of SPyNet.
This function computes the optical flow from ref to supp.
Args:
ref (Tensor): Reference image with shape of (n, 3, h, w).
supp (Tensor): Supporting image with shape of (n, 3, h, w).
Returns:
Tensor: Estimated optical flow: (n, 2, h, w).
"""
# upsize to a multiple of 32
h, w = ref.shape[2:4]
w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
ref = F.interpolate(ref,
size=(h_up, w_up),
mode='bilinear',
align_corners=False)
supp = F.interpolate(supp,
size=(h_up, w_up),
mode='bilinear',
align_corners=False)
ref.stop_gradient = False
supp.stop_gradient = False
# compute flow, and resize back to the original resolution
flow_up = self.compute_flow(ref, supp)
flow = F.interpolate(flow_up,
size=(h, w),
mode='bilinear',
align_corners=False)
# adjust the flow values
# todo: grad bug
# flow[:, 0, :, :] *= (float(w) / float(w_up))
# flow[:, 1, :, :] *= (float(h) / float(h_up))
flow_x = flow[:, 0:1, :, :] * (float(w) / float(w_up))
flow_y = flow[:, 1:2, :, :] * (float(h) / float(h_up))
flow = paddle.concat([flow_x, flow_y], 1)
return flow
class ResidualBlocksWithInputConv(nn.Layer):
"""Residual blocks with a convolution in front.
Args:
in_channels (int): Number of input channels of the first conv.
out_channels (int): Number of channels of the residual blocks.
Default: 64.
num_blocks (int): Number of residual blocks. Default: 30.
"""
def __init__(self, in_channels, out_channels=64, num_blocks=30):
super().__init__()
# a convolution used to match the channels of the residual blocks
self.covn1 = nn.Conv2D(in_channels, out_channels, 3, 1, 1)
self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1)
# residual blocks
self.ResidualBlocks = MakeMultiBlocks(ResidualBlockNoBN,
num_blocks,
nf=out_channels)
def forward(self, feat):
"""
Forward function for ResidualBlocksWithInputConv.
Args:
feat (Tensor): Input feature with shape (n, in_channels, h, w)
Returns:
Tensor: Output feature with shape (n, out_channels, h, w)
"""
out = self.Leaky_relu(self.covn1(feat))
out = self.ResidualBlocks(out)
return out
@GENERATORS.register()
class BasicVSRNet(nn.Layer):
"""BasicVSR network structure for video super-resolution.
Support only x4 upsampling.
Paper:
BasicVSR: The Search for Essential Components in Video Super-Resolution
and Beyond, CVPR, 2021
Args:
mid_channels (int): Channel number of the intermediate features.
Default: 64.
num_blocks (int): Number of residual blocks in each propagation branch.
Default: 30.
"""
def __init__(self, mid_channels=64, num_blocks=30):
super().__init__()
self.mid_channels = mid_channels
# optical flow network for feature alignment
self.spynet = SPyNet()
weight_path = get_path_from_url(
'https://paddlegan.bj.bcebos.com/models/spynet.pdparams')
self.spynet.set_state_dict(paddle.load(weight_path))
# propagation branches
self.backward_resblocks = ResidualBlocksWithInputConv(
mid_channels + 3, mid_channels, num_blocks)
self.forward_resblocks = ResidualBlocksWithInputConv(
mid_channels + 3, mid_channels, num_blocks)
# upsample
self.fusion = nn.Conv2D(mid_channels * 2, mid_channels, 1, 1, 0)
self.upsample1 = PixelShufflePack(mid_channels,
mid_channels,
2,
upsample_kernel=3)
self.upsample2 = PixelShufflePack(mid_channels,
64,
2,
upsample_kernel=3)
self.conv_hr = nn.Conv2D(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2D(64, 3, 3, 1, 1)
self.img_upsample = nn.Upsample(scale_factor=4,
mode='bilinear',
align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1)
def check_if_mirror_extended(self, lrs):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
(t-1-i)-th frame.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
"""
self.is_mirror_extended = False
if lrs.shape[1] % 2 == 0:
lrs_1, lrs_2 = paddle.chunk(lrs, 2, axis=1)
lrs_2 = paddle.flip(lrs_2, [1])
if paddle.norm(lrs_1 - lrs_2) == 0:
self.is_mirror_extended = True
def compute_flow(self, lrs):
"""Compute optical flow using SPyNet for feature warping.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
flows used for forward-time propagation (current to previous).
'flows_backward' corresponds to the flows used for
backward-time propagation (current to next).
"""
n, t, c, h, w = lrs.shape
lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w])
lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w])
flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w])
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
flows_forward = None
else:
flows_forward = self.spynet(lrs_2,
lrs_1).reshape([n, t - 1, 2, h, w])
return flows_forward, flows_backward
def forward(self, lrs):
"""Forward function for BasicVSR.
Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
n, t, c, h, w = lrs.shape
assert h >= 64 and w >= 64, (
'The height and width of inputs should be at least 64, '
f'but got {h} and {w}.')
# check whether the input is an extended sequence
self.check_if_mirror_extended(lrs)
# compute optical flow
flows_forward, flows_backward = self.compute_flow(lrs)
# backward-time propgation
outputs = []
feat_prop = paddle.to_tensor(
np.zeros([n, self.mid_channels, h, w], 'float32'))
for i in range(t - 1, -1, -1):
if i < t - 1: # no warping required for the last timestep
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1]))
feat_prop = paddle.concat([lrs[:, i, :, :, :], feat_prop], axis=1)
feat_prop = self.backward_resblocks(feat_prop)
outputs.append(feat_prop)
outputs = outputs[::-1]
# forward-time propagation and upsampling
feat_prop = paddle.zeros_like(feat_prop)
for i in range(0, t):
lr_curr = lrs[:, i, :, :, :]
if i > 0: # no warping required for the first timestep
if flows_forward is not None:
flow = flows_forward[:, i - 1, :, :, :]
else:
flow = flows_backward[:, -i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1]))
feat_prop = paddle.concat([lr_curr, feat_prop], axis=1)
feat_prop = self.forward_resblocks(feat_prop)
# upsampling given the backward and forward features
out = paddle.concat([outputs[i], feat_prop], axis=1)
out = self.lrelu(self.fusion(out))
out = self.lrelu(self.upsample1(out))
out = self.lrelu(self.upsample2(out))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = self.img_upsample(lr_curr)
out += base
outputs[i] = out
return paddle.stack(outputs, axis=1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册