未验证 提交 6e3dad37 编写于 作者: L LielinJiang 提交者: GitHub

add vimeo90k dataset and vsr_folder test dataset (#485)

* add vimeo90k dataset, vsr folder test dataset
上级 c21b08f5
total_iters: 600000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
# tensor range for function tensor2img
min_max:
(0., 1.)
model:
name: BasicVSRModel
fix_iter: 5000
lr_mult: 0.25
generator:
name: BasicVSRPlusPlus
mid_channels: 64
num_blocks: 7
is_low_res_input: True
pixel_criterion:
name: CharbonnierLoss
reduction: mean
dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4
batch_size: 1 #4 gpus
dataset:
name: VSRVimeo90KDataset
# mode: train
lq_folder: data/vimeo90k/vimeo_septuplet_BD_matlabLRx4/sequences
gt_folder: data/vimeo90k/vimeo_septuplet/sequences
ann_file: data/vimeo90k/vimeo_septuplet/sep_trainlist.txt
preprocess:
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 256
scale: 4
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: TransposeSequence
keys: [image, image]
- name: MirrorVideoSequence
- name: NormalizeSequence
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: VSRFolderDataset
# for udm10 dataset
# lq_folder: data/udm10/BDx4
# gt_folder: data/udm10/GT
lq_folder: data/Vid4/BDx4
gt_folder: data/Vid4/GT
preprocess:
- name: GetNeighboringFramesIdx
interval_list: [1]
# for udm10 dataset
# filename_tmpl: '{:04d}.png'
filename_tmpl: '{:08d}.png'
- name: ReadImageSequence
key: lq
- name: ReadImageSequence
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: TransposeSequence
keys: [image, image]
- name: NormalizeSequence
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 1e-4
periods: [600000]
restart_weights: [1]
eta_min: !!float 1e-7
optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99
validate:
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: true
ssim:
name: SSIM
crop_border: 0
test_y_channel: true
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
......@@ -26,4 +26,6 @@ from .firstorder_dataset import FirstOrderDataset
from .lapstyle_dataset import LapStyleDataset
from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset
from .mpr_dataset import MPRTrain, MPRVal, MPRTest
from .vsr_vimeo90k_dataset import VSRVimeo90KDataset
from .vsr_folder_dataset import VSRFolderDataset
from .photopen_dataset import PhotoPenDataset
from .io import LoadImageFromFile
from .io import LoadImageFromFile, ReadImageSequence, GetNeighboringFramesIdx
from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip,
PairedRandomVerticalFlip, PairedRandomTransposeHW,
SRPairedRandomCrop, SplitPairedImage, SRNoise)
SRPairedRandomCrop, SplitPairedImage, SRNoise,
NormalizeSequence, MirrorVideoSequence,
TransposeSequence)
from .builder import build_preprocess
# code was reference to mmcv
import os
import cv2
import numpy as np
from .builder import PREPROCESS
......@@ -9,12 +10,12 @@ class LoadImageFromFile(object):
"""Load image from file.
Args:
key (str): Keys in results to find corresponding path. Default: 'image'.
key (str): Keys in datas to find corresponding path. Default: 'image'.
flag (str): Loading flag for images. Default: -1.
to_rgb (str): Convert img to 'rgb' format. Default: True.
backend (str): io backend where images are store. Default: None.
save_original_img (bool): If True, maintain a copy of the image in
`results` dict with name of `f'ori_{key}'`. Default: False.
`datas` dict with name of `f'ori_{key}'`. Default: False.
kwargs (dict): Args for file client.
"""
def __init__(self,
......@@ -31,28 +32,150 @@ class LoadImageFromFile(object):
self.save_original_img = save_original_img
self.kwargs = kwargs
def __call__(self, results):
def __call__(self, datas):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
datas (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
filepath = str(results[f'{self.key}_path'])
filepath = str(datas[f'{self.key}_path'])
#TODO: use file client to manage io backend
# such as opencv, pil, imdb
img = cv2.imread(filepath, self.flag)
if self.to_rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results[self.key] = img
results[f'{self.key}_path'] = filepath
results[f'{self.key}_ori_shape'] = img.shape
datas[self.key] = img
datas[f'{self.key}_path'] = filepath
datas[f'{self.key}_ori_shape'] = img.shape
if self.save_original_img:
results[f'ori_{self.key}'] = img.copy()
datas[f'ori_{self.key}'] = img.copy()
return datas
@PREPROCESS.register()
class ReadImageSequence(LoadImageFromFile):
"""Read image sequence.
It accepts a list of path and read each frame from each path. A list
of frames will be returned.
Args:
key (str): Keys in datas to find corresponding path. Default: 'gt'.
flag (str): Loading flag for images. Default: 'color'.
to_rgb (str): Convert img to 'rgb' format. Default: True.
save_original_img (bool): If True, maintain a copy of the image in
`datas` dict with name of `f'ori_{key}'`. Default: False.
kwargs (dict): Args for file client.
"""
def __call__(self, datas):
"""Call function.
Args:
datas (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
filepaths = datas[f'{self.key}_path']
if not isinstance(filepaths, list):
raise TypeError(
f'filepath should be list, but got {type(filepaths)}')
filepaths = [str(v) for v in filepaths]
imgs = []
shapes = []
if self.save_original_img:
ori_imgs = []
for filepath in filepaths:
img = cv2.imread(filepath, self.flag)
if self.to_rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
imgs.append(img)
shapes.append(img.shape)
if self.save_original_img:
ori_imgs.append(img.copy())
datas[self.key] = imgs
datas[f'{self.key}_path'] = filepaths
datas[f'{self.key}_ori_shape'] = shapes
if self.save_original_img:
datas[f'ori_{self.key}'] = ori_imgs
return datas
@PREPROCESS.register()
class GetNeighboringFramesIdx:
"""Get neighboring frame indices for a video. It also performs temporal
augmention with random interval.
Args:
interval_list (list[int]): Interval list for temporal augmentation.
It will randomly pick an interval from interval_list and sample
frame index with the interval.
start_idx (int): The index corresponds to the first frame in the
sequence. Default: 0.
filename_tmpl (str): Template for file name. Default: '{:08d}.png'.
"""
def __init__(self, interval_list, start_idx=0, filename_tmpl='{:08d}.png'):
self.interval_list = interval_list
self.filename_tmpl = filename_tmpl
self.start_idx = start_idx
def __call__(self, datas):
"""Call function.
Args:
datas (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
clip_name = datas['key']
interval = np.random.choice(self.interval_list)
self.sequence_length = datas['sequence_length']
num_frames = datas.get('num_frames', self.sequence_length)
if self.sequence_length - num_frames * interval < 0:
raise ValueError('The input sequence is not long enough to '
'support the current choice of [interval] or '
'[num_frames].')
start_frame_idx = np.random.randint(
0, self.sequence_length - num_frames * interval + 1)
end_frame_idx = start_frame_idx + num_frames * interval
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
neighbor_list = [v + self.start_idx for v in neighbor_list]
lq_path_root = datas['lq_path']
gt_path_root = datas['gt_path']
lq_path = [
os.path.join(lq_path_root, clip_name, self.filename_tmpl.format(v))
for v in neighbor_list
]
gt_path = [
os.path.join(gt_path_root, clip_name, self.filename_tmpl.format(v))
for v in neighbor_list
]
datas['lq_path'] = lq_path
datas['gt_path'] = gt_path
datas['interval'] = interval
return results
return datas
......@@ -55,6 +55,7 @@ class Transforms():
def __call__(self, datas):
data = []
for k in self.input_keys:
data.append(datas[k])
data = tuple(data)
......@@ -133,7 +134,10 @@ class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
if isinstance(image, list):
image = [F.hflip(v) for v in image]
else:
return F.hflip(image)
return image
......@@ -149,7 +153,10 @@ class PairedRandomVerticalFlip(T.RandomHorizontalFlip):
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
if isinstance(image, list):
image = [F.vflip(v) for v in image]
else:
return F.vflip(image)
return image
......@@ -180,10 +187,108 @@ class PairedRandomTransposeHW(T.BaseTransform):
def _apply_image(self, image):
if self.params['transpose']:
image = image.transpose(1, 0, 2)
if isinstance(image, list):
image = [v.transpose(1, 0, 2) for v in image]
else:
image = image.transpose(1, 0, 2)
return image
@TRANSFORMS.register()
class TransposeSequence(T.Transpose):
"""Transpose input data or a video sequence to a target format.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
output image will be an instance of numpy.ndarray.
Args:
order (list|tuple, optional): Target order of input data. Default: (2, 0, 1).
keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None.
Examples:
.. code-block:: python
import numpy as np
from PIL import Image
transform = TransposeSequence()
fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8))
fake_img_seq = [fake_img, fake_img, fake_img]
fake_img_seq = transform(fake_img_seq)
"""
def _apply_image(self, img):
if isinstance(img, list):
imgs = []
for im in img:
if F._is_tensor_image(im):
return im.transpose(self.order)
if F._is_pil_image(im):
im = np.asarray(im)
if len(im.shape) == 2:
im = im[..., np.newaxis]
imgs.append(im.transpose(self.order))
return imgs
else:
if F._is_tensor_image(img):
return img.transpose(self.order)
if F._is_pil_image(img):
img = np.asarray(img)
if len(img.shape) == 2:
img = img[..., np.newaxis]
return img.transpose(self.order)
@TRANSFORMS.register()
class NormalizeSequence(T.Normalize):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list|tuple): Sequence of means for each channel.
std (int|float|list|tuple): Sequence of standard deviations for each channel.
data_format (str, optional): Data format of img, should be 'HWC' or
'CHW'. Default: 'CHW'.
to_rgb (bool, optional): Whether to convert to rgb. Default: False.
keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None.
Examples:
.. code-block:: python
import numpy as np
from PIL import Image
normalize_seq = NormalizeSequence(mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
data_format='HWC')
fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8))
fake_img_seq = [fake_img, fake_img, fake_img]
fake_img_seq = normalize_seq(fake_img_seq)
"""
def _apply_image(self, img):
if isinstance(img, list):
imgs = [
F.normalize(v, self.mean, self.std, self.data_format,
self.to_rgb) for v in img
]
return np.stack(imgs, axis=0).astype('float32')
return F.normalize(img, self.mean, self.std, self.data_format,
self.to_rgb)
@TRANSFORMS.register()
class SRPairedRandomCrop(T.BaseTransform):
"""Super resolution random crop.
......@@ -204,15 +309,19 @@ class SRPairedRandomCrop(T.BaseTransform):
self.scale_list = scale_list
def __call__(self, inputs):
"""inputs must be (lq_img, gt_img)"""
"""inputs must be (lq_img or list[lq_img], gt_img or list[gt_img])"""
scale = self.scale
lq_patch_size = self.gt_patch_size // scale
lq = inputs[0]
gt = inputs[1]
h_lq, w_lq, _ = lq.shape
h_gt, w_gt, _ = gt.shape
if isinstance(lq, list):
h_lq, w_lq, _ = lq[0].shape
h_gt, w_gt, _ = gt[0].shape
else:
h_lq, w_lq, _ = lq.shape
h_gt, w_gt, _ = gt.shape
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError('scale size not match')
......@@ -222,18 +331,30 @@ class SRPairedRandomCrop(T.BaseTransform):
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
lq = lq[top:top + lq_patch_size, left:left + lq_patch_size, ...]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
gt = gt[top_gt:top_gt + self.gt_patch_size,
left_gt:left_gt + self.gt_patch_size, ...]
if self.scale_list and self.scale == 4:
lqx2 = F.resize(gt, (lq_patch_size * 2, lq_patch_size * 2),
'bicubic')
outputs = (lq, lqx2, gt)
return outputs
if isinstance(lq, list):
lq = [
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
for v in lq
]
top_gt, left_gt = int(top * scale), int(left * scale)
gt = [
v[top_gt:top_gt + self.gt_patch_size,
left_gt:left_gt + self.gt_patch_size, ...] for v in gt
]
else:
# crop lq patch
lq = lq[top:top + lq_patch_size, left:left + lq_patch_size, ...]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
gt = gt[top_gt:top_gt + self.gt_patch_size,
left_gt:left_gt + self.gt_patch_size, ...]
if self.scale_list and self.scale == 4:
lqx2 = F.resize(gt, (lq_patch_size * 2, lq_patch_size * 2),
'bicubic')
outputs = (lq, lqx2, gt)
return outputs
outputs = (lq, gt)
return outputs
......@@ -411,3 +532,36 @@ class PairedColorJitter(T.BaseTransform):
for f in self.params:
img = f(img)
return img
@TRANSFORMS.register()
class MirrorVideoSequence:
"""Double a short video sequences by mirroring the sequences
Example:
Given a sequence with N frames (x1, ..., xN), extend the
sequence to (x1, ..., xN, xN, ..., x1).
Args:
keys (list[str]): The frame lists to be extended.
"""
def __init__(self, keys=None):
self.keys = keys
def __call__(self, datas):
"""Call function.
Args:
datas (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
lrs, hrs = datas
assert isinstance(lrs, list) and isinstance(hrs, list)
lrs = lrs + lrs[::-1]
hrs = hrs + hrs[::-1]
return (lrs, hrs)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import glob
import random
import logging
import numpy as np
from paddle.io import Dataset
from .base_sr_dataset import BaseDataset
from .builder import DATASETS
logger = logging.getLogger(__name__)
@DATASETS.register()
class VSRFolderDataset(BaseDataset):
"""Video super-resolution for folder format.
Args:
lq_folder (str): Path to a low quality image folder.
gt_folder (str): Path to a ground truth image folder.
ann_file (str): Path to the annotation file.
preprocess (list[dict|callable]): A list functions of data transformations.
num_frames (int): Number of frames of each input clip.
times (int): Repeat times of datset length.
"""
def __init__(self,
lq_folder,
gt_folder,
preprocess,
num_frames=None,
times=1):
super().__init__(preprocess)
self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.num_frames = num_frames
self.times = times
self.data_infos = self.prepare_data_infos()
def prepare_data_infos(self):
sequences = sorted(glob.glob(os.path.join(self.lq_folder, '*')))
data_infos = []
for sequence in sequences:
sequence_length = len(glob.glob(os.path.join(sequence, '*.png')))
if self.num_frames is None:
num_frames = sequence_length
else:
num_frames = self.num_frames
data_infos.append(
dict(lq_path=self.lq_folder,
gt_path=self.gt_folder,
key=sequence.replace(f'{self.lq_folder}/', ''),
num_frames=num_frames,
sequence_length=sequence_length))
return data_infos
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import glob
import random
import logging
import numpy as np
from paddle.io import Dataset
from .base_sr_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register()
class VSRVimeo90KDataset(BaseDataset):
"""Vimeo90K dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
It reads Vimeo90K keys from the txt file. Each line contains video frame folder
Examples:
00001/0233
00001/0234
Args:
lq_folder (str): Path to a low quality image folder.
gt_folder (str): Path to a ground truth image folder.
ann_file (str): Path to the annotation file.
preprocess (list[dict|callable]): A list functions of data transformations.
"""
def __init__(self, lq_folder, gt_folder, ann_file, preprocess):
super().__init__(preprocess)
self.lq_folder = str(lq_folder)
self.gt_folder = str(gt_folder)
self.ann_file = str(ann_file)
self.data_infos = self.prepare_data_infos()
def prepare_data_infos(self):
with open(self.ann_file, 'r') as fin:
keys = [line.strip() for line in fin]
data_infos = []
for key in keys:
lq_paths = sorted(
glob.glob(os.path.join(self.lq_folder, key, '*.png')))
gt_paths = sorted(
glob.glob(os.path.join(self.gt_folder, key, '*.png')))
data_infos.append(dict(lq_path=lq_paths, gt_path=gt_paths, key=key))
return data_infos
......@@ -30,17 +30,26 @@ class PSNR(paddle.metric.Metric):
def reset(self):
self.results = []
def update(self, preds, gts):
def update(self, preds, gts, is_seq=False):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
if is_seq:
single_seq = []
for pred, gt in zip(preds, gts):
value = calculate_psnr(pred, gt, self.crop_border, self.input_order,
self.test_y_channel)
self.results.append(value)
if is_seq:
single_seq.append(value)
else:
self.results.append(value)
if is_seq:
self.results.append(np.mean(single_seq))
def accumulate(self):
if paddle.distributed.get_world_size() > 1:
......@@ -59,17 +68,26 @@ class PSNR(paddle.metric.Metric):
@METRICS.register()
class SSIM(PSNR):
def update(self, preds, gts):
def update(self, preds, gts, is_seq=False):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
if is_seq:
single_seq = []
for pred, gt in zip(preds, gts):
value = calculate_ssim(pred, gt, self.crop_border, self.input_order,
self.test_y_channel)
self.results.append(value)
if is_seq:
single_seq.append(value)
else:
self.results.append(value)
if is_seq:
self.results.append(np.mean(single_seq))
def name(self):
return 'SSIM'
......
......@@ -25,7 +25,7 @@ from ..utils.visual import tensor2img
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
r"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class.
-- <setup_input>: unpack data from dataset and apply preprocessing.
......
......@@ -103,7 +103,7 @@ class BasicVSRModel(BaseSRModel):
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
metric.update(out_img, gt_img, is_seq=True)
def init_basicvsr_weight(net):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册