未验证 提交 4bb26861 编写于 作者: L LielinJiang 提交者: GitHub

Implement basicvsr (#356)

* add basicvsr model
上级 058faa87
total_iters: 300000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: BasicVSRModel
fix_iter: 5000
generator:
name: BasicVSRNet
mid_channels: 64
num_blocks: 30
pixel_criterion:
name: CharbonnierLoss
reduction: mean
dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4 # 6
batch_size: 2 # 4*2
dataset:
name: SRREDSMultipleGTDataset
mode: train
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 15
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4
test:
name: SRREDSMultipleGTDataset
mode: test
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [300000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
# FIXME: avoid oom
interval: 5000000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False
log_config:
interval: 100
visiual_interval: 500
snapshot_config:
interval: 5000
......@@ -24,4 +24,5 @@ from .starganv2_dataset import StarGANv2Dataset
from .edvr_dataset import REDSDataset
from .firstorder_dataset import FirstOrderDataset
from .lapstyle_dataset import LapStyleDataset
from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset
from .mpr_dataset import MPRTrain, MPRVal, MPRTest
......@@ -19,11 +19,25 @@ import numpy as np
from paddle.distributed import ParallelEnv
from paddle.io import DistributedBatchSampler
from ..utils.registry import Registry
from .repeat_dataset import RepeatDataset
from ..utils.registry import Registry, build_from_config
DATASETS = Registry("DATASETS")
def build_dataset(cfg):
name = cfg.pop('name')
if name == 'RepeatDataset':
dataset_ = build_from_config(cfg['dataset'], DATASETS)
dataset = RepeatDataset(dataset_, cfg['times'])
else:
dataset = dataset = DATASETS.get(name)(**cfg)
return dataset
def build_dataloader(cfg, is_train=True, distributed=True):
cfg_ = cfg.copy()
......@@ -31,9 +45,7 @@ def build_dataloader(cfg, is_train=True, distributed=True):
num_workers = cfg_.pop('num_workers', 0)
use_shared_memory = cfg_.pop('use_shared_memory', True)
name = cfg_.pop('name')
dataset = DATASETS.get(name)(**cfg_)
dataset = build_dataset(cfg_)
if distributed:
sampler = DistributedBatchSampler(dataset,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class RepeatDataset(paddle.io.Dataset):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
return self.dataset[idx % self._ori_len]
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return self.times * self._ori_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import numpy as np
import cv2
from paddle.io import Dataset
from .builder import DATASETS
logger = logging.getLogger(__name__)
@DATASETS.register()
class SRREDSMultipleGTDataset(Dataset):
"""REDS dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
num_input_frames (int): Number of input frames.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
val_partition (str): Validation partition mode. Choices ['official' or
'REDS4']. Default: 'official'.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""
def __init__(self,
mode,
lq_folder,
gt_folder,
crop_size=256,
interval_list=[1],
random_reverse=False,
number_frames=15,
use_flip=False,
use_rot=False,
scale=4,
val_partition='REDS4',
batch_size=4):
super(SRREDSMultipleGTDataset, self).__init__()
self.mode = mode
self.fileroot = str(lq_folder)
self.gtroot = str(gt_folder)
self.crop_size = crop_size
self.interval_list = interval_list
self.random_reverse = random_reverse
self.number_frames = number_frames
self.use_flip = use_flip
self.use_rot = use_rot
self.scale = scale
self.val_partition = val_partition
self.batch_size = batch_size
self.data_infos = self.load_annotations()
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
item = self.data_infos[idx]
idt = random.randint(0, 100 - self.number_frames)
item = item + '_' + f'{idt:03d}'
img_LQs, img_GTs = self.get_sample_data(
item, self.number_frames, self.interval_list, self.random_reverse,
self.gtroot, self.fileroot, self.crop_size, self.scale,
self.use_flip, self.use_rot, self.mode)
return {'lq': img_LQs, 'gt': img_GTs, 'lq_path': self.data_infos[idx]}
def load_annotations(self):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys = [f'{i:03d}' for i in range(0, 270)]
if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif self.val_partition == 'official':
val_partition = [f'{i:03d}' for i in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {self.val_partition}.'
f'Supported ones are ["official", "REDS4"]')
if self.mode == 'train':
keys = [v for v in keys if v not in val_partition]
else:
keys = [v for v in keys if v in val_partition]
data_infos = []
for key in keys:
data_infos.append(key)
return data_infos
def get_sample_data(self,
item,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
frame_idxs = self.get_neighbor_frames(frame_name,
number_frames=number_frames,
interval_list=interval_list,
random_reverse=random_reverse)
frame_list = []
gt_list = []
for frame_idx in frame_idxs:
frame_idx_name = "%08d" % frame_idx
img = self.read_img(
os.path.join(fileroot, video_name, frame_idx_name + '.png'))
frame_list.append(img)
gt_img = self.read_img(
os.path.join(gtroot, video_name, frame_idx_name + '.png'))
gt_list.append(gt_img)
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
frame_list = [
v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
for v in frame_list
]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
gt_list = [
v[rnd_h_HR:rnd_h_HR + crop_size,
rnd_w_HR:rnd_w_HR + crop_size, :] for v in gt_list
]
# add random flip and rotation
for v in gt_list:
frame_list.append(v)
if (mode == 'train') or (mode == 'valid'):
rlt = self.img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:number_frames]
gt_list = rlt[number_frames:]
# stack LQ images to NHWC, N is the frame number
frame_list = [v.transpose(2, 0, 1).astype('float32') for v in frame_list]
gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list]
img_LQs = np.stack(frame_list, axis=0)
img_GTs = np.stack(gt_list, axis=0)
return img_LQs, img_GTs
def get_neighbor_frames(self, frame_name, number_frames, interval_list,
random_reverse):
frame_idx = int(frame_name)
interval = random.choice(interval_list)
neighbor_list = list(
range(frame_idx, frame_idx + number_frames, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list
def read_img(self, path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def img_augment(self, img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return len(self.data_infos)
......@@ -323,7 +323,6 @@ class Trainer:
is_save_image=False):
"""
visual the images, use visualdl or directly write to the directory
Parameters:
results_dir (str) -- directory name which contains saved images
visual_results (dict) -- the results images dict
......@@ -440,7 +439,6 @@ class Trainer:
def close(self):
"""
when finish the training need close file handler or other.
"""
if self.enable_visualdl:
self.vdl_logger.close()
self.vdl_logger.close()
\ No newline at end of file
......@@ -30,4 +30,5 @@ from .starganv2_model import StarGANv2Model
from .edvr_model import EDVRModel
from .firstorder_model import FirstOrderModel
from .lapstyle_model import LapStyleDraModel, LapStyleRevFirstModel, LapStyleRevSecondModel
from .basicvsr_model import BasicVSRModel
from .mpr_model import MPRModel
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .builder import MODELS
from .sr_model import BaseSRModel
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet
from ..modules.init import reset_parameters
from ..utils.visual import tensor2img
@MODELS.register()
class BasicVSRModel(BaseSRModel):
"""BasicVSR Model.
Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021
"""
def __init__(self, generator, fix_iter, pixel_criterion=None):
"""Initialize the BasicVSR class.
Args:
generator (dict): config of generator.
fix_iter (dict): config of fix_iter.
pixel_criterion (dict): config of pixel criterion.
"""
super(BasicVSRModel, self).__init__(generator, pixel_criterion)
self.fix_iter = fix_iter
self.current_iter = 1
self.flag = True
init_basicvsr_weight(self.nets['generator'])
def setup_input(self, input):
self.lq = paddle.to_tensor(input['lq'])
self.visual_items['lq'] = self.lq[:, 0, :, :, :]
if 'gt' in input:
self.gt = paddle.to_tensor(input['gt'])
self.visual_items['gt'] = self.gt[:, 0, :, :, :]
self.image_paths = input['lq_path']
def train_iter(self, optims=None):
optims['optim'].clear_grad()
if self.fix_iter:
if self.current_iter == 1:
print('Train BasicVSR with fixed spynet for', self.fix_iter,
'iters.')
for name, param in self.nets['generator'].named_parameters():
if 'spynet' in name:
param.trainable = False
elif self.current_iter >= self.fix_iter + 1 and self.flag:
print('Train all the parameters.')
for name, param in self.nets['generator'].named_parameters():
param.trainable = True
if 'spynet' in name:
param.optimize_attr['learning_rate'] = 0.125
self.flag = False
for net in self.nets.values():
net.find_unused_parameters = False
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output[:, 0, :, :, :]
# pixel loss
loss_pixel = self.pixel_criterion(self.output, self.gt)
loss_pixel.backward()
optims['optim'].step()
self.losses['loss_pixel'] = loss_pixel
self.current_iter += 1
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output[:, 0, :, :, :]
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output[0], self.gt[0]):
# print(out_tensor.shape, gt_tensor.shape)
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
def init_basicvsr_weight(net):
for m in net.children():
if hasattr(m, 'weight') and not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D)):
reset_parameters(m)
continue
if (not isinstance(
m, (ResidualBlockNoBN, PixelShufflePack, SPyNet))):
init_basicvsr_weight(m)
......@@ -59,8 +59,9 @@ class CharbonnierLoss():
eps (float): Default: 1e-12.
"""
def __init__(self, eps=1e-12):
def __init__(self, eps=1e-12, reduction='sum'):
self.eps = eps
self.reduction = reduction
def __call__(self, pred, target, **kwargs):
"""Forward Function.
......@@ -69,7 +70,14 @@ class CharbonnierLoss():
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
return paddle.sum(paddle.sqrt((pred - target)**2 + self.eps))
if self.reduction == 'sum':
out = paddle.sum(paddle.sqrt((pred - target)**2 + self.eps))
elif self.reduction == 'mean':
out = paddle.mean(paddle.sqrt((pred - target)**2 + self.eps))
else:
raise NotImplementedError('CharbonnierLoss %s not implemented' %
self.reduction)
return out
@CRITERIONS.register()
......
......@@ -30,4 +30,5 @@ from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Ma
from .edvr import EDVRNet
from .generator_firstorder import FirstOrderGenerator
from .generater_lapstyle import DecoderNet, Encoder, RevisionNet
from .basicvsr import BasicVSRNet
from .mpr import MPRNet
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册