Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1dd14a70
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1dd14a70
编写于
7月 19, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix
上级
f9f8fbaa
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
148 addition
and
63 deletion
+148
-63
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+126
-63
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+22
-0
未找到文件。
python/paddle/fluid/trainer.py
浏览文件 @
1dd14a70
...
...
@@ -360,6 +360,7 @@ class Trainer(object):
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
self
.
train_program
)
self
.
slice_vars
=
t
.
get_slice_vars_and_atts
(
current_endpoint
)
elif
training_role
==
"TRAINER"
:
self
.
train_program
=
t
.
get_trainer_program
()
else
:
...
...
@@ -474,8 +475,10 @@ class Trainer(object):
self
.
_clean_checkpoint
()
return
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
\
and
self
.
checkpoint_cfg
.
step_id
>=
step_id
and
self
.
checkpoint_cfg
.
epoch_id
==
epoch_id
:
if
self
.
checkpoint_cfg
and
\
self
.
checkpoint_cfg
.
load_serial
is
not
None
and
\
self
.
checkpoint_cfg
.
step_id
>=
step_id
and
\
self
.
checkpoint_cfg
.
epoch_id
==
epoch_id
:
continue
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
...
...
@@ -569,36 +572,58 @@ class Trainer(object):
def
_load_checkpoint
(
self
):
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
load_trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args
=
load_checkpoint
(
checkpoint_dir
=
_get_serial_dir
(
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
checkpoint_cfg
.
load_serial
)
# Trainer Load
if
self
.
checkpoint_cfg
.
pserver_id
is
None
:
# load model
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_
trainer_args
=
load_trainer_args
)
load_
models
=
True
)
if
len
(
trainer_args
)
!=
2
:
# load trainer_args
trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args_ret
=
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_trainer_args
=
trainer_args
)
if
len
(
trainer_args_ret
)
!=
2
:
raise
ValueError
(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args
[
1
])
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args_ret
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args_ret
[
1
])
# Pserver Load
else
:
# load slice_vars
if
self
.
slice_vars
!=
None
and
len
(
self
.
slice_vars
)
!=
0
:
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_slice_up_vars
=
self
.
slice_vars
)
# load lookup table
if
self
.
checkpoint_cfg
.
lookup_table_name
:
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_trainer_args
=
None
,
load_lookup_table
=
self
.
checkpoint_cfg
.
lookup_table_name
)
...
...
@@ -640,7 +665,7 @@ def save_checkpoint(executor,
main_program
,
trainer_args
=
None
,
max_num_checkpoints
=
3
,
lookup_table
=
None
,
save_
lookup_table
=
None
,
pserver_endpoints
=
None
):
"""
This function filters out all checkpoint variables from the give
...
...
@@ -673,7 +698,7 @@ def save_checkpoint(executor,
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
save_
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
...
...
@@ -704,7 +729,7 @@ def save_checkpoint(executor,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
save_
lookup_table=table_name,
pserver_endpoints = ps_endpoints)
"""
if
checkpoint_dir
is
None
:
...
...
@@ -720,15 +745,15 @@ def save_checkpoint(executor,
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
,
True
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
is_chief
:
_save_persist
_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
_save_persist
able_vars
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
pserver_endpoints
:
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
if
is_chief
and
save_
lookup_table
and
pserver_endpoints
:
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
save_
lookup_table
,
pserver_endpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
...
...
@@ -736,10 +761,12 @@ def save_checkpoint(executor,
def
load_checkpoint
(
executor
,
checkpoint_dir
,
main_program
,
main_program
=
None
,
role_id
=
0
,
is_trainer
=
True
,
load_models
=
True
,
load_trainer_args
=
None
,
load_slice_up_vars
=
None
,
load_lookup_table
=
None
):
"""
This function filters out all checkpoint variables from the give
...
...
@@ -762,7 +789,7 @@ def load_checkpoint(executor,
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
main_program(Program
|None
): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
...
...
@@ -794,27 +821,23 @@ def load_checkpoint(executor,
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
# there are nothing need to be loaded
if
serial
is
None
or
serial
<
0
:
return
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
is_trainer
and
load_trainer_args
is
None
:
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
_load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
return
if
is_trainer
and
load_trainer_args
:
return
_load_trainer_args
(
checkpoint_dir
,
serial
,
role_id
,
load_trainer_args
)
if
not
is_trainer
and
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
# trainer load
if
is_trainer
:
if
load_models
:
_load_persistable_vars
(
executor
,
checkpoint_dir
,
main_program
,
True
)
return
if
load_trainer_args
:
trainer_args_ret
=
_load_trainer_args
(
checkpoint_dir
,
role_id
,
load_trainer_args
)
return
trainer_args_ret
# pserver load
else
:
if
load_slice_up_vars
:
_load_slice_up_vars
(
executor
,
checkpoint_dir
,
load_slice_up_vars
)
return
if
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
...
...
@@ -835,10 +858,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os
.
rmdir
(
checkpoint_dir
)
def
_load_persist_vars_without_grad
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
def
_load_persistable_vars
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
...
...
@@ -867,10 +887,10 @@ def _load_persist_vars_without_grad(executor,
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_load_persist
_vars_without_grad
(executor=exe,
_load_persist
able_vars
(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `_load_persist
_vars_without_grad
` function
# In this example, `_load_persist
able_vars
` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
...
...
@@ -887,6 +907,51 @@ def _load_persist_vars_without_grad(executor,
filename
=
None
)
def
_load_slice_up_vars
(
executor
,
dirname
,
slice_vars
):
if
slice_vars
==
None
or
len
(
slice_vars
)
==
0
:
return
dirname
=
_get_model_dir
(
dirname
)
load_prog
=
framework
.
Program
()
load_block
=
load_prog
.
global_block
()
for
var_tuple
in
slice_vars
:
orig_var
=
var_tuple
[
0
]
start
=
var_tuple
[
1
]
slice_var
=
var_tuple
[
2
]
end
=
start
+
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
clone_orig_var
=
load_block
.
create_var
(
name
=
orig_var
.
name
,
type
=
orig_var
.
type
,
shape
=
orig_var
.
shape
,
dtype
=
orig_var
.
dtype
,
persistable
=
True
)
clone_slice_var
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
clone_orig_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
clone_orig_var
.
name
)})
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
clone_orig_var
},
outputs
=
{
'Out'
:
clone_slice_var
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
executor
.
run
(
load_prog
)
def
_load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
...
...
@@ -937,7 +1002,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
executor
.
run
(
load_prog
)
def
_save_persist
_vars_without_grad
(
executor
,
dirname
,
program
):
def
_save_persist
able_vars
(
executor
,
dirname
,
program
):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
...
...
@@ -964,10 +1029,10 @@ def _save_persist_vars_without_grad(executor, dirname, program):
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_save_persist
_vars_without_grad
(executor=exe,
_save_persist
able_vars
(executor=exe,
dirname=param_path, program=prog)
# In this example, `_save_persist
_vars_without_grad
` function
# In this example, `_save_persist
able_vars
` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
...
...
@@ -1043,7 +1108,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
_write_success
(
cur_dir
)
def
_load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
def
_load_trainer_args
(
checkpoint_dir
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
...
...
@@ -1069,8 +1134,7 @@ def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_trainer_dir
(
cur_dir
,
trainer_id
)
cur_dir
=
_get_trainer_dir
(
checkpoint_dir
,
trainer_id
)
ret_values
=
[]
...
...
@@ -1125,20 +1189,19 @@ def _make_chekcpoint_dirs(dirs):
def
_get_dir_serial
(
dirname
):
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
try
:
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
serial_num
=
int
(
serial
)
except
ValueError
:
serial_num
=
-
1
return
serial_num
def
_get_serial_dir
(
dirname
,
serial
):
def
_get_serial_dir
(
dirname
,
serial
,
makedirs
=
False
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
_make_chekcpoint_dirs
(
serial_dir
)
if
makedirs
:
_make_chekcpoint_dirs
(
serial_dir
)
return
serial_dir
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
1dd14a70
...
...
@@ -719,6 +719,28 @@ class DistributeTranspiler(object):
})
for
ep
in
self
.
pserver_endpoints
]
def
get_slice_vars_and_atts
(
self
,
endpoint
):
slice_vars_and_atts
=
[]
block_suffix
=
".block"
for
param
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
suff_idx
=
param
.
name
.
find
(
block_suffix
)
if
suff_idx
<=
0
:
continue
orig_var_name
=
param
.
name
[:
suff_idx
]
block_idx
=
int
(
param
.
name
[
suff_idx
+
len
(
block_suffix
):])
orig_var
=
self
.
origin_program
.
global_block
().
vars
[
orig_var_name
]
skip_numel
=
0
slice_vars
=
self
.
param_var_mapping
[
orig_var_name
]
for
slice_var
in
slice_vars
[:
block_idx
]:
skip_numel
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
slice_vars_and_atts
.
append
([
orig_var
,
skip_numel
,
param
])
return
slice_vars_and_atts
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
pserver_endpoints
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录