Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
605b34ba
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
605b34ba
编写于
10月 28, 2020
作者:
L
LielinJiang
提交者:
GitHub
10月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix resume and multiple gpu train bug (#57)
上级
2d17703b
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
23 addition
and
22 deletion
+23
-22
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+18
-21
ppgan/utils/filesystem.py
ppgan/utils/filesystem.py
+5
-1
未找到文件。
ppgan/engine/trainer.py
浏览文件 @
605b34ba
...
@@ -55,11 +55,8 @@ class Trainer:
...
@@ -55,11 +55,8 @@ class Trainer:
def
distributed_data_parallel
(
self
):
def
distributed_data_parallel
(
self
):
strategy
=
paddle
.
distributed
.
prepare_context
()
strategy
=
paddle
.
distributed
.
prepare_context
()
for
name
in
self
.
model
.
model_names
:
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
if
isinstance
(
name
,
str
):
self
.
model
.
nets
[
net_name
]
=
paddle
.
DataParallel
(
net
,
strategy
)
net
=
getattr
(
self
.
model
,
'net'
+
name
)
setattr
(
self
.
model
,
'net'
+
name
,
paddle
.
DataParallel
(
net
,
strategy
))
def
train
(
self
):
def
train
(
self
):
reader_cost_averager
=
TimeAverager
()
reader_cost_averager
=
TimeAverager
()
...
@@ -77,9 +74,9 @@ class Trainer:
...
@@ -77,9 +74,9 @@ class Trainer:
self
.
model
.
set_input
(
data
)
self
.
model
.
set_input
(
data
)
self
.
model
.
optimize_parameters
()
self
.
model
.
optimize_parameters
()
batch_cost_averager
.
record
(
batch_cost_averager
.
record
(
time
.
time
()
-
step_start_time
,
time
.
time
()
-
step_start_time
,
num_samples
=
self
.
cfg
.
get
(
num_samples
=
self
.
cfg
.
get
(
'batch_size'
,
1
))
'batch_size'
,
1
))
if
i
%
self
.
log_interval
==
0
:
if
i
%
self
.
log_interval
==
0
:
self
.
data_time
=
reader_cost_averager
.
get_average
()
self
.
data_time
=
reader_cost_averager
.
get_average
()
self
.
step_time
=
batch_cost_averager
.
get_average
()
self
.
step_time
=
batch_cost_averager
.
get_average
()
...
@@ -94,8 +91,8 @@ class Trainer:
...
@@ -94,8 +91,8 @@ class Trainer:
step_start_time
=
time
.
time
()
step_start_time
=
time
.
time
()
self
.
logger
.
info
(
self
.
logger
.
info
(
'train one epoch time: {}'
.
format
(
time
.
time
()
-
'train one epoch time: {}'
.
format
(
time
.
time
()
-
start_time
))
start_time
))
if
self
.
validate_interval
>
-
1
and
epoch
%
self
.
validate_interval
:
if
self
.
validate_interval
>
-
1
and
epoch
%
self
.
validate_interval
:
self
.
validate
()
self
.
validate
()
self
.
model
.
lr_scheduler
.
step
()
self
.
model
.
lr_scheduler
.
step
()
...
@@ -105,8 +102,8 @@ class Trainer:
...
@@ -105,8 +102,8 @@ class Trainer:
def
validate
(
self
):
def
validate
(
self
):
if
not
hasattr
(
self
,
'val_dataloader'
):
if
not
hasattr
(
self
,
'val_dataloader'
):
self
.
val_dataloader
=
build_dataloader
(
self
.
val_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
val
,
self
.
cfg
.
dataset
.
val
,
is_train
=
False
)
is_train
=
False
)
metric_result
=
{}
metric_result
=
{}
...
@@ -152,8 +149,8 @@ class Trainer:
...
@@ -152,8 +149,8 @@ class Trainer:
self
.
visual
(
'visual_val'
,
visual_results
=
visual_results
)
self
.
visual
(
'visual_val'
,
visual_results
=
visual_results
)
if
i
%
self
.
log_interval
==
0
:
if
i
%
self
.
log_interval
==
0
:
self
.
logger
.
info
(
self
.
logger
.
info
(
'val iter: [%d/%d]'
%
'val iter: [%d/%d]'
%
(
i
,
len
(
self
.
val_dataloader
)))
(
i
,
len
(
self
.
val_dataloader
)))
for
metric_name
in
metric_result
.
keys
():
for
metric_name
in
metric_result
.
keys
():
metric_result
[
metric_name
]
/=
len
(
self
.
val_dataloader
.
dataset
)
metric_result
[
metric_name
]
/=
len
(
self
.
val_dataloader
.
dataset
)
...
@@ -163,8 +160,8 @@ class Trainer:
...
@@ -163,8 +160,8 @@ class Trainer:
def
test
(
self
):
def
test
(
self
):
if
not
hasattr
(
self
,
'test_dataloader'
):
if
not
hasattr
(
self
,
'test_dataloader'
):
self
.
test_dataloader
=
build_dataloader
(
self
.
test_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
test
,
self
.
cfg
.
dataset
.
test
,
is_train
=
False
)
is_train
=
False
)
# data[0]: img, data[1]: img path index
# data[0]: img, data[1]: img path index
# test batch size must be 1
# test batch size must be 1
...
@@ -188,8 +185,8 @@ class Trainer:
...
@@ -188,8 +185,8 @@ class Trainer:
self
.
visual
(
'visual_test'
,
visual_results
=
visual_results
)
self
.
visual
(
'visual_test'
,
visual_results
=
visual_results
)
if
i
%
self
.
log_interval
==
0
:
if
i
%
self
.
log_interval
==
0
:
self
.
logger
.
info
(
self
.
logger
.
info
(
'Test iter: [%d/%d]'
%
'Test iter: [%d/%d]'
%
(
i
,
len
(
self
.
test_dataloader
)))
(
i
,
len
(
self
.
test_dataloader
)))
def
print_log
(
self
):
def
print_log
(
self
):
losses
=
self
.
model
.
get_current_losses
()
losses
=
self
.
model
.
get_current_losses
()
...
@@ -277,13 +274,13 @@ class Trainer:
...
@@ -277,13 +274,13 @@ class Trainer:
self
.
start_epoch
=
state_dicts
[
'epoch'
]
+
1
self
.
start_epoch
=
state_dicts
[
'epoch'
]
+
1
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
net
.
set_dict
(
state_dicts
[
net_name
])
net
.
set_
state_
dict
(
state_dicts
[
net_name
])
for
opt_name
,
opt
in
self
.
model
.
optimizers
.
items
():
for
opt_name
,
opt
in
self
.
model
.
optimizers
.
items
():
opt
.
set_dict
(
state_dicts
[
opt_name
])
opt
.
set_
state_
dict
(
state_dicts
[
opt_name
])
def
load
(
self
,
weight_path
):
def
load
(
self
,
weight_path
):
state_dicts
=
load
(
weight_path
)
state_dicts
=
load
(
weight_path
)
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
for
net_name
,
net
in
self
.
model
.
nets
.
items
():
net
.
set_dict
(
state_dicts
[
net_name
])
net
.
set_
state_
dict
(
state_dicts
[
net_name
])
ppgan/utils/filesystem.py
浏览文件 @
605b34ba
...
@@ -6,7 +6,11 @@ import paddle
...
@@ -6,7 +6,11 @@ import paddle
def
makedirs
(
dir
):
def
makedirs
(
dir
):
if
not
os
.
path
.
exists
(
dir
):
if
not
os
.
path
.
exists
(
dir
):
# avoid error when train with multiple gpus
try
:
os
.
makedirs
(
dir
)
os
.
makedirs
(
dir
)
except
:
pass
def
save
(
state_dicts
,
file_name
):
def
save
(
state_dicts
,
file_name
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录