Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
424ab9eb
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 2 年 前同步成功
通知
100
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
424ab9eb
编写于
12月 15, 2021
作者:
W
wangna11BD
提交者:
GitHub
12月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix reds dataset (#527)
* fix reds dataset * fix CI
上级
a8ae6b7a
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
968 addition
and
855 deletion
+968
-855
configs/basicvsr++_reds.yaml
configs/basicvsr++_reds.yaml
+49
-21
configs/basicvsr++_vimeo90k_BD.yaml
configs/basicvsr++_vimeo90k_BD.yaml
+2
-2
configs/basicvsr_reds.yaml
configs/basicvsr_reds.yaml
+51
-24
configs/edvr_l_blur_w_tsa.yaml
configs/edvr_l_blur_w_tsa.yaml
+62
-30
configs/edvr_l_blur_wo_tsa.yaml
configs/edvr_l_blur_wo_tsa.yaml
+62
-30
configs/edvr_l_w_tsa.yaml
configs/edvr_l_w_tsa.yaml
+60
-28
configs/edvr_l_wo_tsa.yaml
configs/edvr_l_wo_tsa.yaml
+60
-28
configs/edvr_m_w_tsa.yaml
configs/edvr_m_w_tsa.yaml
+60
-28
configs/edvr_m_wo_tsa.yaml
configs/edvr_m_wo_tsa.yaml
+60
-28
configs/esrgan_psnr_x4_div2k.yaml
configs/esrgan_psnr_x4_div2k.yaml
+2
-2
configs/esrgan_x4_div2k.yaml
configs/esrgan_x4_div2k.yaml
+2
-2
configs/iconvsr_reds.yaml
configs/iconvsr_reds.yaml
+50
-22
configs/lesrcnn_psnr_x4_div2k.yaml
configs/lesrcnn_psnr_x4_div2k.yaml
+2
-2
configs/msvsr_l_reds.yaml
configs/msvsr_l_reds.yaml
+49
-22
configs/msvsr_reds.yaml
configs/msvsr_reds.yaml
+50
-25
configs/msvsr_vimeo90k_BD.yaml
configs/msvsr_vimeo90k_BD.yaml
+2
-3
configs/realsr_bicubic_noise_x4_df2k.yaml
configs/realsr_bicubic_noise_x4_df2k.yaml
+2
-2
configs/realsr_kernel_noise_x4_dped.yaml
configs/realsr_kernel_noise_x4_dped.yaml
+2
-2
ppgan/datasets/__init__.py
ppgan/datasets/__init__.py
+2
-2
ppgan/datasets/edvr_dataset.py
ppgan/datasets/edvr_dataset.py
+0
-313
ppgan/datasets/preprocess/__init__.py
ppgan/datasets/preprocess/__init__.py
+1
-1
ppgan/datasets/preprocess/io.py
ppgan/datasets/preprocess/io.py
+148
-0
ppgan/datasets/sr_reds_multiple_gt_dataset.py
ppgan/datasets/sr_reds_multiple_gt_dataset.py
+0
-237
ppgan/datasets/vsr_reds_dataset.py
ppgan/datasets/vsr_reds_dataset.py
+97
-0
ppgan/datasets/vsr_reds_multiple_gt_dataset.py
ppgan/datasets/vsr_reds_multiple_gt_dataset.py
+92
-0
ppgan/models/edvr_model.py
ppgan/models/edvr_model.py
+1
-1
未找到文件。
configs/basicvsr++_reds.yaml
浏览文件 @
424ab9eb
...
...
@@ -27,33 +27,61 @@ dataset:
num_workers
:
4
batch_size
:
2
#4 gpus
dataset
:
name
:
SRREDSMultipleGTDataset
mode
:
train
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
30
use_flip
:
True
use_rot
:
True
scale
:
4
val_partition
:
REDS4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
30
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
SRREDSMultipleGTDataset
mode
:
test
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder
:
data/REDS/REDS4_test_sharp/X4
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
100
use_flip
:
False
use_rot
:
False
scale
:
4
val_partition
:
REDS4
num_workers
:
0
batch_size
:
1
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
100
test_mode
:
True
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/basicvsr++_vimeo90k_BD.yaml
浏览文件 @
424ab9eb
...
...
@@ -53,7 +53,7 @@ dataset:
keys
:
[
image
,
image
]
-
name
:
MirrorVideoSequence
-
name
:
NormalizeSequence
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
@@ -80,7 +80,7 @@ dataset:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
configs/basicvsr_reds.yaml
浏览文件 @
424ab9eb
...
...
@@ -24,37 +24,64 @@ dataset:
name
:
RepeatDataset
times
:
1000
num_workers
:
4
batch_size
:
2
#4
GPU
s
batch_size
:
2
#4
gpu
s
dataset
:
name
:
SRREDSMultipleGTDataset
mode
:
train
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
15
use_flip
:
True
use_rot
:
True
scale
:
4
val_partition
:
REDS4
num_clips
:
270
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
15
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
SRREDSMultipleGTDataset
mode
:
test
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder
:
data/REDS/REDS4_test_sharp/X4
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
100
use_flip
:
False
use_rot
:
False
scale
:
4
val_partition
:
REDS4
num_workers
:
0
batch_size
:
1
num_clips
:
270
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
100
test_mode
:
True
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/edvr_l_blur_w_tsa.yaml
浏览文件 @
424ab9eb
...
...
@@ -20,47 +20,79 @@ model:
front_RBs
:
5
back_RBs
:
40
center
:
2
predeblur
:
True
#False
HR_in
:
True
#False
predeblur
:
True
HR_in
:
True
w_TSA
:
True
pixel_criterion
:
name
:
CharbonnierLoss
dataset
:
train
:
name
:
REDSDataset
mode
:
train
gt_folder
:
data/REDS/train_sharp/X4
lq_folder
:
data/REDS/train_blur/X4
img_format
:
png
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
use_flip
:
True
use_rot
:
True
buf_size
:
1024
scale
:
1
fix_random_seed
:
10
name
:
RepeatDataset
times
:
1000
num_workers
:
6
batch_size
:
8
batch_size
:
8
#4 gpus
dataset
:
name
:
VSRREDSDataset
lq_folder
:
data/REDS/train_blur/X4
gt_folder
:
data/REDS/train_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
False
preprocess
:
-
name
:
GetFrameIdx
interval_list
:
[
1
]
frames_per_clip
:
99
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
1
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
REDSDataset
mode
:
test
gt_folder
:
data/REDS/REDS4_test_sharp/X4
name
:
VSRREDSDataset
lq_folder
:
data/REDS/REDS4_test_blur/X4
img_format
:
png
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
1
use_flip
:
False
use_rot
:
False
buf_size
:
1024
scale
:
1
fix_random_seed
:
10
gt_folder
:
data/REDS/REDS4_test_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
True
preprocess
:
-
name
:
GetFrameIdxwithPadding
padding
:
reflection_circle
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/edvr_l_blur_wo_tsa.yaml
浏览文件 @
424ab9eb
...
...
@@ -19,47 +19,79 @@ model:
front_RBs
:
5
back_RBs
:
40
center
:
2
predeblur
:
True
#False
HR_in
:
True
#False
predeblur
:
True
HR_in
:
True
w_TSA
:
False
pixel_criterion
:
name
:
CharbonnierLoss
dataset
:
train
:
name
:
REDSDataset
mode
:
train
gt_folder
:
data/REDS/train_sharp/X4
lq_folder
:
data/REDS/train_blur/X4
img_format
:
png
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
use_flip
:
True
use_rot
:
True
buf_size
:
1024
scale
:
1
fix_random_seed
:
10
name
:
RepeatDataset
times
:
1000
num_workers
:
6
batch_size
:
8
batch_size
:
8
#4 gpus
dataset
:
name
:
VSRREDSDataset
lq_folder
:
data/REDS/train_blur/X4
gt_folder
:
data/REDS/train_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
False
preprocess
:
-
name
:
GetFrameIdx
interval_list
:
[
1
]
frames_per_clip
:
99
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
1
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
REDSDataset
mode
:
test
gt_folder
:
data/REDS/REDS4_test_sharp/X4
name
:
VSRREDSDataset
lq_folder
:
data/REDS/REDS4_test_blur/X4
img_format
:
png
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
1
use_flip
:
False
use_rot
:
False
buf_size
:
1024
scale
:
1
fix_random_seed
:
10
gt_folder
:
data/REDS/REDS4_test_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
True
preprocess
:
-
name
:
GetFrameIdxwithPadding
padding
:
reflection_circle
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/edvr_l_w_tsa.yaml
浏览文件 @
424ab9eb
...
...
@@ -28,39 +28,71 @@ model:
dataset
:
train
:
name
:
REDSDataset
mode
:
train
gt_folder
:
data/REDS/train_sharp/X4
lq_folder
:
data/REDS/train_sharp_bicubic/X4
img_format
:
png
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
use_flip
:
True
use_rot
:
True
buf_size
:
1024
scale
:
4
fix_random_seed
:
10
name
:
RepeatDataset
times
:
1000
num_workers
:
3
batch_size
:
4
# 8GPUs
batch_size
:
4
#8 gpus
dataset
:
name
:
VSRREDSDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
False
preprocess
:
-
name
:
GetFrameIdx
interval_list
:
[
1
]
frames_per_clip
:
99
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
REDSDataset
mode
:
test
gt_folder
:
data/REDS/REDS4_test_sharp/X4
name
:
VSRREDSDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
img_format
:
png
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
1
use_flip
:
False
use_rot
:
False
buf_size
:
1024
scale
:
4
fix_random_seed
:
10
gt_folder
:
data/REDS/REDS4_test_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
True
preprocess
:
-
name
:
GetFrameIdxwithPadding
padding
:
reflection_circle
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/edvr_l_wo_tsa.yaml
浏览文件 @
424ab9eb
...
...
@@ -27,39 +27,71 @@ model:
dataset
:
train
:
name
:
REDSDataset
mode
:
train
gt_folder
:
data/REDS/train_sharp/X4
lq_folder
:
data/REDS/train_sharp_bicubic/X4
img_format
:
png
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
use_flip
:
True
use_rot
:
True
buf_size
:
1024
scale
:
4
fix_random_seed
:
10
name
:
RepeatDataset
times
:
1000
num_workers
:
3
batch_size
:
4
# 8GPUs
batch_size
:
4
#8 gpus
dataset
:
name
:
VSRREDSDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
False
preprocess
:
-
name
:
GetFrameIdx
interval_list
:
[
1
]
frames_per_clip
:
99
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
REDSDataset
mode
:
test
gt_folder
:
data/REDS/REDS4_test_sharp/X4
name
:
VSRREDSDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
img_format
:
png
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
1
use_flip
:
False
use_rot
:
False
buf_size
:
1024
scale
:
4
fix_random_seed
:
10
gt_folder
:
data/REDS/REDS4_test_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
True
preprocess
:
-
name
:
GetFrameIdxwithPadding
padding
:
reflection_circle
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/edvr_m_w_tsa.yaml
浏览文件 @
424ab9eb
...
...
@@ -31,39 +31,71 @@ export_model:
dataset
:
train
:
name
:
REDSDataset
mode
:
train
gt_folder
:
data/REDS/train_sharp/X4
lq_folder
:
data/REDS/train_sharp_bicubic/X4
img_format
:
png
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
use_flip
:
True
use_rot
:
True
buf_size
:
1024
scale
:
4
fix_random_seed
:
10
name
:
RepeatDataset
times
:
1000
num_workers
:
3
batch_size
:
4
# 8GPUs
batch_size
:
4
#8 gpus
dataset
:
name
:
VSRREDSDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
False
preprocess
:
-
name
:
GetFrameIdx
interval_list
:
[
1
]
frames_per_clip
:
99
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
REDSDataset
mode
:
test
gt_folder
:
data/REDS/REDS4_test_sharp/X4
name
:
VSRREDSDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
img_format
:
png
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
1
use_flip
:
False
use_rot
:
False
buf_size
:
1024
scale
:
4
fix_random_seed
:
10
gt_folder
:
data/REDS/REDS4_test_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
True
preprocess
:
-
name
:
GetFrameIdxwithPadding
padding
:
reflection_circle
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/edvr_m_wo_tsa.yaml
浏览文件 @
424ab9eb
...
...
@@ -27,39 +27,71 @@ model:
dataset
:
train
:
name
:
REDSDataset
mode
:
train
gt_folder
:
data/REDS/train_sharp/X4
lq_folder
:
data/REDS/train_sharp_bicubic/X4
img_format
:
png
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
use_flip
:
True
use_rot
:
True
buf_size
:
1024
scale
:
4
fix_random_seed
:
10
name
:
RepeatDataset
times
:
1000
num_workers
:
3
batch_size
:
4
# 8GPUs
batch_size
:
4
#8 gpus
dataset
:
name
:
VSRREDSDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
False
preprocess
:
-
name
:
GetFrameIdx
interval_list
:
[
1
]
frames_per_clip
:
99
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
REDSDataset
mode
:
test
gt_folder
:
data/REDS/REDS4_test_sharp/X4
name
:
VSRREDSDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
img_format
:
png
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
1
use_flip
:
False
use_rot
:
False
buf_size
:
1024
scale
:
4
fix_random_seed
:
10
gt_folder
:
data/REDS/REDS4_test_sharp/X4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
5
val_partition
:
REDS4
test_mode
:
True
preprocess
:
-
name
:
GetFrameIdxwithPadding
padding
:
reflection_circle
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/esrgan_psnr_x4_div2k.yaml
浏览文件 @
424ab9eb
...
...
@@ -44,7 +44,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
...
...
@@ -63,7 +63,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
configs/esrgan_x4_div2k.yaml
浏览文件 @
424ab9eb
...
...
@@ -64,7 +64,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
...
...
@@ -83,7 +83,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
configs/iconvsr_reds.yaml
浏览文件 @
424ab9eb
...
...
@@ -24,35 +24,63 @@ dataset:
name
:
RepeatDataset
times
:
1000
num_workers
:
4
batch_size
:
2
#4
GPU
s
batch_size
:
2
#4
gpu
s
dataset
:
name
:
SRREDSMultipleGTDataset
mode
:
train
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
15
use_flip
:
True
use_rot
:
True
scale
:
4
val_partition
:
REDS4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
15
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
SRREDSMultipleGTDataset
mode
:
test
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder
:
data/REDS/REDS4_test_sharp/X4
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
100
use_flip
:
False
use_rot
:
False
scale
:
4
val_partition
:
REDS4
num_workers
:
0
batch_size
:
1
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
100
test_mode
:
True
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/lesrcnn_psnr_x4_div2k.yaml
浏览文件 @
424ab9eb
...
...
@@ -40,7 +40,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
...
...
@@ -59,7 +59,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
configs/msvsr_l_reds.yaml
浏览文件 @
424ab9eb
...
...
@@ -34,35 +34,62 @@ dataset:
times
:
1000
num_workers
:
4
batch_size
:
2
#8 gpus
use_shared_memory
:
True
dataset
:
name
:
SRREDSMultipleGTDataset
mode
:
train
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
30
use_flip
:
True
use_rot
:
True
scale
:
4
val_partition
:
REDS4
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
30
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
SRREDSMultipleGTDataset
mode
:
test
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder
:
data/REDS/REDS4_test_sharp/X4
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
100
use_flip
:
False
use_rot
:
False
scale
:
4
val_partition
:
REDS4
num_workers
:
0
batch_size
:
1
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
100
test_mode
:
True
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/msvsr_reds.yaml
浏览文件 @
424ab9eb
...
...
@@ -34,37 +34,62 @@ dataset:
times
:
1000
num_workers
:
6
batch_size
:
2
#8 gpus
use_shared_memory
:
True
dataset
:
name
:
SRREDSMultipleGTDataset
mode
:
train
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
20
use_flip
:
True
use_rot
:
True
scale
:
4
val_partition
:
REDS4
num_clips
:
270
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
20
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
256
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
SRREDSMultipleGTDataset
mode
:
test
name
:
VSRREDSMultipleGTDataset
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder
:
data/REDS/REDS4_test_sharp/X4
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
100
use_flip
:
False
use_rot
:
False
scale
:
4
val_partition
:
REDS4
num_workers
:
0
batch_size
:
1
num_clips
:
270
ann_file
:
data/REDS/meta_info_REDS_GT.txt
num_frames
:
100
test_mode
:
True
preprocess
:
-
name
:
GetNeighboringFramesIdx
interval_list
:
[
1
]
-
name
:
ReadImageSequence
key
:
lq
-
name
:
ReadImageSequence
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
@@ -104,4 +129,4 @@ snapshot_config:
interval
:
5000
export_model
:
-
{
name
:
'
generator'
,
inputs_num
:
1
}
\ No newline at end of file
-
{
name
:
'
generator'
,
inputs_num
:
1
}
configs/msvsr_vimeo90k_BD.yaml
浏览文件 @
424ab9eb
...
...
@@ -36,7 +36,6 @@ dataset:
batch_size
:
2
#8 gpus
dataset
:
name
:
VSRVimeo90KDataset
# mode: train
lq_folder
:
data/vimeo90k/vimeo_septuplet_BD_matlabLRx4/sequences
gt_folder
:
data/vimeo90k/vimeo_septuplet/sequences
ann_file
:
data/vimeo90k/vimeo_septuplet/sep_trainlist.txt
...
...
@@ -62,7 +61,7 @@ dataset:
keys
:
[
image
,
image
]
-
name
:
MirrorVideoSequence
-
name
:
NormalizeSequence
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
@@ -89,7 +88,7 @@ dataset:
-
name
:
TransposeSequence
keys
:
[
image
,
image
]
-
name
:
NormalizeSequence
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
configs/realsr_bicubic_noise_x4_df2k.yaml
浏览文件 @
424ab9eb
...
...
@@ -61,7 +61,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
-
name
:
SRNoise
...
...
@@ -84,7 +84,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
configs/realsr_kernel_noise_x4_dped.yaml
浏览文件 @
424ab9eb
...
...
@@ -61,7 +61,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
-
name
:
SRNoise
...
...
@@ -84,7 +84,7 @@ dataset:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
mean
:
[
0.
,
0.
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
...
...
ppgan/datasets/__init__.py
浏览文件 @
424ab9eb
...
...
@@ -21,11 +21,11 @@ from .common_vision_dataset import CommonVisionDataset
from
.animeganv2_dataset
import
AnimeGANV2Dataset
from
.wav2lip_dataset
import
Wav2LipDataset
from
.starganv2_dataset
import
StarGANv2Dataset
from
.edvr_dataset
import
REDSDataset
from
.firstorder_dataset
import
FirstOrderDataset
from
.lapstyle_dataset
import
LapStyleDataset
from
.sr_reds_multiple_gt_dataset
import
SRREDSMultipleGTDataset
from
.mpr_dataset
import
MPRTrain
,
MPRVal
,
MPRTest
from
.vsr_reds_dataset
import
VSRREDSDataset
from
.vsr_reds_multiple_gt_dataset
import
VSRREDSMultipleGTDataset
from
.vsr_vimeo90k_dataset
import
VSRVimeo90KDataset
from
.vsr_folder_dataset
import
VSRFolderDataset
from
.photopen_dataset
import
PhotoPenDataset
ppgan/datasets/edvr_dataset.py
已删除
100644 → 0
浏览文件 @
a8ae6b7a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
import
random
import
numpy
as
np
import
scipy.io
as
scio
import
cv2
import
paddle
from
paddle.io
import
Dataset
,
DataLoader
from
.builder
import
DATASETS
logger
=
logging
.
getLogger
(
__name__
)
@
DATASETS
.
register
()
class
REDSDataset
(
Dataset
):
"""
REDS dataset for EDVR model
"""
def
__init__
(
self
,
mode
,
lq_folder
,
gt_folder
,
img_format
=
"png"
,
crop_size
=
256
,
interval_list
=
[
1
],
random_reverse
=
False
,
number_frames
=
5
,
batch_size
=
32
,
use_flip
=
False
,
use_rot
=
False
,
buf_size
=
1024
,
scale
=
4
,
fix_random_seed
=
False
):
super
(
REDSDataset
,
self
).
__init__
()
self
.
format
=
img_format
self
.
mode
=
mode
self
.
crop_size
=
crop_size
self
.
interval_list
=
interval_list
self
.
random_reverse
=
random_reverse
self
.
number_frames
=
number_frames
self
.
batch_size
=
batch_size
self
.
fileroot
=
lq_folder
self
.
use_flip
=
use_flip
self
.
use_rot
=
use_rot
self
.
buf_size
=
buf_size
self
.
fix_random_seed
=
fix_random_seed
if
self
.
mode
!=
'infer'
:
self
.
gtroot
=
gt_folder
self
.
scale
=
scale
self
.
LR_input
=
(
self
.
scale
>
1
)
if
self
.
fix_random_seed
:
random
.
seed
(
10
)
np
.
random
.
seed
(
10
)
self
.
num_reader_threads
=
1
self
.
_init_
()
def
_init_
(
self
):
logger
.
info
(
'initialize reader ... '
)
print
(
"initialize reader"
)
self
.
filelist
=
[]
for
video_name
in
os
.
listdir
(
self
.
fileroot
):
if
(
self
.
mode
==
'train'
)
and
(
video_name
in
[
'000'
,
'011'
,
'015'
,
'020'
]):
#These four videos are used as val
continue
for
frame_name
in
os
.
listdir
(
os
.
path
.
join
(
self
.
fileroot
,
video_name
)):
frame_idx
=
frame_name
.
split
(
'.'
)[
0
]
video_frame_idx
=
video_name
+
'_'
+
str
(
frame_idx
)
# for each item in self.filelist is like '010_00000015', '260_00000090'
self
.
filelist
.
append
(
video_frame_idx
)
if
self
.
mode
==
'test'
:
self
.
filelist
.
sort
()
print
(
len
(
self
.
filelist
))
def
__getitem__
(
self
,
index
):
"""Get training sample
return: lq:[5,3,W,H],
gt:[3,W,H],
lq_path:str
"""
item
=
self
.
filelist
[
index
]
img_LQs
,
img_GT
=
self
.
get_sample_data
(
item
,
self
.
number_frames
,
self
.
interval_list
,
self
.
random_reverse
,
self
.
gtroot
,
self
.
fileroot
,
self
.
LR_input
,
self
.
crop_size
,
self
.
scale
,
self
.
use_flip
,
self
.
use_rot
,
self
.
mode
)
return
{
'lq'
:
img_LQs
,
'gt'
:
img_GT
,
'lq_path'
:
self
.
filelist
[
index
]}
def
get_sample_data
(
self
,
item
,
number_frames
,
interval_list
,
random_reverse
,
gtroot
,
fileroot
,
LR_input
,
crop_size
,
scale
,
use_flip
,
use_rot
,
mode
=
'train'
):
video_name
=
item
.
split
(
'_'
)[
0
]
frame_name
=
item
.
split
(
'_'
)[
1
]
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
ngb_frames
,
name_b
=
self
.
get_neighbor_frames
(
frame_name
,
\
number_frames
=
number_frames
,
\
interval_list
=
interval_list
,
\
random_reverse
=
random_reverse
)
elif
mode
==
'test'
:
ngb_frames
,
name_b
=
self
.
get_test_neighbor_frames
(
int
(
frame_name
),
number_frames
)
else
:
raise
NotImplementedError
(
'mode {} not implemented'
.
format
(
mode
))
frame_name
=
name_b
img_GT
=
self
.
read_img
(
os
.
path
.
join
(
gtroot
,
video_name
,
frame_name
+
'.png'
))
frame_list
=
[]
for
ngb_frm
in
ngb_frames
:
ngb_name
=
"%08d"
%
ngb_frm
img
=
self
.
read_img
(
os
.
path
.
join
(
fileroot
,
video_name
,
ngb_name
+
'.png'
))
frame_list
.
append
(
img
)
H
,
W
,
C
=
frame_list
[
0
].
shape
# add random crop
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
if
LR_input
:
LQ_size
=
crop_size
//
scale
rnd_h
=
random
.
randint
(
0
,
max
(
0
,
H
-
LQ_size
))
rnd_w
=
random
.
randint
(
0
,
max
(
0
,
W
-
LQ_size
))
frame_list
=
[
v
[
rnd_h
:
rnd_h
+
LQ_size
,
rnd_w
:
rnd_w
+
LQ_size
,
:]
for
v
in
frame_list
]
rnd_h_HR
,
rnd_w_HR
=
int
(
rnd_h
*
scale
),
int
(
rnd_w
*
scale
)
img_GT
=
img_GT
[
rnd_h_HR
:
rnd_h_HR
+
crop_size
,
rnd_w_HR
:
rnd_w_HR
+
crop_size
,
:]
else
:
rnd_h
=
random
.
randint
(
0
,
max
(
0
,
H
-
crop_size
))
rnd_w
=
random
.
randint
(
0
,
max
(
0
,
W
-
crop_size
))
frame_list
=
[
v
[
rnd_h
:
rnd_h
+
crop_size
,
rnd_w
:
rnd_w
+
crop_size
,
:]
for
v
in
frame_list
]
img_GT
=
img_GT
[
rnd_h
:
rnd_h
+
crop_size
,
rnd_w
:
rnd_w
+
crop_size
,
:]
# add random flip and rotation
frame_list
.
append
(
img_GT
)
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
rlt
=
self
.
img_augment
(
frame_list
,
use_flip
,
use_rot
)
else
:
rlt
=
frame_list
frame_list
=
rlt
[
0
:
-
1
]
img_GT
=
rlt
[
-
1
]
# stack LQ images to NHWC, N is the frame number
img_LQs
=
np
.
stack
(
frame_list
,
axis
=
0
)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT
=
img_GT
[:,
:,
[
2
,
1
,
0
]]
img_LQs
=
img_LQs
[:,
:,
:,
[
2
,
1
,
0
]]
img_GT
=
np
.
transpose
(
img_GT
,
(
2
,
0
,
1
)).
astype
(
'float32'
)
img_LQs
=
np
.
transpose
(
img_LQs
,
(
0
,
3
,
1
,
2
)).
astype
(
'float32'
)
return
img_LQs
,
img_GT
def
get_neighbor_frames
(
self
,
frame_name
,
number_frames
,
interval_list
,
random_reverse
,
max_frame
=
99
,
bordermode
=
False
):
center_frame_idx
=
int
(
frame_name
)
half_N_frames
=
number_frames
//
2
interval
=
random
.
choice
(
interval_list
)
if
bordermode
:
direction
=
1
if
random_reverse
and
random
.
random
()
<
0.5
:
direction
=
random
.
choice
([
0
,
1
])
if
center_frame_idx
+
interval
*
(
number_frames
-
1
)
>
max_frame
:
direction
=
0
elif
center_frame_idx
-
interval
*
(
number_frames
-
1
)
<
0
:
direction
=
1
if
direction
==
1
:
neighbor_list
=
list
(
range
(
center_frame_idx
,
center_frame_idx
+
interval
*
number_frames
,
interval
))
else
:
neighbor_list
=
list
(
range
(
center_frame_idx
,
center_frame_idx
-
interval
*
number_frames
,
-
interval
))
name_b
=
'{:08d}'
.
format
(
neighbor_list
[
0
])
else
:
# ensure not exceeding the borders
while
(
center_frame_idx
+
half_N_frames
*
interval
>
max_frame
)
or
(
center_frame_idx
-
half_N_frames
*
interval
<
0
):
center_frame_idx
=
random
.
randint
(
0
,
max_frame
)
neighbor_list
=
list
(
range
(
center_frame_idx
-
half_N_frames
*
interval
,
center_frame_idx
+
half_N_frames
*
interval
+
1
,
interval
))
if
random_reverse
and
random
.
random
()
<
0.5
:
neighbor_list
.
reverse
()
name_b
=
'{:08d}'
.
format
(
neighbor_list
[
half_N_frames
])
assert
len
(
neighbor_list
)
==
number_frames
,
\
"frames slected have length({}), but it should be ({})"
.
format
(
len
(
neighbor_list
),
number_frames
)
return
neighbor_list
,
name_b
def
read_img
(
self
,
path
,
size
=
None
):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]
"""
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
img
def
img_augment
(
self
,
img_list
,
hflip
=
True
,
rot
=
True
):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
"""
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rot
and
random
.
random
()
<
0.5
rot90
=
rot
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
img
=
img
[:,
::
-
1
,
:]
if
vflip
:
img
=
img
[::
-
1
,
:,
:]
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
return
[
_augment
(
img
)
for
img
in
img_list
]
def
get_test_neighbor_frames
(
self
,
crt_i
,
N
,
max_n
=
100
,
padding
=
'new_info'
):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n
=
max_n
-
1
n_pad
=
N
//
2
return_l
=
[]
for
i
in
range
(
crt_i
-
n_pad
,
crt_i
+
n_pad
+
1
):
if
i
<
0
:
if
padding
==
'replicate'
:
add_idx
=
0
elif
padding
==
'reflection'
:
add_idx
=
-
i
elif
padding
==
'new_info'
:
add_idx
=
(
crt_i
+
n_pad
)
+
(
-
i
)
elif
padding
==
'circle'
:
add_idx
=
N
+
i
else
:
raise
ValueError
(
'Wrong padding mode'
)
elif
i
>
max_n
:
if
padding
==
'replicate'
:
add_idx
=
max_n
elif
padding
==
'reflection'
:
add_idx
=
max_n
*
2
-
i
elif
padding
==
'new_info'
:
add_idx
=
(
crt_i
-
n_pad
)
-
(
i
-
max_n
)
elif
padding
==
'circle'
:
add_idx
=
i
-
N
else
:
raise
ValueError
(
'Wrong padding mode'
)
else
:
add_idx
=
i
return_l
.
append
(
add_idx
)
name_b
=
'{:08d}'
.
format
(
crt_i
)
return
return_l
,
name_b
def
__len__
(
self
):
"""Return the total number of images in the dataset.
"""
return
len
(
self
.
filelist
)
ppgan/datasets/preprocess/__init__.py
浏览文件 @
424ab9eb
from
.io
import
LoadImageFromFile
,
ReadImageSequence
,
GetNeighboringFramesIdx
from
.io
import
LoadImageFromFile
,
ReadImageSequence
,
GetNeighboringFramesIdx
,
GetFrameIdx
,
GetFrameIdxwithPadding
from
.transforms
import
(
PairedRandomCrop
,
PairedRandomHorizontalFlip
,
PairedRandomVerticalFlip
,
PairedRandomTransposeHW
,
SRPairedRandomCrop
,
SplitPairedImage
,
SRNoise
,
...
...
ppgan/datasets/preprocess/io.py
浏览文件 @
424ab9eb
...
...
@@ -179,3 +179,151 @@ class GetNeighboringFramesIdx:
datas
[
'interval'
]
=
interval
return
datas
@
PREPROCESS
.
register
()
class
GetFrameIdx
:
"""Generate frame index for REDS datasets.
Args:
interval_list (list[int]): Interval list for temporal augmentation.
It will randomly pick an interval from interval_list and sample
frame index with the interval.
frames_per_clip(int): Number of frames per clips. Default: 99 for
REDS dataset.
"""
def
__init__
(
self
,
interval_list
,
frames_per_clip
=
99
):
self
.
interval_list
=
interval_list
self
.
frames_per_clip
=
frames_per_clip
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
clip_name
,
frame_name
=
results
[
'key'
].
split
(
'/'
)
center_frame_idx
=
int
(
frame_name
)
num_half_frames
=
results
[
'num_frames'
]
//
2
interval
=
np
.
random
.
choice
(
self
.
interval_list
)
# ensure not exceeding the borders
start_frame_idx
=
center_frame_idx
-
num_half_frames
*
interval
end_frame_idx
=
center_frame_idx
+
num_half_frames
*
interval
while
(
start_frame_idx
<
0
)
or
(
end_frame_idx
>
self
.
frames_per_clip
):
center_frame_idx
=
np
.
random
.
randint
(
0
,
self
.
frames_per_clip
+
1
)
start_frame_idx
=
center_frame_idx
-
num_half_frames
*
interval
end_frame_idx
=
center_frame_idx
+
num_half_frames
*
interval
frame_name
=
f
'
{
center_frame_idx
:
08
d
}
'
neighbor_list
=
list
(
range
(
center_frame_idx
-
num_half_frames
*
interval
,
center_frame_idx
+
num_half_frames
*
interval
+
1
,
interval
))
lq_path_root
=
results
[
'lq_path'
]
gt_path_root
=
results
[
'gt_path'
]
lq_path
=
[
os
.
path
.
join
(
lq_path_root
,
clip_name
,
f
'
{
v
:
08
d
}
.png'
)
for
v
in
neighbor_list
]
gt_path
=
[
os
.
path
.
join
(
gt_path_root
,
clip_name
,
f
'
{
frame_name
}
.png'
)]
results
[
'lq_path'
]
=
lq_path
results
[
'gt_path'
]
=
gt_path
results
[
'interval'
]
=
interval
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
(
f
'(interval_list=
{
self
.
interval_list
}
, '
f
'frames_per_clip=
{
self
.
frames_per_clip
}
)'
)
return
repr_str
@
PREPROCESS
.
register
()
class
GetFrameIdxwithPadding
:
"""Generate frame index with padding for REDS dataset and Vid4 dataset
during testing.
Args:
padding (str): padding mode, one of
'replicate' | 'reflection' | 'reflection_circle' | 'circle'.
Examples: current_idx = 0, num_frames = 5
The generated frame indices under different padding mode:
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
reflection_circle: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
filename_tmpl (str): Template for file name. Default: '{:08d}'.
"""
def
__init__
(
self
,
padding
,
filename_tmpl
=
'{:08d}'
):
if
padding
not
in
(
'replicate'
,
'reflection'
,
'reflection_circle'
,
'circle'
):
raise
ValueError
(
f
'Wrong padding mode
{
padding
}
.'
'Should be "replicate", "reflection", '
'"reflection_circle", "circle"'
)
self
.
padding
=
padding
self
.
filename_tmpl
=
filename_tmpl
def
__call__
(
self
,
results
):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
clip_name
,
frame_name
=
results
[
'key'
].
split
(
'/'
)
current_idx
=
int
(
frame_name
)
max_frame_num
=
results
[
'max_frame_num'
]
-
1
# start from 0
num_frames
=
results
[
'num_frames'
]
num_pad
=
num_frames
//
2
frame_list
=
[]
for
i
in
range
(
current_idx
-
num_pad
,
current_idx
+
num_pad
+
1
):
if
i
<
0
:
if
self
.
padding
==
'replicate'
:
pad_idx
=
0
elif
self
.
padding
==
'reflection'
:
pad_idx
=
-
i
elif
self
.
padding
==
'reflection_circle'
:
pad_idx
=
current_idx
+
num_pad
-
i
else
:
pad_idx
=
num_frames
+
i
elif
i
>
max_frame_num
:
if
self
.
padding
==
'replicate'
:
pad_idx
=
max_frame_num
elif
self
.
padding
==
'reflection'
:
pad_idx
=
max_frame_num
*
2
-
i
elif
self
.
padding
==
'reflection_circle'
:
pad_idx
=
(
current_idx
-
num_pad
)
-
(
i
-
max_frame_num
)
else
:
pad_idx
=
i
-
num_frames
else
:
pad_idx
=
i
frame_list
.
append
(
pad_idx
)
lq_path_root
=
results
[
'lq_path'
]
gt_path_root
=
results
[
'gt_path'
]
lq_paths
=
[
os
.
path
.
join
(
lq_path_root
,
clip_name
,
f
'
{
self
.
filename_tmpl
.
format
(
idx
)
}
.png'
)
for
idx
in
frame_list
]
gt_paths
=
[
os
.
path
.
join
(
gt_path_root
,
clip_name
,
f
'
{
frame_name
}
.png'
)]
results
[
'lq_path'
]
=
lq_paths
results
[
'gt_path'
]
=
gt_paths
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
+
f
"(padding='
{
self
.
padding
}
')"
return
repr_str
ppgan/datasets/sr_reds_multiple_gt_dataset.py
已删除
100644 → 0
浏览文件 @
a8ae6b7a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
import
random
import
numpy
as
np
import
cv2
from
paddle.io
import
Dataset
from
.builder
import
DATASETS
logger
=
logging
.
getLogger
(
__name__
)
@
DATASETS
.
register
()
class
SRREDSMultipleGTDataset
(
Dataset
):
"""REDS dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
num_input_frames (int): Number of input frames.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
val_partition (str): Validation partition mode. Choices ['official' or
'REDS4']. Default: 'official'.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""
def
__init__
(
self
,
mode
,
lq_folder
,
gt_folder
,
crop_size
=
256
,
interval_list
=
[
1
],
random_reverse
=
False
,
number_frames
=
15
,
use_flip
=
False
,
use_rot
=
False
,
scale
=
4
,
val_partition
=
'REDS4'
,
batch_size
=
4
,
num_clips
=
270
):
super
(
SRREDSMultipleGTDataset
,
self
).
__init__
()
self
.
mode
=
mode
self
.
fileroot
=
str
(
lq_folder
)
self
.
gtroot
=
str
(
gt_folder
)
self
.
crop_size
=
crop_size
self
.
interval_list
=
interval_list
self
.
random_reverse
=
random_reverse
self
.
number_frames
=
number_frames
self
.
use_flip
=
use_flip
self
.
use_rot
=
use_rot
self
.
scale
=
scale
self
.
val_partition
=
val_partition
self
.
batch_size
=
batch_size
self
.
num_clips
=
num_clips
# training num of LQ and GT pairs
self
.
data_infos
=
self
.
load_annotations
()
def
__getitem__
(
self
,
idx
):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
item
=
self
.
data_infos
[
idx
]
idt
=
random
.
randint
(
0
,
100
-
self
.
number_frames
)
item
=
item
+
'_'
+
f
'
{
idt
:
03
d
}
'
img_LQs
,
img_GTs
=
self
.
get_sample_data
(
item
,
self
.
number_frames
,
self
.
interval_list
,
self
.
random_reverse
,
self
.
gtroot
,
self
.
fileroot
,
self
.
crop_size
,
self
.
scale
,
self
.
use_flip
,
self
.
use_rot
,
self
.
mode
)
return
{
'lq'
:
img_LQs
,
'gt'
:
img_GTs
,
'lq_path'
:
self
.
data_infos
[
idx
]}
def
load_annotations
(
self
):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys
=
[
f
'
{
i
:
03
d
}
'
for
i
in
range
(
0
,
self
.
num_clips
)]
if
self
.
val_partition
==
'REDS4'
:
val_partition
=
[
'000'
,
'011'
,
'015'
,
'020'
]
elif
self
.
val_partition
==
'official'
:
val_partition
=
[
f
'
{
i
:
03
d
}
'
for
i
in
range
(
240
,
270
)]
else
:
raise
ValueError
(
f
'Wrong validation partition
{
self
.
val_partition
}
.'
f
'Supported ones are ["official", "REDS4"]'
)
if
self
.
mode
==
'train'
:
keys
=
[
v
for
v
in
keys
if
v
not
in
val_partition
]
else
:
keys
=
[
v
for
v
in
keys
if
v
in
val_partition
]
data_infos
=
[]
for
key
in
keys
:
data_infos
.
append
(
key
)
return
data_infos
def
get_sample_data
(
self
,
item
,
number_frames
,
interval_list
,
random_reverse
,
gtroot
,
fileroot
,
crop_size
,
scale
,
use_flip
,
use_rot
,
mode
=
'train'
):
video_name
=
item
.
split
(
'_'
)[
0
]
frame_name
=
item
.
split
(
'_'
)[
1
]
frame_idxs
=
self
.
get_neighbor_frames
(
frame_name
,
number_frames
=
number_frames
,
interval_list
=
interval_list
,
random_reverse
=
random_reverse
)
frame_list
=
[]
gt_list
=
[]
for
frame_idx
in
frame_idxs
:
frame_idx_name
=
"%08d"
%
frame_idx
img
=
self
.
read_img
(
os
.
path
.
join
(
fileroot
,
video_name
,
frame_idx_name
+
'.png'
))
frame_list
.
append
(
img
)
gt_img
=
self
.
read_img
(
os
.
path
.
join
(
gtroot
,
video_name
,
frame_idx_name
+
'.png'
))
gt_list
.
append
(
gt_img
)
H
,
W
,
C
=
frame_list
[
0
].
shape
# add random crop
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
LQ_size
=
crop_size
//
scale
rnd_h
=
random
.
randint
(
0
,
max
(
0
,
H
-
LQ_size
))
rnd_w
=
random
.
randint
(
0
,
max
(
0
,
W
-
LQ_size
))
frame_list
=
[
v
[
rnd_h
:
rnd_h
+
LQ_size
,
rnd_w
:
rnd_w
+
LQ_size
,
:]
for
v
in
frame_list
]
rnd_h_HR
,
rnd_w_HR
=
int
(
rnd_h
*
scale
),
int
(
rnd_w
*
scale
)
gt_list
=
[
v
[
rnd_h_HR
:
rnd_h_HR
+
crop_size
,
rnd_w_HR
:
rnd_w_HR
+
crop_size
,
:]
for
v
in
gt_list
]
# add random flip and rotation
for
v
in
gt_list
:
frame_list
.
append
(
v
)
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
rlt
=
self
.
img_augment
(
frame_list
,
use_flip
,
use_rot
)
else
:
rlt
=
frame_list
frame_list
=
rlt
[
0
:
number_frames
]
gt_list
=
rlt
[
number_frames
:]
# stack LQ images to NHWC, N is the frame number
frame_list
=
[
v
.
transpose
(
2
,
0
,
1
).
astype
(
'float32'
)
for
v
in
frame_list
]
gt_list
=
[
v
.
transpose
(
2
,
0
,
1
).
astype
(
'float32'
)
for
v
in
gt_list
]
img_LQs
=
np
.
stack
(
frame_list
,
axis
=
0
)
img_GTs
=
np
.
stack
(
gt_list
,
axis
=
0
)
return
img_LQs
,
img_GTs
def
get_neighbor_frames
(
self
,
frame_name
,
number_frames
,
interval_list
,
random_reverse
):
frame_idx
=
int
(
frame_name
)
interval
=
random
.
choice
(
interval_list
)
neighbor_list
=
list
(
range
(
frame_idx
,
frame_idx
+
number_frames
,
interval
))
if
random_reverse
and
random
.
random
()
<
0.5
:
neighbor_list
.
reverse
()
assert
len
(
neighbor_list
)
==
number_frames
,
\
"frames slected have length({}), but it should be ({})"
.
format
(
len
(
neighbor_list
),
number_frames
)
return
neighbor_list
def
read_img
(
self
,
path
,
size
=
None
):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]
"""
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
def
img_augment
(
self
,
img_list
,
hflip
=
True
,
rot
=
True
):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
"""
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rot
and
random
.
random
()
<
0.5
rot90
=
rot
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
img
=
img
[:,
::
-
1
,
:]
if
vflip
:
img
=
img
[::
-
1
,
:,
:]
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
return
[
_augment
(
img
)
for
img
in
img_list
]
def
__len__
(
self
):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return
len
(
self
.
data_infos
)
ppgan/datasets/vsr_reds_dataset.py
0 → 100644
浏览文件 @
424ab9eb
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
from
.builder
import
DATASETS
from
.base_sr_dataset
import
BaseDataset
logger
=
logging
.
getLogger
(
__name__
)
@
DATASETS
.
register
()
class
VSRREDSDataset
(
BaseDataset
):
"""REDS dataset for video super resolution for Sliding-window networks.
The dataset loads several LQ (Low-Quality) frames and a center GT
(Ground-Truth) frame. Then it applies specified transforms and finally
returns a dict containing paired data and other information.
It reads REDS keys from the txt file. Each line contains video frame folder
Examples:
000/00000000.png (720, 1280, 3)
000/00000001.png (720, 1280, 3)
Args:
lq_folder (str): Path to a low quality image folder.
gt_folder (str): Path to a ground truth image folder.
ann_file (str): Path to the annotation file.
num_frames (int): Window size for input frames.
preprocess (list[dict|callable]): A list functions of data transformations.
val_partition (str): Validation partition mode. Choices ['official' or 'REDS4']. Default: 'REDS4'.
test_mode (bool): Store `True` when building test dataset. Default: `False`.
"""
def
__init__
(
self
,
lq_folder
,
gt_folder
,
ann_file
,
num_frames
,
preprocess
,
val_partition
=
'REDS4'
,
test_mode
=
False
):
super
().
__init__
(
preprocess
)
assert
num_frames
%
2
==
1
,
(
f
'num_frames should be odd numbers, '
f
'but received
{
num_frames
}
.'
)
self
.
lq_folder
=
str
(
lq_folder
)
self
.
gt_folder
=
str
(
gt_folder
)
self
.
ann_file
=
str
(
ann_file
)
self
.
num_frames
=
num_frames
self
.
val_partition
=
val_partition
self
.
test_mode
=
test_mode
self
.
data_infos
=
self
.
prepare_data_infos
()
def
prepare_data_infos
(
self
):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# get keys
with
open
(
self
.
ann_file
,
'r'
)
as
fin
:
keys
=
[
v
.
strip
().
split
(
'.'
)[
0
]
for
v
in
fin
]
if
self
.
val_partition
==
'REDS4'
:
val_partition
=
[
'000'
,
'011'
,
'015'
,
'020'
]
elif
self
.
val_partition
==
'official'
:
val_partition
=
[
f
'
{
v
:
03
d
}
'
for
v
in
range
(
240
,
270
)]
else
:
raise
ValueError
(
f
'Wrong validation partition
{
self
.
val_partition
}
.'
f
'Supported ones are ["official", "REDS4"]'
)
if
self
.
test_mode
:
keys
=
[
v
for
v
in
keys
if
v
.
split
(
'/'
)[
0
]
in
val_partition
]
else
:
keys
=
[
v
for
v
in
keys
if
v
.
split
(
'/'
)[
0
]
not
in
val_partition
]
data_infos
=
[]
for
key
in
keys
:
data_infos
.
append
(
dict
(
lq_path
=
self
.
lq_folder
,
gt_path
=
self
.
gt_folder
,
key
=
key
,
max_frame_num
=
100
,
num_frames
=
self
.
num_frames
))
return
data_infos
ppgan/datasets/vsr_reds_multiple_gt_dataset.py
0 → 100644
浏览文件 @
424ab9eb
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
from
.builder
import
DATASETS
from
.base_sr_dataset
import
BaseDataset
logger
=
logging
.
getLogger
(
__name__
)
@
DATASETS
.
register
()
class
VSRREDSMultipleGTDataset
(
BaseDataset
):
"""REDS dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) frames.
Then it applies specified transforms and finally returns a dict containing
paired data and other information.
Args:
Args:
lq_folder (str): Path to a low quality image folder.
gt_folder (str): Path to a ground truth image folder.
ann_file (str): Path to the annotation file.
num_frames (int): Window size for input frames.
preprocess (list[dict|callable]): A list functions of data transformations.
val_partition (str): Validation partition mode. Choices ['official' or 'REDS4'].
Default: 'REDS4'.
test_mode (bool): Store `True` when building test dataset. Default: `False`.
"""
def
__init__
(
self
,
lq_folder
,
gt_folder
,
ann_file
,
num_frames
,
preprocess
,
val_partition
=
'REDS4'
,
test_mode
=
False
):
super
().
__init__
(
preprocess
)
self
.
lq_folder
=
str
(
lq_folder
)
self
.
gt_folder
=
str
(
gt_folder
)
self
.
ann_file
=
str
(
ann_file
)
self
.
num_frames
=
num_frames
self
.
val_partition
=
val_partition
self
.
test_mode
=
test_mode
self
.
data_infos
=
self
.
prepare_data_infos
()
def
prepare_data_infos
(
self
):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# get keys
with
open
(
self
.
ann_file
,
'r'
)
as
fin
:
keys
=
[
v
.
strip
().
split
(
'/'
)[
0
]
for
v
in
fin
]
keys
=
list
(
set
(
keys
))
if
self
.
val_partition
==
'REDS4'
:
val_partition
=
[
'000'
,
'011'
,
'015'
,
'020'
]
elif
self
.
val_partition
==
'official'
:
val_partition
=
[
f
'
{
v
:
03
d
}
'
for
v
in
range
(
240
,
270
)]
else
:
raise
ValueError
(
f
'Wrong validation partition
{
self
.
val_partition
}
.'
f
'Supported ones are ["official", "REDS4"]'
)
if
self
.
test_mode
:
keys
=
[
v
for
v
in
keys
if
v
in
val_partition
]
else
:
keys
=
[
v
for
v
in
keys
if
v
not
in
val_partition
]
data_infos
=
[]
for
key
in
keys
:
data_infos
.
append
(
dict
(
lq_path
=
self
.
lq_folder
,
gt_path
=
self
.
gt_folder
,
key
=
key
,
sequence_length
=
100
,
num_frames
=
self
.
num_frames
))
return
data_infos
ppgan/models/edvr_model.py
浏览文件 @
424ab9eb
...
...
@@ -48,7 +48,7 @@ class EDVRModel(BaseSRModel):
self
.
visual_items
[
'lq+1'
]
=
self
.
lq
[:,
3
,
:,
:,
:]
self
.
visual_items
[
'lq+2'
]
=
self
.
lq
[:,
4
,
:,
:,
:]
if
'gt'
in
input
:
self
.
gt
=
input
[
'gt'
]
self
.
gt
=
input
[
'gt'
]
[:,
0
,
:,
:,
:]
self
.
visual_items
[
'gt'
]
=
self
.
gt
self
.
image_paths
=
input
[
'lq_path'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录