Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1dd14a70
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1dd14a70
编写于
7月 19, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix
上级
f9f8fbaa
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
148 addition
and
63 deletion
+148
-63
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+126
-63
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+22
-0
未找到文件。
python/paddle/fluid/trainer.py
浏览文件 @
1dd14a70
...
@@ -360,6 +360,7 @@ class Trainer(object):
...
@@ -360,6 +360,7 @@ class Trainer(object):
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
self
.
train_program
)
self
.
train_program
)
self
.
slice_vars
=
t
.
get_slice_vars_and_atts
(
current_endpoint
)
elif
training_role
==
"TRAINER"
:
elif
training_role
==
"TRAINER"
:
self
.
train_program
=
t
.
get_trainer_program
()
self
.
train_program
=
t
.
get_trainer_program
()
else
:
else
:
...
@@ -474,8 +475,10 @@ class Trainer(object):
...
@@ -474,8 +475,10 @@ class Trainer(object):
self
.
_clean_checkpoint
()
self
.
_clean_checkpoint
()
return
return
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
\
if
self
.
checkpoint_cfg
and
\
and
self
.
checkpoint_cfg
.
step_id
>=
step_id
and
self
.
checkpoint_cfg
.
epoch_id
==
epoch_id
:
self
.
checkpoint_cfg
.
load_serial
is
not
None
and
\
self
.
checkpoint_cfg
.
step_id
>=
step_id
and
\
self
.
checkpoint_cfg
.
epoch_id
==
epoch_id
:
continue
continue
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
...
@@ -569,36 +572,58 @@ class Trainer(object):
...
@@ -569,36 +572,58 @@ class Trainer(object):
def
_load_checkpoint
(
self
):
def
_load_checkpoint
(
self
):
with
self
.
_prog_and_scope_guard
():
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
exe
=
executor
.
Executor
(
self
.
place
)
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
checkpoint_dir
=
_get_serial_dir
(
self
.
checkpoint_cfg
.
checkpoint_dir
,
load_trainer_args
=
self
.
_get_checkpoint_load_args
()
self
.
checkpoint_cfg
.
load_serial
)
trainer_args
=
load_checkpoint
(
# Trainer Load
if
self
.
checkpoint_cfg
.
pserver_id
is
None
:
# load model
load_checkpoint
(
executor
=
exe
,
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
is_trainer
=
True
,
load_
trainer_args
=
load_trainer_args
)
load_
models
=
True
)
if
len
(
trainer_args
)
!=
2
:
# load trainer_args
trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args_ret
=
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_trainer_args
=
trainer_args
)
if
len
(
trainer_args_ret
)
!=
2
:
raise
ValueError
(
raise
ValueError
(
"the return trainer_args length do not equal _get_checkpoint_load_args"
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
)
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args
[
0
])
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args_ret
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args
[
1
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args_ret
[
1
])
# Pserver Load
else
:
else
:
# load slice_vars
if
self
.
slice_vars
!=
None
and
len
(
self
.
slice_vars
)
!=
0
:
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_slice_up_vars
=
self
.
slice_vars
)
# load lookup table
if
self
.
checkpoint_cfg
.
lookup_table_name
:
if
self
.
checkpoint_cfg
.
lookup_table_name
:
load_checkpoint
(
load_checkpoint
(
executor
=
exe
,
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
is_trainer
=
False
,
load_trainer_args
=
None
,
load_lookup_table
=
self
.
checkpoint_cfg
.
lookup_table_name
)
load_lookup_table
=
self
.
checkpoint_cfg
.
lookup_table_name
)
...
@@ -640,7 +665,7 @@ def save_checkpoint(executor,
...
@@ -640,7 +665,7 @@ def save_checkpoint(executor,
main_program
,
main_program
,
trainer_args
=
None
,
trainer_args
=
None
,
max_num_checkpoints
=
3
,
max_num_checkpoints
=
3
,
lookup_table
=
None
,
save_
lookup_table
=
None
,
pserver_endpoints
=
None
):
pserver_endpoints
=
None
):
"""
"""
This function filters out all checkpoint variables from the give
This function filters out all checkpoint variables from the give
...
@@ -673,7 +698,7 @@ def save_checkpoint(executor,
...
@@ -673,7 +698,7 @@ def save_checkpoint(executor,
max_num_checkpoints(int): The max number of total number of existing
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
checkpoints.
Default: 3
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
save_
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
pserver_endpoints(list|None): the parameter server ip:port list.
...
@@ -704,7 +729,7 @@ def save_checkpoint(executor,
...
@@ -704,7 +729,7 @@ def save_checkpoint(executor,
trainer_args=trainer_args,
trainer_args=trainer_args,
main_program=prog,
main_program=prog,
max_num_checkpoints=3,
max_num_checkpoints=3,
lookup_table=table_name,
save_
lookup_table=table_name,
pserver_endpoints = ps_endpoints)
pserver_endpoints = ps_endpoints)
"""
"""
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
...
@@ -720,15 +745,15 @@ def save_checkpoint(executor,
...
@@ -720,15 +745,15 @@ def save_checkpoint(executor,
_make_chekcpoint_dirs
(
checkpoint_dir
)
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
,
True
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
is_chief
:
if
is_chief
:
_save_persist
_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
_save_persist
able_vars
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
pserver_endpoints
:
if
is_chief
and
save_
lookup_table
and
pserver_endpoints
:
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
save_
lookup_table
,
pserver_endpoints
)
pserver_endpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
...
@@ -736,10 +761,12 @@ def save_checkpoint(executor,
...
@@ -736,10 +761,12 @@ def save_checkpoint(executor,
def
load_checkpoint
(
executor
,
def
load_checkpoint
(
executor
,
checkpoint_dir
,
checkpoint_dir
,
main_program
,
main_program
=
None
,
role_id
=
0
,
role_id
=
0
,
is_trainer
=
True
,
is_trainer
=
True
,
load_models
=
True
,
load_trainer_args
=
None
,
load_trainer_args
=
None
,
load_slice_up_vars
=
None
,
load_lookup_table
=
None
):
load_lookup_table
=
None
):
"""
"""
This function filters out all checkpoint variables from the give
This function filters out all checkpoint variables from the give
...
@@ -762,7 +789,7 @@ def load_checkpoint(executor,
...
@@ -762,7 +789,7 @@ def load_checkpoint(executor,
executor(Executor): The executor to run for loading checkpoint.
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
main_program(Program
|None
): The program whose checkpoint variables will
be loaded.
be loaded.
role_id(int): the trainer id or the parameter server id.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
is_trainer(bool): trainer is True and parameter server is False.
...
@@ -794,27 +821,23 @@ def load_checkpoint(executor,
...
@@ -794,27 +821,23 @@ def load_checkpoint(executor,
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
# trainer load
if
is_trainer
:
# there are nothing need to be loaded
if
load_models
:
if
serial
is
None
or
serial
<
0
:
_load_persistable_vars
(
executor
,
checkpoint_dir
,
main_program
,
True
)
return
return
if
load_trainer_args
:
if
main_program
is
None
:
trainer_args_ret
=
_load_trainer_args
(
checkpoint_dir
,
role_id
,
raise
ValueError
(
'main_program should not be None.'
)
load_trainer_args
)
return
trainer_args_ret
if
is_trainer
and
load_trainer_args
is
None
:
# pserver load
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
else
:
_load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
if
load_slice_up_vars
:
return
_load_slice_up_vars
(
executor
,
checkpoint_dir
,
load_slice_up_vars
)
return
if
is_trainer
and
load_trainer_args
:
if
load_lookup_table
:
return
_load_trainer_args
(
checkpoint_dir
,
serial
,
role_id
,
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
load_trainer_args
)
role_id
,
load_lookup_table
)
if
not
is_trainer
and
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
...
@@ -835,10 +858,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
...
@@ -835,10 +858,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os
.
rmdir
(
checkpoint_dir
)
os
.
rmdir
(
checkpoint_dir
)
def
_load_persist_vars_without_grad
(
executor
,
def
_load_persistable_vars
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
dirname
,
program
,
has_model_dir
=
False
):
"""
"""
This function filters out all checkpoint variables from the give
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
program and then trys to load these variables from the given directory.
...
@@ -867,10 +887,10 @@ def _load_persist_vars_without_grad(executor,
...
@@ -867,10 +887,10 @@ def _load_persist_vars_without_grad(executor,
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
prog = fluid.default_main_program()
_load_persist
_vars_without_grad
(executor=exe,
_load_persist
able_vars
(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `_load_persist
_vars_without_grad
` function
# In this example, `_load_persist
able_vars
` function
# will first filters out all checkpoint variables in the default
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
# folder "./my_paddle_model/__model__".
...
@@ -887,6 +907,51 @@ def _load_persist_vars_without_grad(executor,
...
@@ -887,6 +907,51 @@ def _load_persist_vars_without_grad(executor,
filename
=
None
)
filename
=
None
)
def
_load_slice_up_vars
(
executor
,
dirname
,
slice_vars
):
if
slice_vars
==
None
or
len
(
slice_vars
)
==
0
:
return
dirname
=
_get_model_dir
(
dirname
)
load_prog
=
framework
.
Program
()
load_block
=
load_prog
.
global_block
()
for
var_tuple
in
slice_vars
:
orig_var
=
var_tuple
[
0
]
start
=
var_tuple
[
1
]
slice_var
=
var_tuple
[
2
]
end
=
start
+
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
clone_orig_var
=
load_block
.
create_var
(
name
=
orig_var
.
name
,
type
=
orig_var
.
type
,
shape
=
orig_var
.
shape
,
dtype
=
orig_var
.
dtype
,
persistable
=
True
)
clone_slice_var
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
clone_orig_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
clone_orig_var
.
name
)})
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
clone_orig_var
},
outputs
=
{
'Out'
:
clone_slice_var
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
executor
.
run
(
load_prog
)
def
_load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
def
_load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
"""
The parameter server will load lookup table's local file in
The parameter server will load lookup table's local file in
...
@@ -937,7 +1002,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
...
@@ -937,7 +1002,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
executor
.
run
(
load_prog
)
executor
.
run
(
load_prog
)
def
_save_persist
_vars_without_grad
(
executor
,
dirname
,
program
):
def
_save_persist
able_vars
(
executor
,
dirname
,
program
):
"""
"""
This function filters out all checkpoint variables from the give
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
program and then save these variables to a sub-folder '__model__' of
...
@@ -964,10 +1029,10 @@ def _save_persist_vars_without_grad(executor, dirname, program):
...
@@ -964,10 +1029,10 @@ def _save_persist_vars_without_grad(executor, dirname, program):
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
prog = fluid.default_main_program()
_save_persist
_vars_without_grad
(executor=exe,
_save_persist
able_vars
(executor=exe,
dirname=param_path, program=prog)
dirname=param_path, program=prog)
# In this example, `_save_persist
_vars_without_grad
` function
# In this example, `_save_persist
able_vars
` function
# will first filters out all checkpoint variables in the default
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
# "./my_paddle_model/__model__".
...
@@ -1043,7 +1108,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
...
@@ -1043,7 +1108,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
_write_success
(
cur_dir
)
_write_success
(
cur_dir
)
def
_load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
def
_load_trainer_args
(
checkpoint_dir
,
trainer_id
,
trainer_args
):
"""
"""
trainer will load some args from it's independent directory,
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
such as epoch_id and step_id.
...
@@ -1069,8 +1134,7 @@ def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
...
@@ -1069,8 +1134,7 @@ def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
"""
assert
isinstance
(
trainer_args
,
list
)
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_trainer_dir
(
checkpoint_dir
,
trainer_id
)
cur_dir
=
_get_trainer_dir
(
cur_dir
,
trainer_id
)
ret_values
=
[]
ret_values
=
[]
...
@@ -1125,20 +1189,19 @@ def _make_chekcpoint_dirs(dirs):
...
@@ -1125,20 +1189,19 @@ def _make_chekcpoint_dirs(dirs):
def
_get_dir_serial
(
dirname
):
def
_get_dir_serial
(
dirname
):
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
try
:
try
:
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
serial_num
=
int
(
serial
)
serial_num
=
int
(
serial
)
except
ValueError
:
except
ValueError
:
serial_num
=
-
1
serial_num
=
-
1
return
serial_num
return
serial_num
def
_get_serial_dir
(
dirname
,
serial
):
def
_get_serial_dir
(
dirname
,
serial
,
makedirs
=
False
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
_make_chekcpoint_dirs
(
serial_dir
)
if
makedirs
:
_make_chekcpoint_dirs
(
serial_dir
)
return
serial_dir
return
serial_dir
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
1dd14a70
...
@@ -719,6 +719,28 @@ class DistributeTranspiler(object):
...
@@ -719,6 +719,28 @@ class DistributeTranspiler(object):
})
for
ep
in
self
.
pserver_endpoints
})
for
ep
in
self
.
pserver_endpoints
]
]
def
get_slice_vars_and_atts
(
self
,
endpoint
):
slice_vars_and_atts
=
[]
block_suffix
=
".block"
for
param
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
suff_idx
=
param
.
name
.
find
(
block_suffix
)
if
suff_idx
<=
0
:
continue
orig_var_name
=
param
.
name
[:
suff_idx
]
block_idx
=
int
(
param
.
name
[
suff_idx
+
len
(
block_suffix
):])
orig_var
=
self
.
origin_program
.
global_block
().
vars
[
orig_var_name
]
skip_numel
=
0
slice_vars
=
self
.
param_var_mapping
[
orig_var_name
]
for
slice_var
in
slice_vars
[:
block_idx
]:
skip_numel
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
slice_vars_and_atts
.
append
([
orig_var
,
skip_numel
,
param
])
return
slice_vars_and_atts
# transpiler function for dis lookup_table
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
pserver_endpoints
):
pserver_endpoints
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录