Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1dd14a70
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1dd14a70
编写于
7月 19, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix
上级
f9f8fbaa
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
148 addition
and
63 deletion
+148
-63
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+126
-63
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+22
-0
未找到文件。
python/paddle/fluid/trainer.py
浏览文件 @
1dd14a70
...
...
@@ -360,6 +360,7 @@ class Trainer(object):
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
self
.
train_program
)
self
.
slice_vars
=
t
.
get_slice_vars_and_atts
(
current_endpoint
)
elif
training_role
==
"TRAINER"
:
self
.
train_program
=
t
.
get_trainer_program
()
else
:
...
...
@@ -474,8 +475,10 @@ class Trainer(object):
self
.
_clean_checkpoint
()
return
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
\
and
self
.
checkpoint_cfg
.
step_id
>=
step_id
and
self
.
checkpoint_cfg
.
epoch_id
==
epoch_id
:
if
self
.
checkpoint_cfg
and
\
self
.
checkpoint_cfg
.
load_serial
is
not
None
and
\
self
.
checkpoint_cfg
.
step_id
>=
step_id
and
\
self
.
checkpoint_cfg
.
epoch_id
==
epoch_id
:
continue
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
...
...
@@ -569,36 +572,58 @@ class Trainer(object):
def
_load_checkpoint
(
self
):
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
load_trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args
=
load_checkpoint
(
checkpoint_dir
=
_get_serial_dir
(
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
checkpoint_cfg
.
load_serial
)
# Trainer Load
if
self
.
checkpoint_cfg
.
pserver_id
is
None
:
# load model
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_
trainer_args
=
load_trainer_args
)
load_
models
=
True
)
if
len
(
trainer_args
)
!=
2
:
# load trainer_args
trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args_ret
=
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_trainer_args
=
trainer_args
)
if
len
(
trainer_args_ret
)
!=
2
:
raise
ValueError
(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args
[
1
])
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args_ret
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args_ret
[
1
])
# Pserver Load
else
:
# load slice_vars
if
self
.
slice_vars
!=
None
and
len
(
self
.
slice_vars
)
!=
0
:
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_slice_up_vars
=
self
.
slice_vars
)
# load lookup table
if
self
.
checkpoint_cfg
.
lookup_table_name
:
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
checkpoint_dir
=
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_trainer_args
=
None
,
load_lookup_table
=
self
.
checkpoint_cfg
.
lookup_table_name
)
...
...
@@ -640,7 +665,7 @@ def save_checkpoint(executor,
main_program
,
trainer_args
=
None
,
max_num_checkpoints
=
3
,
lookup_table
=
None
,
save_
lookup_table
=
None
,
pserver_endpoints
=
None
):
"""
This function filters out all checkpoint variables from the give
...
...
@@ -673,7 +698,7 @@ def save_checkpoint(executor,
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
save_
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
...
...
@@ -704,7 +729,7 @@ def save_checkpoint(executor,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
save_
lookup_table=table_name,
pserver_endpoints = ps_endpoints)
"""
if
checkpoint_dir
is
None
:
...
...
@@ -720,15 +745,15 @@ def save_checkpoint(executor,
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
,
True
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
is_chief
:
_save_persist
_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
_save_persist
able_vars
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
pserver_endpoints
:
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
if
is_chief
and
save_
lookup_table
and
pserver_endpoints
:
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
save_
lookup_table
,
pserver_endpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
...
...
@@ -736,10 +761,12 @@ def save_checkpoint(executor,
def
load_checkpoint
(
executor
,
checkpoint_dir
,
main_program
,
main_program
=
None
,
role_id
=
0
,
is_trainer
=
True
,
load_models
=
True
,
load_trainer_args
=
None
,
load_slice_up_vars
=
None
,
load_lookup_table
=
None
):
"""
This function filters out all checkpoint variables from the give
...
...
@@ -762,7 +789,7 @@ def load_checkpoint(executor,
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
main_program(Program
|None
): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
...
...
@@ -794,27 +821,23 @@ def load_checkpoint(executor,
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
# there are nothing need to be loaded
if
serial
is
None
or
serial
<
0
:
return
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
is_trainer
and
load_trainer_args
is
None
:
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
_load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
return
if
is_trainer
and
load_trainer_args
:
return
_load_trainer_args
(
checkpoint_dir
,
serial
,
role_id
,
load_trainer_args
)
if
not
is_trainer
and
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
# trainer load
if
is_trainer
:
if
load_models
:
_load_persistable_vars
(
executor
,
checkpoint_dir
,
main_program
,
True
)
return
if
load_trainer_args
:
trainer_args_ret
=
_load_trainer_args
(
checkpoint_dir
,
role_id
,
load_trainer_args
)
return
trainer_args_ret
# pserver load
else
:
if
load_slice_up_vars
:
_load_slice_up_vars
(
executor
,
checkpoint_dir
,
load_slice_up_vars
)
return
if
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
...
...
@@ -835,10 +858,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os
.
rmdir
(
checkpoint_dir
)
def
_load_persist_vars_without_grad
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
def
_load_persistable_vars
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
...
...
@@ -867,10 +887,10 @@ def _load_persist_vars_without_grad(executor,
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_load_persist
_vars_without_grad
(executor=exe,
_load_persist
able_vars
(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `_load_persist
_vars_without_grad
` function
# In this example, `_load_persist
able_vars
` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
...
...
@@ -887,6 +907,51 @@ def _load_persist_vars_without_grad(executor,
filename
=
None
)
def
_load_slice_up_vars
(
executor
,
dirname
,
slice_vars
):
if
slice_vars
==
None
or
len
(
slice_vars
)
==
0
:
return
dirname
=
_get_model_dir
(
dirname
)
load_prog
=
framework
.
Program
()
load_block
=
load_prog
.
global_block
()
for
var_tuple
in
slice_vars
:
orig_var
=
var_tuple
[
0
]
start
=
var_tuple
[
1
]
slice_var
=
var_tuple
[
2
]
end
=
start
+
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
clone_orig_var
=
load_block
.
create_var
(
name
=
orig_var
.
name
,
type
=
orig_var
.
type
,
shape
=
orig_var
.
shape
,
dtype
=
orig_var
.
dtype
,
persistable
=
True
)
clone_slice_var
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
clone_orig_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
clone_orig_var
.
name
)})
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
clone_orig_var
},
outputs
=
{
'Out'
:
clone_slice_var
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
executor
.
run
(
load_prog
)
def
_load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
...
...
@@ -937,7 +1002,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
executor
.
run
(
load_prog
)
def
_save_persist
_vars_without_grad
(
executor
,
dirname
,
program
):
def
_save_persist
able_vars
(
executor
,
dirname
,
program
):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
...
...
@@ -964,10 +1029,10 @@ def _save_persist_vars_without_grad(executor, dirname, program):
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_save_persist
_vars_without_grad
(executor=exe,
_save_persist
able_vars
(executor=exe,
dirname=param_path, program=prog)
# In this example, `_save_persist
_vars_without_grad
` function
# In this example, `_save_persist
able_vars
` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
...
...
@@ -1043,7 +1108,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
_write_success
(
cur_dir
)
def
_load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
def
_load_trainer_args
(
checkpoint_dir
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
...
...
@@ -1069,8 +1134,7 @@ def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_trainer_dir
(
cur_dir
,
trainer_id
)
cur_dir
=
_get_trainer_dir
(
checkpoint_dir
,
trainer_id
)
ret_values
=
[]
...
...
@@ -1125,20 +1189,19 @@ def _make_chekcpoint_dirs(dirs):
def
_get_dir_serial
(
dirname
):
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
try
:
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
serial_num
=
int
(
serial
)
except
ValueError
:
serial_num
=
-
1
return
serial_num
def
_get_serial_dir
(
dirname
,
serial
):
def
_get_serial_dir
(
dirname
,
serial
,
makedirs
=
False
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
_make_chekcpoint_dirs
(
serial_dir
)
if
makedirs
:
_make_chekcpoint_dirs
(
serial_dir
)
return
serial_dir
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
1dd14a70
...
...
@@ -719,6 +719,28 @@ class DistributeTranspiler(object):
})
for
ep
in
self
.
pserver_endpoints
]
def
get_slice_vars_and_atts
(
self
,
endpoint
):
slice_vars_and_atts
=
[]
block_suffix
=
".block"
for
param
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
suff_idx
=
param
.
name
.
find
(
block_suffix
)
if
suff_idx
<=
0
:
continue
orig_var_name
=
param
.
name
[:
suff_idx
]
block_idx
=
int
(
param
.
name
[
suff_idx
+
len
(
block_suffix
):])
orig_var
=
self
.
origin_program
.
global_block
().
vars
[
orig_var_name
]
skip_numel
=
0
slice_vars
=
self
.
param_var_mapping
[
orig_var_name
]
for
slice_var
in
slice_vars
[:
block_idx
]:
skip_numel
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
slice_vars_and_atts
.
append
([
orig_var
,
skip_numel
,
param
])
return
slice_vars_and_atts
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
pserver_endpoints
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录