Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
b6204126
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b6204126
编写于
12月 17, 2021
作者:
L
LielinJiang
提交者:
GitHub
12月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix mprnet train bug and add docs (#506)
* fix mprnet train and add docs * update config
上级
5bf728df
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
81 addition
and
32 deletion
+81
-32
configs/mprnet_deblurring.yaml
configs/mprnet_deblurring.yaml
+14
-13
docs/en_US/tutorials/single_image_super_resolution.md
docs/en_US/tutorials/single_image_super_resolution.md
+4
-0
docs/zh_CN/tutorials/single_image_super_resolution.md
docs/zh_CN/tutorials/single_image_super_resolution.md
+4
-0
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+29
-8
ppgan/models/criterions/pixel_loss.py
ppgan/models/criterions/pixel_loss.py
+10
-8
ppgan/models/mpr_model.py
ppgan/models/mpr_model.py
+20
-3
未找到文件。
configs/mprnet_deblurring.yaml
浏览文件 @
b6204126
total_iters
:
100000
# epoch: 3000 for total batch size=16
total_iters
:
400000
output_dir
:
output_dir
output_dir
:
output_dir
model
:
model
:
...
@@ -15,38 +16,38 @@ dataset:
...
@@ -15,38 +16,38 @@ dataset:
train
:
train
:
name
:
MPRTrain
name
:
MPRTrain
rgb_dir
:
'
data/GoPro/train'
rgb_dir
:
'
data/GoPro/train'
num_workers
:
16
num_workers
:
4
batch_size
:
4
batch_size
:
2
img_options
:
img_options
:
patch_size
:
256
patch_size
:
256
test
:
test
:
name
:
MPR
Train
name
:
MPR
Val
rgb_dir
:
'
data/GoPro/test'
rgb_dir
:
'
data/GoPro/test'
num_workers
:
16
num_workers
:
4
batch_size
:
4
batch_size
:
2
img_options
:
img_options
:
patch_size
:
256
patch_size
:
256
lr_scheduler
:
lr_scheduler
:
name
:
CosineAnnealingRestartLR
name
:
CosineAnnealingRestartLR
learning_rate
:
!!float
2
e-4
learning_rate
:
!!float
1
e-4
periods
:
[
25000
,
25000
,
25000
,
25
000
]
periods
:
[
400
000
]
restart_weights
:
[
1
,
1
,
1
,
1
]
restart_weights
:
[
1
]
eta_min
:
!!float
1e-6
eta_min
:
!!float
1e-6
validate
:
validate
:
interval
:
1
0
interval
:
500
0
save_img
:
false
save_img
:
false
metrics
:
metrics
:
psnr
:
# metric name, can be arbitrary
psnr
:
# metric name, can be arbitrary
name
:
PSNR
name
:
PSNR
crop_border
:
4
crop_border
:
4
test_y_channel
:
Tru
e
test_y_channel
:
fals
e
ssim
:
ssim
:
name
:
SSIM
name
:
SSIM
crop_border
:
4
crop_border
:
4
test_y_channel
:
Tru
e
test_y_channel
:
fals
e
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
...
@@ -59,7 +60,7 @@ optimizer:
...
@@ -59,7 +60,7 @@ optimizer:
epsilon
:
1e-8
epsilon
:
1e-8
log_config
:
log_config
:
interval
:
10
interval
:
10
0
visiual_interval
:
5000
visiual_interval
:
5000
snapshot_config
:
snapshot_config
:
...
...
docs/en_US/tutorials/single_image_super_resolution.md
浏览文件 @
b6204126
...
@@ -130,6 +130,10 @@ The metrics are PSNR / SSIM.
...
@@ -130,6 +130,10 @@ The metrics are PSNR / SSIM.
| pan_x4 | 30.4574 / 0.8643 | 26.7204 / 0.7434 | 28.9187 / 0.8176 |
| pan_x4 | 30.4574 / 0.8643 | 26.7204 / 0.7434 | 28.9187 / 0.8176 |
| drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - |
| drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - |
Deblur models zoo
| model | GoPro | Download Link |
|---|---|---|
| MPRNet | 33.4360 / 0.9410 |
[
link
](
https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams
)
|
<!-- ![](../../imgs/horse2zebra.png) -->
<!-- ![](../../imgs/horse2zebra.png) -->
...
...
docs/zh_CN/tutorials/single_image_super_resolution.md
浏览文件 @
b6204126
...
@@ -120,6 +120,10 @@ paddle模型使用DIV2K数据集训练,torch模型使用df2k和DIV2K数据集
...
@@ -120,6 +120,10 @@ paddle模型使用DIV2K数据集训练,torch模型使用df2k和DIV2K数据集
| paddle | 30.4574 / 0.8643 | 26.7204 / 0.7434 |
| paddle | 30.4574 / 0.8643 | 26.7204 / 0.7434 |
| torch | 30.2183 / 0.8643 | 26.8035 / 0.7445 |
| torch | 30.2183 / 0.8643 | 26.8035 / 0.7445 |
去模糊模型
| 模型 | GoPro | 下载地址 |
|---|---|---|
| MPRNet | 33.4360 / 0.9410 |
[
链接
](
https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams
)
|
<!-- ![](../../imgs/horse2zebra.png) -->
<!-- ![](../../imgs/horse2zebra.png) -->
...
...
ppgan/engine/trainer.py
浏览文件 @
b6204126
...
@@ -29,6 +29,7 @@ from ..utils.filesystem import makedirs, save, load
...
@@ -29,6 +29,7 @@ from ..utils.filesystem import makedirs, save, load
from
..utils.timer
import
TimeAverager
from
..utils.timer
import
TimeAverager
from
..utils.profiler
import
add_profiler_step
from
..utils.profiler
import
add_profiler_step
class
IterLoader
:
class
IterLoader
:
def
__init__
(
self
,
dataloader
):
def
__init__
(
self
,
dataloader
):
self
.
_dataloader
=
dataloader
self
.
_dataloader
=
dataloader
...
@@ -429,6 +430,17 @@ class Trainer:
...
@@ -429,6 +430,17 @@ class Trainer:
def
load
(
self
,
weight_path
):
def
load
(
self
,
weight_path
):
state_dicts
=
load
(
weight_path
)
state_dicts
=
load
(
weight_path
)
def
is_dict_in_dict_weight
(
state_dict
):
if
isinstance
(
state_dict
,
dict
)
and
len
(
state_dict
)
>
0
:
val
=
list
(
state_dict
.
values
())[
0
]
if
isinstance
(
val
,
dict
):
return
True
else
:
return
False
else
:
return
False
if
is_dict_in_dict_weight
(
state_dicts
):
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
if
net_name
in
state_dicts
:
if
net_name
in
state_dicts
:
net
.
set_state_dict
(
state_dicts
[
net_name
])
net
.
set_state_dict
(
state_dicts
[
net_name
])
...
@@ -438,6 +450,15 @@ class Trainer:
...
@@ -438,6 +450,15 @@ class Trainer:
self
.
logger
.
warning
(
self
.
logger
.
warning
(
'Can not find state dict of net {}. Skip load pretrained weight for net {}'
'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.
format
(
net_name
,
net_name
))
.
format
(
net_name
,
net_name
))
else
:
assert
len
(
self
.
model
.
nets
)
==
1
,
'checkpoint only contain weight of one net,
\
but model contains more than one net!'
net_name
,
net
=
list
(
self
.
model
.
nets
.
items
())[
0
]
net
.
set_state_dict
(
state_dicts
)
self
.
logger
.
info
(
'Loaded pretrained weight for net {}'
.
format
(
net_name
))
def
close
(
self
):
def
close
(
self
):
"""
"""
...
...
ppgan/models/criterions/pixel_loss.py
浏览文件 @
b6204126
...
@@ -249,23 +249,25 @@ class CalcStyleLoss():
...
@@ -249,23 +249,25 @@ class CalcStyleLoss():
class
EdgeLoss
():
class
EdgeLoss
():
def
__init__
(
self
):
def
__init__
(
self
):
k
=
paddle
.
to_tensor
([[.
05
,
.
25
,
.
4
,
.
25
,
.
05
]])
k
=
paddle
.
to_tensor
([[.
05
,
.
25
,
.
4
,
.
25
,
.
05
]])
self
.
kernel
=
paddle
.
matmul
(
k
.
t
(),
k
).
unsqueeze
(
0
).
tile
([
3
,
1
,
1
,
1
])
self
.
kernel
=
paddle
.
matmul
(
k
.
t
(),
k
).
unsqueeze
(
0
).
tile
([
3
,
1
,
1
,
1
])
self
.
loss
=
CharbonnierLoss
()
self
.
loss
=
CharbonnierLoss
()
def
conv_gauss
(
self
,
img
):
def
conv_gauss
(
self
,
img
):
n_channels
,
_
,
kw
,
kh
=
self
.
kernel
.
shape
n_channels
,
_
,
kw
,
kh
=
self
.
kernel
.
shape
img
=
F
.
pad
(
img
,
[
kw
//
2
,
kh
//
2
,
kw
//
2
,
kh
//
2
],
mode
=
'replicate'
)
img
=
F
.
pad
(
img
,
[
kw
//
2
,
kh
//
2
,
kw
//
2
,
kh
//
2
],
mode
=
'replicate'
)
return
F
.
conv2d
(
img
,
self
.
kernel
,
groups
=
n_channels
)
return
F
.
conv2d
(
img
,
self
.
kernel
,
groups
=
n_channels
)
def
laplacian_kernel
(
self
,
current
):
def
laplacian_kernel
(
self
,
current
):
filtered
=
self
.
conv_gauss
(
current
)
# filter
filtered
=
self
.
conv_gauss
(
current
)
# filter
down
=
filtered
[:,:,::
2
,::
2
]
# downsample
down
=
filtered
[:,
:,
::
2
,
::
2
]
# downsample
new_filter
=
paddle
.
zeros_like
(
filtered
)
new_filter
=
paddle
.
zeros_like
(
filtered
)
new_filter
[:,:,::
2
,::
2
]
=
down
*
4
# upsample
new_filter
.
stop_gradient
=
True
new_filter
[:,
:,
::
2
,
::
2
]
=
down
*
4
# upsample
filtered
=
self
.
conv_gauss
(
new_filter
)
# filter
filtered
=
self
.
conv_gauss
(
new_filter
)
# filter
diff
=
current
-
filtered
diff
=
current
-
filtered
return
diff
return
diff
def
__call__
(
self
,
x
,
y
):
def
__call__
(
self
,
x
,
y
):
y
.
stop_gradient
=
True
loss
=
self
.
loss
(
self
.
laplacian_kernel
(
x
),
self
.
laplacian_kernel
(
y
))
loss
=
self
.
loss
(
self
.
laplacian_kernel
(
x
),
self
.
laplacian_kernel
(
y
))
return
loss
return
loss
ppgan/models/mpr_model.py
浏览文件 @
b6204126
...
@@ -20,6 +20,7 @@ from .base_model import BaseModel
...
@@ -20,6 +20,7 @@ from .base_model import BaseModel
from
.generators.builder
import
build_generator
from
.generators.builder
import
build_generator
from
.criterions.builder
import
build_criterion
from
.criterions.builder
import
build_criterion
from
..modules.init
import
reset_parameters
,
init_weights
from
..modules.init
import
reset_parameters
,
init_weights
from
..utils.visual
import
tensor2img
@
MODELS
.
register
()
@
MODELS
.
register
()
...
@@ -50,12 +51,12 @@ class MPRModel(BaseModel):
...
@@ -50,12 +51,12 @@ class MPRModel(BaseModel):
def
setup_input
(
self
,
input
):
def
setup_input
(
self
,
input
):
self
.
target
=
input
[
0
]
self
.
target
=
input
[
0
]
self
.
input_
=
input
[
1
]
self
.
lq
=
input
[
1
]
def
train_iter
(
self
,
optims
=
None
):
def
train_iter
(
self
,
optims
=
None
):
optims
[
'optim'
].
clear_gradients
()
optims
[
'optim'
].
clear_gradients
()
restored
=
self
.
nets
[
'generator'
](
self
.
input_
)
restored
=
self
.
nets
[
'generator'
](
self
.
lq
)
loss_char
=
[]
loss_char
=
[]
loss_edge
=
[]
loss_edge
=
[]
...
@@ -75,5 +76,21 @@ class MPRModel(BaseModel):
...
@@ -75,5 +76,21 @@ class MPRModel(BaseModel):
self
.
losses
[
'loss'
]
=
loss
.
numpy
()
self
.
losses
[
'loss'
]
=
loss
.
numpy
()
def
forward
(
self
):
def
forward
(
self
):
"""Run forward pass; called by both functions <train_iter> and <test_iter>."""
pass
pass
def
test_iter
(
self
,
metrics
=
None
):
self
.
nets
[
'generator'
].
eval
()
with
paddle
.
no_grad
():
self
.
output
=
self
.
nets
[
'generator'
](
self
.
lq
)[
0
]
self
.
visual_items
[
'output'
]
=
self
.
output
self
.
nets
[
'generator'
].
train
()
out_img
=
[]
gt_img
=
[]
for
out_tensor
,
gt_tensor
in
zip
(
self
.
output
,
self
.
target
):
out_img
.
append
(
tensor2img
(
out_tensor
,
(
0.
,
1.
)))
gt_img
.
append
(
tensor2img
(
gt_tensor
,
(
0.
,
1.
)))
if
metrics
is
not
None
:
for
metric
in
metrics
.
values
():
metric
.
update
(
out_img
,
gt_img
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录