Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
89dbb63f
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
89dbb63f
编写于
1月 06, 2021
作者:
L
LielinJiang
提交者:
GitHub
1月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix some bugs (#140)
* fix some bugs * update configs
上级
cd642c08
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
38 addition
and
23 deletion
+38
-23
configs/cyclegan_cityscapes.yaml
configs/cyclegan_cityscapes.yaml
+1
-1
configs/cyclegan_horse2zebra.yaml
configs/cyclegan_horse2zebra.yaml
+2
-2
configs/pix2pix_cityscapes.yaml
configs/pix2pix_cityscapes.yaml
+1
-1
configs/pix2pix_cityscapes_2gpus.yaml
configs/pix2pix_cityscapes_2gpus.yaml
+1
-1
configs/pix2pix_facades.yaml
configs/pix2pix_facades.yaml
+1
-1
ppgan/apps/realsr_predictor.py
ppgan/apps/realsr_predictor.py
+2
-1
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+24
-13
ppgan/utils/options.py
ppgan/utils/options.py
+2
-1
ppgan/utils/setup.py
ppgan/utils/setup.py
+4
-2
未找到文件。
configs/cyclegan_cityscapes.yaml
浏览文件 @
89dbb63f
...
...
@@ -67,7 +67,7 @@ dataset:
batch_size
:
1
max_size
:
inf
is_train
:
False
load_pipeline
:
preprocess
:
-
name
:
LoadImageFromFile
key
:
A
-
name
:
LoadImageFromFile
...
...
configs/cyclegan_horse2zebra.yaml
浏览文件 @
89dbb63f
...
...
@@ -35,7 +35,7 @@ dataset:
batch_size
:
1
is_train
:
True
max_size
:
inf
load_pipeline
:
preprocess
:
-
name
:
LoadImageFromFile
key
:
A
-
name
:
LoadImageFromFile
...
...
@@ -67,7 +67,7 @@ dataset:
batch_size
:
1
max_size
:
inf
is_train
:
False
load_pipeline
:
preprocess
:
-
name
:
LoadImageFromFile
key
:
A
-
name
:
LoadImageFromFile
...
...
configs/pix2pix_cityscapes.yaml
浏览文件 @
89dbb63f
...
...
@@ -61,7 +61,7 @@ dataset:
dataroot
:
data/cityscapes/test
num_workers
:
4
batch_size
:
1
load_pipeline
:
preprocess
:
-
name
:
LoadImageFromFile
key
:
pair
-
name
:
SplitPairedImage
...
...
configs/pix2pix_cityscapes_2gpus.yaml
浏览文件 @
89dbb63f
...
...
@@ -61,7 +61,7 @@ dataset:
dataroot
:
data/cityscapes/test
num_workers
:
4
batch_size
:
1
load_pipeline
:
preprocess
:
-
name
:
LoadImageFromFile
key
:
pair
-
name
:
Transforms
...
...
configs/pix2pix_facades.yaml
浏览文件 @
89dbb63f
...
...
@@ -61,7 +61,7 @@ dataset:
dataroot
:
data/facades/test
num_workers
:
4
batch_size
:
1
load_pipeline
:
preprocess
:
-
name
:
LoadImageFromFile
key
:
pair
-
name
:
Transforms
...
...
ppgan/apps/realsr_predictor.py
浏览文件 @
89dbb63f
...
...
@@ -60,7 +60,8 @@ class RealSRPredictor(BasePredictor):
img
=
self
.
norm
(
ori_img
)
x
=
paddle
.
to_tensor
(
img
[
np
.
newaxis
,
...])
out
=
self
.
model
(
x
)
with
paddle
.
no_grad
():
out
=
self
.
model
(
x
)
pred_img
=
self
.
denorm
(
out
.
numpy
()[
0
])
pred_img
=
Image
.
fromarray
(
pred_img
)
...
...
ppgan/engine/trainer.py
浏览文件 @
89dbb63f
...
...
@@ -124,6 +124,9 @@ class Trainer:
self
.
weight_interval
=
cfg
.
snapshot_config
.
interval
self
.
log_interval
=
cfg
.
log_config
.
interval
self
.
visual_interval
=
cfg
.
log_config
.
visiual_interval
if
self
.
by_epoch
:
self
.
weight_interval
*=
self
.
iters_per_epoch
self
.
validate_interval
=
-
1
if
cfg
.
get
(
'validate'
,
None
)
is
not
None
:
self
.
validate_interval
=
cfg
.
validate
.
get
(
'interval'
,
-
1
)
...
...
@@ -177,16 +180,12 @@ class Trainer:
self
.
model
.
lr_scheduler
.
step
()
if
self
.
by_epoch
:
temp
=
self
.
current_epoch
else
:
temp
=
self
.
current_iter
if
self
.
validate_interval
>
-
1
and
temp
%
self
.
validate_interval
==
0
:
if
self
.
validate_interval
>
-
1
and
self
.
current_iter
%
self
.
validate_interval
==
0
:
self
.
test
()
if
temp
%
self
.
weight_interval
==
0
:
self
.
save
(
temp
,
'weight'
,
keep
=-
1
)
self
.
save
(
temp
)
if
self
.
current_iter
%
self
.
weight_interval
==
0
:
self
.
save
(
self
.
current_iter
,
'weight'
,
keep
=-
1
)
self
.
save
(
self
.
current_iter
)
self
.
current_iter
+=
1
...
...
@@ -335,7 +334,12 @@ class Trainer:
assert
name
in
[
'checkpoint'
,
'weight'
]
state_dicts
=
{}
save_filename
=
'epoch_%s_%s.pdparams'
%
(
epoch
,
name
)
if
self
.
by_epoch
:
save_filename
=
'epoch_%s_%s.pdparams'
%
(
epoch
//
self
.
iters_per_epoch
,
name
)
else
:
save_filename
=
'iter_%s_%s.pdparams'
%
(
epoch
,
name
)
save_path
=
os
.
path
.
join
(
self
.
output_dir
,
save_filename
)
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
state_dicts
[
net_name
]
=
net
.
state_dict
()
...
...
@@ -353,9 +357,16 @@ class Trainer:
if
keep
>
0
:
try
:
checkpoint_name_to_be_removed
=
os
.
path
.
join
(
self
.
output_dir
,
'epoch_%s_%s.pdparams'
%
(
epoch
-
keep
,
name
))
if
self
.
by_epoch
:
checkpoint_name_to_be_removed
=
os
.
path
.
join
(
self
.
output_dir
,
'epoch_%s_%s.pdparams'
%
((
epoch
-
keep
*
self
.
weight_interval
)
//
self
.
iters_per_epoch
,
name
))
else
:
checkpoint_name_to_be_removed
=
os
.
path
.
join
(
self
.
output_dir
,
'iter_%s_%s.pdparams'
%
(
epoch
-
keep
*
self
.
weight_interval
,
name
))
if
os
.
path
.
exists
(
checkpoint_name_to_be_removed
):
os
.
remove
(
checkpoint_name_to_be_removed
)
...
...
@@ -366,7 +377,7 @@ class Trainer:
state_dicts
=
load
(
checkpoint_path
)
if
state_dicts
.
get
(
'epoch'
,
None
)
is
not
None
:
self
.
start_epoch
=
state_dicts
[
'epoch'
]
+
1
self
.
global_steps
=
self
.
step
s_per_epoch
*
state_dicts
[
'epoch'
]
self
.
global_steps
=
self
.
iter
s_per_epoch
*
state_dicts
[
'epoch'
]
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
net
.
set_state_dict
(
state_dicts
[
net_name
])
...
...
ppgan/utils/options.py
浏览文件 @
89dbb63f
...
...
@@ -17,7 +17,8 @@ import argparse
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'PaddleGAN'
)
parser
.
add_argument
(
'--config-file'
,
parser
.
add_argument
(
'-c'
,
'--config-file'
,
metavar
=
"FILE"
,
help
=
'config file path'
)
# cuda setting
...
...
ppgan/utils/setup.py
浏览文件 @
89dbb63f
...
...
@@ -26,8 +26,10 @@ def setup(args, cfg):
cfg
.
is_train
=
True
cfg
.
timestamp
=
time
.
strftime
(
'-%Y-%m-%d-%H-%M'
,
time
.
localtime
())
cfg
.
output_dir
=
os
.
path
.
join
(
cfg
.
output_dir
,
str
(
cfg
.
model
.
name
)
+
cfg
.
timestamp
)
cfg
.
output_dir
=
os
.
path
.
join
(
cfg
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
str
(
args
.
config_file
)))[
0
]
+
cfg
.
timestamp
)
logger
=
setup_logger
(
cfg
.
output_dir
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录