Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
9d9f69f0
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9d9f69f0
编写于
1月 29, 2021
作者:
L
LielinJiang
提交者:
GitHub
1月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix sr docs and add div2k process script (#154)
上级
95e5f4f3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
367 addition
and
64 deletion
+367
-64
configs/drn_psnr_x4_div2k.yaml
configs/drn_psnr_x4_div2k.yaml
+2
-2
configs/esrgan_psnr_x4_div2k.yaml
configs/esrgan_psnr_x4_div2k.yaml
+4
-4
configs/esrgan_x4_div2k.yaml
configs/esrgan_x4_div2k.yaml
+2
-2
configs/lesrcnn_psnr_x4_div2k.yaml
configs/lesrcnn_psnr_x4_div2k.yaml
+2
-2
configs/realsr_bicubic_noise_x4_df2k.yaml
configs/realsr_bicubic_noise_x4_df2k.yaml
+2
-2
configs/realsr_kernel_noise_x4_dped.yaml
configs/realsr_kernel_noise_x4_dped.yaml
+2
-2
data/process_div2k_data.py
data/process_div2k_data.py
+290
-0
docs/en_US/tutorials/super_resolution.md
docs/en_US/tutorials/super_resolution.md
+31
-21
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+32
-29
未找到文件。
configs/drn_psnr_x4_div2k.yaml
浏览文件 @
9d9f69f0
...
...
@@ -53,8 +53,8 @@ dataset:
keys
:
[
image
,
image
,
image
]
test
:
name
:
SRDataset
gt_folder
:
data/
DIV2K/val_set14/Set14
lq_folder
:
data/
DIV2K/val_set14/Set14_bicLR
x4
gt_folder
:
data/
Set14/GTmod12
lq_folder
:
data/
Set14/LRbic
x4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
...
...
configs/esrgan_psnr_x4_div2k.yaml
浏览文件 @
9d9f69f0
...
...
@@ -18,8 +18,8 @@ model:
dataset
:
train
:
name
:
SRDataset
gt_folder
:
data/DIV2K/DIV2K_train_HR
_sub
lq_folder
:
data/DIV2K/DIV2K_train_LR_bicubic/X4
_sub
gt_folder
:
data/DIV2K/DIV2K_train_HR
lq_folder
:
data/DIV2K/DIV2K_train_LR_bicubic/X4
num_workers
:
4
batch_size
:
16
scale
:
4
...
...
@@ -49,8 +49,8 @@ dataset:
keys
:
[
image
,
image
]
test
:
name
:
SRDataset
gt_folder
:
data/
DIV2K/val_set14/Set14
lq_folder
:
data/
DIV2K/val_set14/Set14_bicLR
x4
gt_folder
:
data/
Set14/GTmod12
lq_folder
:
data/
Set14/LRbic
x4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
...
...
configs/esrgan_x4_div2k.yaml
浏览文件 @
9d9f69f0
...
...
@@ -65,8 +65,8 @@ dataset:
keys
:
[
image
,
image
]
test
:
name
:
SRDataset
gt_folder
:
data/
DIV2K/val_set14/Set14
lq_folder
:
data/
DIV2K/val_set14/Set14_bicLR
x4
gt_folder
:
data/
Set14/GTmod12
lq_folder
:
data/
Set14/LRbic
x4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
...
...
configs/lesrcnn_psnr_x4_div2k.yaml
浏览文件 @
9d9f69f0
...
...
@@ -45,8 +45,8 @@ dataset:
keys
:
[
image
,
image
]
test
:
name
:
SRDataset
gt_folder
:
data/
DIV2K/val_set14/Set14
lq_folder
:
data/
DIV2K/val_set14/Set14_bicLR
x4
gt_folder
:
data/
Set14/GTmod12
lq_folder
:
data/
Set14/LRbic
x4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
...
...
configs/realsr_bicubic_noise_x4_df2k.yaml
浏览文件 @
9d9f69f0
...
...
@@ -69,8 +69,8 @@ dataset:
keys
:
[
image
]
test
:
name
:
SRDataset
gt_folder
:
data/
DIV2K/val_set14/Set14
lq_folder
:
data/
DIV2K/val_set14/Set14_bicLR
x4
gt_folder
:
data/
Set14/GTmod12
lq_folder
:
data/
Set14/LRbic
x4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
...
...
configs/realsr_kernel_noise_x4_dped.yaml
浏览文件 @
9d9f69f0
...
...
@@ -69,8 +69,8 @@ dataset:
keys
:
[
image
]
test
:
name
:
SRDataset
gt_folder
:
data/
DIV2K/val_set14/Set14
lq_folder
:
data/
DIV2K/val_set14/Set14_bicLR
x4
gt_folder
:
data/
Set14/GTmod12
lq_folder
:
data/
Set14/LRbic
x4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
...
...
data/process_div2k_data.py
0 → 100644
浏览文件 @
9d9f69f0
import
os
import
re
import
sys
import
cv2
import
argparse
import
numpy
as
np
import
os.path
as
osp
from
time
import
time
from
multiprocessing
import
Pool
from
shutil
import
get_terminal_size
from
ppgan.datasets.base_dataset
import
scandir
class
Timer
:
"""A flexible Timer class."""
def
__init__
(
self
,
start
=
True
,
print_tmpl
=
None
):
self
.
_is_running
=
False
self
.
print_tmpl
=
print_tmpl
if
print_tmpl
else
'{:.3f}'
if
start
:
self
.
start
()
@
property
def
is_running
(
self
):
"""bool: indicate whether the timer is running"""
return
self
.
_is_running
def
__enter__
(
self
):
self
.
start
()
return
self
def
__exit__
(
self
,
type
,
value
,
traceback
):
print
(
self
.
print_tmpl
.
format
(
self
.
since_last_check
()))
self
.
_is_running
=
False
def
start
(
self
):
"""Start the timer."""
if
not
self
.
_is_running
:
self
.
_t_start
=
time
()
self
.
_is_running
=
True
self
.
_t_last
=
time
()
def
since_start
(
self
):
"""Total time since the timer is started.
Returns (float): Time in seconds.
"""
if
not
self
.
_is_running
:
raise
ValueError
(
'timer is not running'
)
self
.
_t_last
=
time
()
return
self
.
_t_last
-
self
.
_t_start
def
since_last_check
(
self
):
"""Time since the last checking.
Either :func:`since_start` or :func:`since_last_check` is a checking
operation.
Returns (float): Time in seconds.
"""
if
not
self
.
_is_running
:
raise
ValueError
(
'timer is not running'
)
dur
=
time
()
-
self
.
_t_last
self
.
_t_last
=
time
()
return
dur
class
ProgressBar
:
"""A progress bar which can print the progress."""
def
__init__
(
self
,
task_num
=
0
,
bar_width
=
50
,
start
=
True
,
file
=
sys
.
stdout
):
self
.
task_num
=
task_num
self
.
bar_width
=
bar_width
self
.
completed
=
0
self
.
file
=
file
if
start
:
self
.
start
()
@
property
def
terminal_width
(
self
):
width
,
_
=
get_terminal_size
()
return
width
def
start
(
self
):
if
self
.
task_num
>
0
:
self
.
file
.
write
(
f
'[
{
" "
*
self
.
bar_width
}
] 0/
{
self
.
task_num
}
, '
'elapsed: 0s, ETA:'
)
else
:
self
.
file
.
write
(
'completed: 0, elapsed: 0s'
)
self
.
file
.
flush
()
self
.
timer
=
Timer
()
def
update
(
self
,
num_tasks
=
1
):
assert
num_tasks
>
0
self
.
completed
+=
num_tasks
elapsed
=
self
.
timer
.
since_start
()
if
elapsed
>
0
:
fps
=
self
.
completed
/
elapsed
else
:
fps
=
float
(
'inf'
)
if
self
.
task_num
>
0
:
percentage
=
self
.
completed
/
float
(
self
.
task_num
)
eta
=
int
(
elapsed
*
(
1
-
percentage
)
/
percentage
+
0.5
)
msg
=
f
'
\r
[{{}}]
{
self
.
completed
}
/
{
self
.
task_num
}
, '
\
f
'
{
fps
:.
1
f
}
task/s, elapsed:
{
int
(
elapsed
+
0.5
)
}
s, '
\
f
'ETA:
{
eta
:
5
}
s'
bar_width
=
min
(
self
.
bar_width
,
int
(
self
.
terminal_width
-
len
(
msg
))
+
2
,
int
(
self
.
terminal_width
*
0.6
))
bar_width
=
max
(
2
,
bar_width
)
mark_width
=
int
(
bar_width
*
percentage
)
bar_chars
=
'>'
*
mark_width
+
' '
*
(
bar_width
-
mark_width
)
self
.
file
.
write
(
msg
.
format
(
bar_chars
))
else
:
self
.
file
.
write
(
f
'completed:
{
self
.
completed
}
, elapsed:
{
int
(
elapsed
+
0.5
)
}
s,'
f
'
{
fps
:.
1
f
}
tasks/s'
)
self
.
file
.
flush
()
def
main_extract_subimages
(
args
):
"""A multi-thread tool to crop large images to sub-images for faster IO.
It is used for DIV2K dataset.
args (dict): Configuration dict. It contains:
n_thread (int): Thread number.
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
A higher value means a smaller size and longer compression time.
Use 0 for faster CPU decompression. Default: 3, same in cv2.
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower
than thresh_size will be dropped.
Usage:
For each folder, run this script.
Typically, there are four folders to be processed for DIV2K dataset.
DIV2K_train_HR
DIV2K_train_LR_bicubic/X2
DIV2K_train_LR_bicubic/X3
DIV2K_train_LR_bicubic/X4
After process, each sub_folder should have the same number of
subimages.
Remember to modify opt configurations according to your settings.
"""
opt
=
{}
opt
[
'n_thread'
]
=
args
.
n_thread
opt
[
'compression_level'
]
=
args
.
compression_level
# HR images
opt
[
'input_folder'
]
=
osp
.
join
(
args
.
data_root
,
'DIV2K_train_HR'
)
opt
[
'save_folder'
]
=
osp
.
join
(
args
.
data_root
,
'DIV2K_train_HR_sub'
)
opt
[
'crop_size'
]
=
args
.
crop_size
opt
[
'step'
]
=
args
.
step
opt
[
'thresh_size'
]
=
args
.
thresh_size
extract_subimages
(
opt
)
for
scale
in
[
2
,
3
,
4
]:
opt
[
'input_folder'
]
=
osp
.
join
(
args
.
data_root
,
f
'DIV2K_train_LR_bicubic/X
{
scale
}
'
)
opt
[
'save_folder'
]
=
osp
.
join
(
args
.
data_root
,
f
'DIV2K_train_LR_bicubic/X
{
scale
}
_sub'
)
opt
[
'crop_size'
]
=
args
.
crop_size
//
scale
opt
[
'step'
]
=
args
.
step
//
scale
opt
[
'thresh_size'
]
=
args
.
thresh_size
//
scale
extract_subimages
(
opt
)
def
extract_subimages
(
opt
):
"""Crop images to subimages.
Args:
opt (dict): Configuration dict. It contains:
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
n_thread (int): Thread number.
"""
input_folder
=
opt
[
'input_folder'
]
save_folder
=
opt
[
'save_folder'
]
if
not
osp
.
exists
(
save_folder
):
os
.
makedirs
(
save_folder
)
print
(
f
'mkdir
{
save_folder
}
...'
)
else
:
print
(
f
'Folder
{
save_folder
}
already exists. Exit.'
)
sys
.
exit
(
1
)
img_list
=
list
(
scandir
(
input_folder
))
img_list
=
[
osp
.
join
(
input_folder
,
v
)
for
v
in
img_list
]
prog_bar
=
ProgressBar
(
len
(
img_list
))
pool
=
Pool
(
opt
[
'n_thread'
])
for
path
in
img_list
:
pool
.
apply_async
(
worker
,
args
=
(
path
,
opt
),
callback
=
lambda
arg
:
prog_bar
.
update
())
pool
.
close
()
pool
.
join
()
print
(
'All processes done.'
)
def
worker
(
path
,
opt
):
"""Worker for each process.
Args:
path (str): Image path.
opt (dict): Configuration dict. It contains:
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is smaller
than thresh_size will be dropped.
save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
Returns:
process_info (str): Process information displayed in progress bar.
"""
crop_size
=
opt
[
'crop_size'
]
step
=
opt
[
'step'
]
thresh_size
=
opt
[
'thresh_size'
]
img_name
,
extension
=
osp
.
splitext
(
osp
.
basename
(
path
))
# remove the x2, x3, x4 and x8 in the filename for DIV2K
img_name
=
re
.
sub
(
'x[2348]'
,
''
,
img_name
)
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
if
img
.
ndim
==
2
or
img
.
ndim
==
3
:
h
,
w
=
img
.
shape
[:
2
]
else
:
raise
ValueError
(
f
'Image ndim should be 2 or 3, but got
{
img
.
ndim
}
'
)
h_space
=
np
.
arange
(
0
,
h
-
crop_size
+
1
,
step
)
if
h
-
(
h_space
[
-
1
]
+
crop_size
)
>
thresh_size
:
h_space
=
np
.
append
(
h_space
,
h
-
crop_size
)
w_space
=
np
.
arange
(
0
,
w
-
crop_size
+
1
,
step
)
if
w
-
(
w_space
[
-
1
]
+
crop_size
)
>
thresh_size
:
w_space
=
np
.
append
(
w_space
,
w
-
crop_size
)
index
=
0
for
x
in
h_space
:
for
y
in
w_space
:
index
+=
1
cropped_img
=
img
[
x
:
x
+
crop_size
,
y
:
y
+
crop_size
,
...]
cv2
.
imwrite
(
osp
.
join
(
opt
[
'save_folder'
],
f
'
{
img_name
}
_s
{
index
:
03
d
}{
extension
}
'
),
cropped_img
,
[
cv2
.
IMWRITE_PNG_COMPRESSION
,
opt
[
'compression_level'
]])
process_info
=
f
'Processing
{
img_name
}
...'
return
process_info
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Prepare DIV2K dataset'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
'--data-root'
,
help
=
'dataset root'
)
parser
.
add_argument
(
'--crop-size'
,
nargs
=
'?'
,
default
=
480
,
help
=
'cropped size for HR images'
)
parser
.
add_argument
(
'--step'
,
nargs
=
'?'
,
default
=
240
,
help
=
'step size for HR images'
)
parser
.
add_argument
(
'--thresh-size'
,
nargs
=
'?'
,
default
=
0
,
help
=
'threshold size for HR images'
)
parser
.
add_argument
(
'--compression-level'
,
nargs
=
'?'
,
default
=
3
,
help
=
'compression level when save png images'
)
parser
.
add_argument
(
'--n-thread'
,
nargs
=
'?'
,
default
=
20
,
help
=
'thread number when using multiprocessing'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
args
=
parse_args
()
# extract subimages
main_extract_subimages
(
args
)
docs/en_US/tutorials/super_resolution.md
浏览文件 @
9d9f69f0
...
...
@@ -20,27 +20,36 @@
| Classical SR Testing | Set5 | Set5 test dataset |
[
Google Drive
](
https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u
)
/
[
Baidu Drive
](
https://pan.baidu.com/s/1q_1ERCMqALH0xFwjLM0pTg#list/path=%2Fsharelink2016187762-785433459861126%2Fclassical_SR_datasets&parentPath=%2Fsharelink2016187762-785433459861126
)
|
| Classical SR Testing | Set14 | Set14 test dataset |
[
Google Drive
](
https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u
)
/
[
Baidu Drive
](
https://pan.baidu.com/s/1q_1ERCMqALH0xFwjLM0pTg#list/path=%2Fsharelink2016187762-785433459861126%2Fclassical_SR_datasets&parentPath=%2Fsharelink2016187762-785433459861126
)
|
The structure of DIV2K is as following:
```
DIV2K
├── DIV2K_train_HR
├── DIV2K_train_LR_bicubic
| ├──X2
| ├──X3
| └──X4
├── DIV2K_valid_HR
├── DIV2K_valid_LR_bicubic
...
```
The structures of Set5 and Set14 are similar. Taking Set5 as an example, the structure is as following:
```
Set5
├── GTmod12
├── LRbicx2
├── LRbicx3
├── LRbicx4
└── original
The structure of DIV2K, Set5 and Set14 is as following:
```
PaddleGAN
├── data
├── DIV2K
├── DIV2K_train_HR
├── DIV2K_train_LR_bicubic
| ├──X2
| ├──X3
| └──X4
├── DIV2K_valid_HR
├── DIV2K_valid_LR_bicubic
Set5
├── GTmod12
├── LRbicx2
├── LRbicx3
├── LRbicx4
└── original
Set14
├── GTmod12
├── LRbicx2
├── LRbicx3
├── LRbicx4
└── original
...
```
Use the following commands to process the DIV2K data set:
```
python data/process_div2k_data.py --data-root data/DIV2K
```
...
...
@@ -71,6 +80,7 @@ The metrics are PSNR / SSIM.
| lesrcnn_x4 | 31.9476 / 0.8909 | 28.4110 / 0.7770 | 30.231 / 0.8326 |
| esrgan_psnr_x4 | 32.5512 / 0.8991 | 28.8114 / 0.7871 | 30.7565 / 0.8449 |
| esrgan_x4 | 28.7647 / 0.8187 | 25.0065 / 0.6762 | 26.9013 / 0.7542 |
| drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - |
<!-- ![](../../imgs/horse2zebra.png) -->
...
...
ppgan/engine/trainer.py
浏览文件 @
9d9f69f0
...
...
@@ -72,6 +72,23 @@ class Trainer:
# save checkpoint (model.nets) \/
"""
def
__init__
(
self
,
cfg
):
# base config
self
.
logger
=
logging
.
getLogger
(
__name__
)
self
.
cfg
=
cfg
self
.
output_dir
=
cfg
.
output_dir
self
.
max_eval_steps
=
cfg
.
model
.
get
(
'max_eval_steps'
,
None
)
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
log_interval
=
cfg
.
log_config
.
interval
self
.
visual_interval
=
cfg
.
log_config
.
visiual_interval
self
.
weight_interval
=
cfg
.
snapshot_config
.
interval
self
.
start_epoch
=
1
self
.
current_epoch
=
1
self
.
current_iter
=
1
self
.
inner_iter
=
1
self
.
batch_id
=
0
self
.
global_steps
=
0
# build model
self
.
model
=
build_model
(
cfg
.
model
)
...
...
@@ -79,6 +96,21 @@ class Trainer:
if
ParallelEnv
().
nranks
>
1
:
self
.
distributed_data_parallel
()
# build metrics
self
.
metrics
=
None
validate_cfg
=
cfg
.
get
(
'validate'
,
None
)
if
validate_cfg
and
'metrics'
in
validate_cfg
:
self
.
metrics
=
self
.
model
.
setup_metrics
(
validate_cfg
[
'metrics'
])
self
.
enable_visualdl
=
cfg
.
get
(
'enable_visualdl'
,
False
)
if
self
.
enable_visualdl
:
import
visualdl
self
.
vdl_logger
=
visualdl
.
LogWriter
(
logdir
=
cfg
.
output_dir
)
# evaluate only
if
not
cfg
.
is_train
:
return
# build train dataloader
self
.
train_dataloader
=
build_dataloader
(
cfg
.
dataset
.
train
)
self
.
iters_per_epoch
=
len
(
self
.
train_dataloader
)
...
...
@@ -93,21 +125,6 @@ class Trainer:
self
.
optimizers
=
self
.
model
.
setup_optimizers
(
self
.
lr_schedulers
,
cfg
.
optimizer
)
# build metrics
self
.
metrics
=
None
validate_cfg
=
cfg
.
get
(
'validate'
,
None
)
if
validate_cfg
and
'metrics'
in
validate_cfg
:
self
.
metrics
=
self
.
model
.
setup_metrics
(
validate_cfg
[
'metrics'
])
self
.
logger
=
logging
.
getLogger
(
__name__
)
self
.
enable_visualdl
=
cfg
.
get
(
'enable_visualdl'
,
False
)
if
self
.
enable_visualdl
:
import
visualdl
self
.
vdl_logger
=
visualdl
.
LogWriter
(
logdir
=
cfg
.
output_dir
)
# base config
self
.
output_dir
=
cfg
.
output_dir
self
.
max_eval_steps
=
cfg
.
model
.
get
(
'max_eval_steps'
,
None
)
self
.
epochs
=
cfg
.
get
(
'epochs'
,
None
)
if
self
.
epochs
:
self
.
total_iters
=
self
.
epochs
*
self
.
iters_per_epoch
...
...
@@ -116,26 +133,12 @@ class Trainer:
self
.
by_epoch
=
False
self
.
total_iters
=
cfg
.
total_iters
self
.
start_epoch
=
1
self
.
current_epoch
=
1
self
.
current_iter
=
1
self
.
inner_iter
=
1
self
.
batch_id
=
0
self
.
global_steps
=
0
self
.
weight_interval
=
cfg
.
snapshot_config
.
interval
if
self
.
by_epoch
:
self
.
weight_interval
*=
self
.
iters_per_epoch
self
.
log_interval
=
cfg
.
log_config
.
interval
self
.
visual_interval
=
cfg
.
log_config
.
visiual_interval
if
self
.
by_epoch
:
self
.
weight_interval
*=
self
.
iters_per_epoch
self
.
validate_interval
=
-
1
if
cfg
.
get
(
'validate'
,
None
)
is
not
None
:
self
.
validate_interval
=
cfg
.
validate
.
get
(
'interval'
,
-
1
)
self
.
cfg
=
cfg
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
time_count
=
{}
self
.
best_metric
=
{}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录