Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
6e3dad37
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6e3dad37
编写于
11月 18, 2021
作者:
L
LielinJiang
提交者:
GitHub
11月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vimeo90k dataset and vsr_folder test dataset (#485)
* add vimeo90k dataset, vsr folder test dataset
上级
c21b08f5
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
602 addition
and
37 deletion
+602
-37
configs/basicvsr++_vimeo90k_BD.yaml
configs/basicvsr++_vimeo90k_BD.yaml
+122
-0
ppgan/datasets/__init__.py
ppgan/datasets/__init__.py
+2
-0
ppgan/datasets/preprocess/__init__.py
ppgan/datasets/preprocess/__init__.py
+4
-2
ppgan/datasets/preprocess/io.py
ppgan/datasets/preprocess/io.py
+134
-11
ppgan/datasets/preprocess/transforms.py
ppgan/datasets/preprocess/transforms.py
+172
-18
ppgan/datasets/vsr_folder_dataset.py
ppgan/datasets/vsr_folder_dataset.py
+73
-0
ppgan/datasets/vsr_vimeo90k_dataset.py
ppgan/datasets/vsr_vimeo90k_dataset.py
+71
-0
ppgan/metrics/psnr_ssim.py
ppgan/metrics/psnr_ssim.py
+22
-4
ppgan/models/base_model.py
ppgan/models/base_model.py
+1
-1
ppgan/models/basicvsr_model.py
ppgan/models/basicvsr_model.py
+1
-1
未找到文件。
configs/basicvsr++_vimeo90k_BD.yaml
0 → 100644
浏览文件 @
6e3dad37
total_iters
:
600000
output_dir
:
output_dir
find_unused_parameters
:
True
checkpoints_dir
:
checkpoints
# tensor range for function tensor2img
min_max
:
(0., 1.)
model
:
name
:
BasicVSRModel
fix_iter
:
5000
lr_mult
:
0.25
generator
:
name
:
BasicVSRPlusPlus
mid_channels
:
64
num_blocks
:
7
is_low_res_input
:
True
pixel_criterion
:
name
:
CharbonnierLoss
reduction
:
mean
dataset
:
train
:
name
:
RepeatDataset
times
:
1000
num_workers
:
4
batch_size
:
1
#4 gpus
dataset
:
name
:
VSRVimeo90KDataset
# mode: train
lq_folder
:
data/vimeo90k/vimeo_septuplet_BD_matlabLRx4/sequences
gt_folder
:
data/vimeo90k/vimeo_septuplet/sequences
ann_file
:
data/vimeo90k/vimeo_septuplet/sep_trainlist.txt
preprocess
:
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
MirrorVideoSequence
-
name
:
NormalizeSequence
mean
:
[
0.
,
.0
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
VSRFolderDataset
# for udm10 dataset
# lq_folder: data/udm10/BDx4
# gt_folder: data/udm10/GT
lq_folder
:
data/Vid4/BDx4
gt_folder
:
data/Vid4/GT
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
# for udm10 dataset
# filename_tmpl: '{:04d}.png'
filename_tmpl
:
'
{:08d}.png'
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
.0
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
learning_rate
:
!!float
1e-4
periods
:
[
600000
]
restart_weights
:
[
1
]
eta_min
:
!!float
1e-7
optimizer
:
name
:
Adam
# add parameters of net_name to optim
# name should in self.nets
net_names
:
-
generator
beta1
:
0.9
beta2
:
0.99
validate
:
interval
:
5000
save_img
:
false
metrics
:
psnr
:
# metric name, can be arbitrary
name
:
PSNR
crop_border
:
0
test_y_channel
:
true
ssim
:
name
:
SSIM
crop_border
:
0
test_y_channel
:
true
log_config
:
interval
:
10
visiual_interval
:
500
snapshot_config
:
interval
:
5000
ppgan/datasets/__init__.py
浏览文件 @
6e3dad37
...
...
@@ -26,4 +26,6 @@ from .firstorder_dataset import FirstOrderDataset
from
.lapstyle_dataset
import
LapStyleDataset
from
.sr_reds_multiple_gt_dataset
import
SRREDSMultipleGTDataset
from
.mpr_dataset
import
MPRTrain
,
MPRVal
,
MPRTest
from
.vsr_vimeo90k_dataset
import
VSRVimeo90KDataset
from
.vsr_folder_dataset
import
VSRFolderDataset
from
.photopen_dataset
import
PhotoPenDataset
ppgan/datasets/preprocess/__init__.py
浏览文件 @
6e3dad37
from
.io
import
LoadImageFromFile
from
.io
import
LoadImageFromFile
,
ReadImageSequence
,
GetNeighboringFramesIdx
from
.transforms
import
(
PairedRandomCrop
,
PairedRandomHorizontalFlip
,
PairedRandomVerticalFlip
,
PairedRandomTransposeHW
,
SRPairedRandomCrop
,
SplitPairedImage
,
SRNoise
)
SRPairedRandomCrop
,
SplitPairedImage
,
SRNoise
,
NormalizeSequence
,
MirrorVideoSequence
,
TransposeSequence
)
from
.builder
import
build_preprocess
ppgan/datasets/preprocess/io.py
浏览文件 @
6e3dad37
# code was reference to mmcv
import
os
import
cv2
import
numpy
as
np
from
.builder
import
PREPROCESS
...
...
@@ -9,12 +10,12 @@ class LoadImageFromFile(object):
"""Load image from file.
Args:
key (str): Keys in
result
s to find corresponding path. Default: 'image'.
key (str): Keys in
data
s to find corresponding path. Default: 'image'.
flag (str): Loading flag for images. Default: -1.
to_rgb (str): Convert img to 'rgb' format. Default: True.
backend (str): io backend where images are store. Default: None.
save_original_img (bool): If True, maintain a copy of the image in
`
result
s` dict with name of `f'ori_{key}'`. Default: False.
`
data
s` dict with name of `f'ori_{key}'`. Default: False.
kwargs (dict): Args for file client.
"""
def
__init__
(
self
,
...
...
@@ -31,28 +32,150 @@ class LoadImageFromFile(object):
self
.
save_original_img
=
save_original_img
self
.
kwargs
=
kwargs
def
__call__
(
self
,
result
s
):
def
__call__
(
self
,
data
s
):
"""Call function.
Args:
result
s (dict): A dict containing the necessary information and
data
s (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
filepath
=
str
(
result
s
[
f
'
{
self
.
key
}
_path'
])
filepath
=
str
(
data
s
[
f
'
{
self
.
key
}
_path'
])
#TODO: use file client to manage io backend
# such as opencv, pil, imdb
img
=
cv2
.
imread
(
filepath
,
self
.
flag
)
if
self
.
to_rgb
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
result
s
[
self
.
key
]
=
img
result
s
[
f
'
{
self
.
key
}
_path'
]
=
filepath
result
s
[
f
'
{
self
.
key
}
_ori_shape'
]
=
img
.
shape
data
s
[
self
.
key
]
=
img
data
s
[
f
'
{
self
.
key
}
_path'
]
=
filepath
data
s
[
f
'
{
self
.
key
}
_ori_shape'
]
=
img
.
shape
if
self
.
save_original_img
:
results
[
f
'ori_
{
self
.
key
}
'
]
=
img
.
copy
()
datas
[
f
'ori_
{
self
.
key
}
'
]
=
img
.
copy
()
return
datas
@
PREPROCESS
.
register
()
class
ReadImageSequence
(
LoadImageFromFile
):
"""Read image sequence.
It accepts a list of path and read each frame from each path. A list
of frames will be returned.
Args:
key (str): Keys in datas to find corresponding path. Default: 'gt'.
flag (str): Loading flag for images. Default: 'color'.
to_rgb (str): Convert img to 'rgb' format. Default: True.
save_original_img (bool): If True, maintain a copy of the image in
`datas` dict with name of `f'ori_{key}'`. Default: False.
kwargs (dict): Args for file client.
"""
def
__call__
(
self
,
datas
):
"""Call function.
Args:
datas (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
filepaths
=
datas
[
f
'
{
self
.
key
}
_path'
]
if
not
isinstance
(
filepaths
,
list
):
raise
TypeError
(
f
'filepath should be list, but got
{
type
(
filepaths
)
}
'
)
filepaths
=
[
str
(
v
)
for
v
in
filepaths
]
imgs
=
[]
shapes
=
[]
if
self
.
save_original_img
:
ori_imgs
=
[]
for
filepath
in
filepaths
:
img
=
cv2
.
imread
(
filepath
,
self
.
flag
)
if
self
.
to_rgb
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
imgs
.
append
(
img
)
shapes
.
append
(
img
.
shape
)
if
self
.
save_original_img
:
ori_imgs
.
append
(
img
.
copy
())
datas
[
self
.
key
]
=
imgs
datas
[
f
'
{
self
.
key
}
_path'
]
=
filepaths
datas
[
f
'
{
self
.
key
}
_ori_shape'
]
=
shapes
if
self
.
save_original_img
:
datas
[
f
'ori_
{
self
.
key
}
'
]
=
ori_imgs
return
datas
@
PREPROCESS
.
register
()
class
GetNeighboringFramesIdx
:
"""Get neighboring frame indices for a video. It also performs temporal
augmention with random interval.
Args:
interval_list (list[int]): Interval list for temporal augmentation.
It will randomly pick an interval from interval_list and sample
frame index with the interval.
start_idx (int): The index corresponds to the first frame in the
sequence. Default: 0.
filename_tmpl (str): Template for file name. Default: '{:08d}.png'.
"""
def
__init__
(
self
,
interval_list
,
start_idx
=
0
,
filename_tmpl
=
'{:08d}.png'
):
self
.
interval_list
=
interval_list
self
.
filename_tmpl
=
filename_tmpl
self
.
start_idx
=
start_idx
def
__call__
(
self
,
datas
):
"""Call function.
Args:
datas (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
clip_name
=
datas
[
'key'
]
interval
=
np
.
random
.
choice
(
self
.
interval_list
)
self
.
sequence_length
=
datas
[
'sequence_length'
]
num_frames
=
datas
.
get
(
'num_frames'
,
self
.
sequence_length
)
if
self
.
sequence_length
-
num_frames
*
interval
<
0
:
raise
ValueError
(
'The input sequence is not long enough to '
'support the current choice of [interval] or '
'[num_frames].'
)
start_frame_idx
=
np
.
random
.
randint
(
0
,
self
.
sequence_length
-
num_frames
*
interval
+
1
)
end_frame_idx
=
start_frame_idx
+
num_frames
*
interval
neighbor_list
=
list
(
range
(
start_frame_idx
,
end_frame_idx
,
interval
))
neighbor_list
=
[
v
+
self
.
start_idx
for
v
in
neighbor_list
]
lq_path_root
=
datas
[
'lq_path'
]
gt_path_root
=
datas
[
'gt_path'
]
lq_path
=
[
os
.
path
.
join
(
lq_path_root
,
clip_name
,
self
.
filename_tmpl
.
format
(
v
))
for
v
in
neighbor_list
]
gt_path
=
[
os
.
path
.
join
(
gt_path_root
,
clip_name
,
self
.
filename_tmpl
.
format
(
v
))
for
v
in
neighbor_list
]
datas
[
'lq_path'
]
=
lq_path
datas
[
'gt_path'
]
=
gt_path
datas
[
'interval'
]
=
interval
return
result
s
return
data
s
ppgan/datasets/preprocess/transforms.py
浏览文件 @
6e3dad37
...
...
@@ -55,6 +55,7 @@ class Transforms():
def
__call__
(
self
,
datas
):
data
=
[]
for
k
in
self
.
input_keys
:
data
.
append
(
datas
[
k
])
data
=
tuple
(
data
)
...
...
@@ -133,7 +134,10 @@ class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def
_apply_image
(
self
,
image
):
if
self
.
params
[
'flip'
]:
return
F
.
hflip
(
image
)
if
isinstance
(
image
,
list
):
image
=
[
F
.
hflip
(
v
)
for
v
in
image
]
else
:
return
F
.
hflip
(
image
)
return
image
...
...
@@ -149,7 +153,10 @@ class PairedRandomVerticalFlip(T.RandomHorizontalFlip):
def
_apply_image
(
self
,
image
):
if
self
.
params
[
'flip'
]:
return
F
.
hflip
(
image
)
if
isinstance
(
image
,
list
):
image
=
[
F
.
vflip
(
v
)
for
v
in
image
]
else
:
return
F
.
vflip
(
image
)
return
image
...
...
@@ -180,10 +187,108 @@ class PairedRandomTransposeHW(T.BaseTransform):
def
_apply_image
(
self
,
image
):
if
self
.
params
[
'transpose'
]:
image
=
image
.
transpose
(
1
,
0
,
2
)
if
isinstance
(
image
,
list
):
image
=
[
v
.
transpose
(
1
,
0
,
2
)
for
v
in
image
]
else
:
image
=
image
.
transpose
(
1
,
0
,
2
)
return
image
@
TRANSFORMS
.
register
()
class
TransposeSequence
(
T
.
Transpose
):
"""Transpose input data or a video sequence to a target format.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
output image will be an instance of numpy.ndarray.
Args:
order (list|tuple, optional): Target order of input data. Default: (2, 0, 1).
keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None.
Examples:
.. code-block:: python
import numpy as np
from PIL import Image
transform = TransposeSequence()
fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8))
fake_img_seq = [fake_img, fake_img, fake_img]
fake_img_seq = transform(fake_img_seq)
"""
def
_apply_image
(
self
,
img
):
if
isinstance
(
img
,
list
):
imgs
=
[]
for
im
in
img
:
if
F
.
_is_tensor_image
(
im
):
return
im
.
transpose
(
self
.
order
)
if
F
.
_is_pil_image
(
im
):
im
=
np
.
asarray
(
im
)
if
len
(
im
.
shape
)
==
2
:
im
=
im
[...,
np
.
newaxis
]
imgs
.
append
(
im
.
transpose
(
self
.
order
))
return
imgs
else
:
if
F
.
_is_tensor_image
(
img
):
return
img
.
transpose
(
self
.
order
)
if
F
.
_is_pil_image
(
img
):
img
=
np
.
asarray
(
img
)
if
len
(
img
.
shape
)
==
2
:
img
=
img
[...,
np
.
newaxis
]
return
img
.
transpose
(
self
.
order
)
@
TRANSFORMS
.
register
()
class
NormalizeSequence
(
T
.
Normalize
):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list|tuple): Sequence of means for each channel.
std (int|float|list|tuple): Sequence of standard deviations for each channel.
data_format (str, optional): Data format of img, should be 'HWC' or
'CHW'. Default: 'CHW'.
to_rgb (bool, optional): Whether to convert to rgb. Default: False.
keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None.
Examples:
.. code-block:: python
import numpy as np
from PIL import Image
normalize_seq = NormalizeSequence(mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
data_format='HWC')
fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8))
fake_img_seq = [fake_img, fake_img, fake_img]
fake_img_seq = normalize_seq(fake_img_seq)
"""
def
_apply_image
(
self
,
img
):
if
isinstance
(
img
,
list
):
imgs
=
[
F
.
normalize
(
v
,
self
.
mean
,
self
.
std
,
self
.
data_format
,
self
.
to_rgb
)
for
v
in
img
]
return
np
.
stack
(
imgs
,
axis
=
0
).
astype
(
'float32'
)
return
F
.
normalize
(
img
,
self
.
mean
,
self
.
std
,
self
.
data_format
,
self
.
to_rgb
)
@
TRANSFORMS
.
register
()
class
SRPairedRandomCrop
(
T
.
BaseTransform
):
"""Super resolution random crop.
...
...
@@ -204,15 +309,19 @@ class SRPairedRandomCrop(T.BaseTransform):
self
.
scale_list
=
scale_list
def
__call__
(
self
,
inputs
):
"""inputs must be (lq_img
, gt_img
)"""
"""inputs must be (lq_img
or list[lq_img], gt_img or list[gt_img]
)"""
scale
=
self
.
scale
lq_patch_size
=
self
.
gt_patch_size
//
scale
lq
=
inputs
[
0
]
gt
=
inputs
[
1
]
h_lq
,
w_lq
,
_
=
lq
.
shape
h_gt
,
w_gt
,
_
=
gt
.
shape
if
isinstance
(
lq
,
list
):
h_lq
,
w_lq
,
_
=
lq
[
0
].
shape
h_gt
,
w_gt
,
_
=
gt
[
0
].
shape
else
:
h_lq
,
w_lq
,
_
=
lq
.
shape
h_gt
,
w_gt
,
_
=
gt
.
shape
if
h_gt
!=
h_lq
*
scale
or
w_gt
!=
w_lq
*
scale
:
raise
ValueError
(
'scale size not match'
)
...
...
@@ -222,18 +331,30 @@ class SRPairedRandomCrop(T.BaseTransform):
# randomly choose top and left coordinates for lq patch
top
=
random
.
randint
(
0
,
h_lq
-
lq_patch_size
)
left
=
random
.
randint
(
0
,
w_lq
-
lq_patch_size
)
# crop lq patch
lq
=
lq
[
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
,
...]
# crop corresponding gt patch
top_gt
,
left_gt
=
int
(
top
*
scale
),
int
(
left
*
scale
)
gt
=
gt
[
top_gt
:
top_gt
+
self
.
gt_patch_size
,
left_gt
:
left_gt
+
self
.
gt_patch_size
,
...]
if
self
.
scale_list
and
self
.
scale
==
4
:
lqx2
=
F
.
resize
(
gt
,
(
lq_patch_size
*
2
,
lq_patch_size
*
2
),
'bicubic'
)
outputs
=
(
lq
,
lqx2
,
gt
)
return
outputs
if
isinstance
(
lq
,
list
):
lq
=
[
v
[
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
,
...]
for
v
in
lq
]
top_gt
,
left_gt
=
int
(
top
*
scale
),
int
(
left
*
scale
)
gt
=
[
v
[
top_gt
:
top_gt
+
self
.
gt_patch_size
,
left_gt
:
left_gt
+
self
.
gt_patch_size
,
...]
for
v
in
gt
]
else
:
# crop lq patch
lq
=
lq
[
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
,
...]
# crop corresponding gt patch
top_gt
,
left_gt
=
int
(
top
*
scale
),
int
(
left
*
scale
)
gt
=
gt
[
top_gt
:
top_gt
+
self
.
gt_patch_size
,
left_gt
:
left_gt
+
self
.
gt_patch_size
,
...]
if
self
.
scale_list
and
self
.
scale
==
4
:
lqx2
=
F
.
resize
(
gt
,
(
lq_patch_size
*
2
,
lq_patch_size
*
2
),
'bicubic'
)
outputs
=
(
lq
,
lqx2
,
gt
)
return
outputs
outputs
=
(
lq
,
gt
)
return
outputs
...
...
@@ -411,3 +532,36 @@ class PairedColorJitter(T.BaseTransform):
for
f
in
self
.
params
:
img
=
f
(
img
)
return
img
@
TRANSFORMS
.
register
()
class
MirrorVideoSequence
:
"""Double a short video sequences by mirroring the sequences
Example:
Given a sequence with N frames (x1, ..., xN), extend the
sequence to (x1, ..., xN, xN, ..., x1).
Args:
keys (list[str]): The frame lists to be extended.
"""
def
__init__
(
self
,
keys
=
None
):
self
.
keys
=
keys
def
__call__
(
self
,
datas
):
"""Call function.
Args:
datas (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
lrs
,
hrs
=
datas
assert
isinstance
(
lrs
,
list
)
and
isinstance
(
hrs
,
list
)
lrs
=
lrs
+
lrs
[::
-
1
]
hrs
=
hrs
+
hrs
[::
-
1
]
return
(
lrs
,
hrs
)
ppgan/datasets/vsr_folder_dataset.py
0 → 100644
浏览文件 @
6e3dad37
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
glob
import
random
import
logging
import
numpy
as
np
from
paddle.io
import
Dataset
from
.base_sr_dataset
import
BaseDataset
from
.builder
import
DATASETS
logger
=
logging
.
getLogger
(
__name__
)
@
DATASETS
.
register
()
class
VSRFolderDataset
(
BaseDataset
):
"""Video super-resolution for folder format.
Args:
lq_folder (str): Path to a low quality image folder.
gt_folder (str): Path to a ground truth image folder.
ann_file (str): Path to the annotation file.
preprocess (list[dict|callable]): A list functions of data transformations.
num_frames (int): Number of frames of each input clip.
times (int): Repeat times of datset length.
"""
def
__init__
(
self
,
lq_folder
,
gt_folder
,
preprocess
,
num_frames
=
None
,
times
=
1
):
super
().
__init__
(
preprocess
)
self
.
lq_folder
=
str
(
lq_folder
)
self
.
gt_folder
=
str
(
gt_folder
)
self
.
num_frames
=
num_frames
self
.
times
=
times
self
.
data_infos
=
self
.
prepare_data_infos
()
def
prepare_data_infos
(
self
):
sequences
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
self
.
lq_folder
,
'*'
)))
data_infos
=
[]
for
sequence
in
sequences
:
sequence_length
=
len
(
glob
.
glob
(
os
.
path
.
join
(
sequence
,
'*.png'
)))
if
self
.
num_frames
is
None
:
num_frames
=
sequence_length
else
:
num_frames
=
self
.
num_frames
data_infos
.
append
(
dict
(
lq_path
=
self
.
lq_folder
,
gt_path
=
self
.
gt_folder
,
key
=
sequence
.
replace
(
f
'
{
self
.
lq_folder
}
/'
,
''
),
num_frames
=
num_frames
,
sequence_length
=
sequence_length
))
return
data_infos
ppgan/datasets/vsr_vimeo90k_dataset.py
0 → 100644
浏览文件 @
6e3dad37
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
glob
import
random
import
logging
import
numpy
as
np
from
paddle.io
import
Dataset
from
.base_sr_dataset
import
BaseDataset
from
.builder
import
DATASETS
@
DATASETS
.
register
()
class
VSRVimeo90KDataset
(
BaseDataset
):
"""Vimeo90K dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
It reads Vimeo90K keys from the txt file. Each line contains video frame folder
Examples:
00001/0233
00001/0234
Args:
lq_folder (str): Path to a low quality image folder.
gt_folder (str): Path to a ground truth image folder.
ann_file (str): Path to the annotation file.
preprocess (list[dict|callable]): A list functions of data transformations.
"""
def
__init__
(
self
,
lq_folder
,
gt_folder
,
ann_file
,
preprocess
):
super
().
__init__
(
preprocess
)
self
.
lq_folder
=
str
(
lq_folder
)
self
.
gt_folder
=
str
(
gt_folder
)
self
.
ann_file
=
str
(
ann_file
)
self
.
data_infos
=
self
.
prepare_data_infos
()
def
prepare_data_infos
(
self
):
with
open
(
self
.
ann_file
,
'r'
)
as
fin
:
keys
=
[
line
.
strip
()
for
line
in
fin
]
data_infos
=
[]
for
key
in
keys
:
lq_paths
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
self
.
lq_folder
,
key
,
'*.png'
)))
gt_paths
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
self
.
gt_folder
,
key
,
'*.png'
)))
data_infos
.
append
(
dict
(
lq_path
=
lq_paths
,
gt_path
=
gt_paths
,
key
=
key
))
return
data_infos
ppgan/metrics/psnr_ssim.py
浏览文件 @
6e3dad37
...
...
@@ -30,17 +30,26 @@ class PSNR(paddle.metric.Metric):
def
reset
(
self
):
self
.
results
=
[]
def
update
(
self
,
preds
,
gts
):
def
update
(
self
,
preds
,
gts
,
is_seq
=
False
):
if
not
isinstance
(
preds
,
(
list
,
tuple
)):
preds
=
[
preds
]
if
not
isinstance
(
gts
,
(
list
,
tuple
)):
gts
=
[
gts
]
if
is_seq
:
single_seq
=
[]
for
pred
,
gt
in
zip
(
preds
,
gts
):
value
=
calculate_psnr
(
pred
,
gt
,
self
.
crop_border
,
self
.
input_order
,
self
.
test_y_channel
)
self
.
results
.
append
(
value
)
if
is_seq
:
single_seq
.
append
(
value
)
else
:
self
.
results
.
append
(
value
)
if
is_seq
:
self
.
results
.
append
(
np
.
mean
(
single_seq
))
def
accumulate
(
self
):
if
paddle
.
distributed
.
get_world_size
()
>
1
:
...
...
@@ -59,17 +68,26 @@ class PSNR(paddle.metric.Metric):
@
METRICS
.
register
()
class
SSIM
(
PSNR
):
def
update
(
self
,
preds
,
gts
):
def
update
(
self
,
preds
,
gts
,
is_seq
=
False
):
if
not
isinstance
(
preds
,
(
list
,
tuple
)):
preds
=
[
preds
]
if
not
isinstance
(
gts
,
(
list
,
tuple
)):
gts
=
[
gts
]
if
is_seq
:
single_seq
=
[]
for
pred
,
gt
in
zip
(
preds
,
gts
):
value
=
calculate_ssim
(
pred
,
gt
,
self
.
crop_border
,
self
.
input_order
,
self
.
test_y_channel
)
self
.
results
.
append
(
value
)
if
is_seq
:
single_seq
.
append
(
value
)
else
:
self
.
results
.
append
(
value
)
if
is_seq
:
self
.
results
.
append
(
np
.
mean
(
single_seq
))
def
name
(
self
):
return
'SSIM'
...
...
ppgan/models/base_model.py
浏览文件 @
6e3dad37
...
...
@@ -25,7 +25,7 @@ from ..utils.visual import tensor2img
class
BaseModel
(
ABC
):
"""This class is an abstract base class (ABC) for models.
r
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class.
-- <setup_input>: unpack data from dataset and apply preprocessing.
...
...
ppgan/models/basicvsr_model.py
浏览文件 @
6e3dad37
...
...
@@ -103,7 +103,7 @@ class BasicVSRModel(BaseSRModel):
if
metrics
is
not
None
:
for
metric
in
metrics
.
values
():
metric
.
update
(
out_img
,
gt_img
)
metric
.
update
(
out_img
,
gt_img
,
is_seq
=
True
)
def
init_basicvsr_weight
(
net
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录