Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6e5635fd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e5635fd
编写于
5月 10, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
b1e51836
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
60 addition
and
19 deletion
+60
-19
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+60
-19
未找到文件。
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
6e5635fd
...
...
@@ -279,11 +279,20 @@ class DistributeTranspiler:
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
assert
(
len
(
grad_blocks
)
==
len
(
param_blocks
))
# step2: Create new vars for the parameters and gradients blocks and
# add ops to do the split.
grad_var_mapping
=
self
.
_append_split_op
(
program
,
grad_blocks
)
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
param_blocks
)
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
grad_blocks
,
add_trainer_suffix
=
self
.
trainer_num
>
1
)
grad_param_mapping
=
dict
()
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
grad_param_mapping
[
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
param_var_mapping
[
p_name
][
int
(
p_bid
)]
rpc_client_var
=
program
.
global_block
().
create_var
(
name
=
RPC_CLIENT_VAR_NAME
,
persistable
=
True
,
...
...
@@ -304,15 +313,21 @@ class DistributeTranspiler:
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher
.
reset
()
for
varname
,
send_vars
in
grad_var_mapping
.
items
():
send_vars
=
[]
for
varname
,
splited_vars
in
grad_var_mapping
.
items
():
index
=
find_op_by_output_arg
(
program
.
global_block
(),
varname
)
eplist
=
ps_dispatcher
.
dispatch
(
send_vars
)
eplist
=
ps_dispatcher
.
dispatch
(
splited_vars
)
if
len
(
splited_vars
)
>
1
:
self
.
_insert_split_op
(
program
,
varname
,
splited_vars
)
index
+=
1
program
.
global_block
().
insert_op
(
index
=
index
,
index
=
index
+
1
,
type
=
"send_vars"
,
inputs
=
{
"X"
:
s
en
d_vars
},
inputs
=
{
"X"
:
s
plite
d_vars
},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"epmap"
:
eplist
})
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
if
self
.
sync_mode
:
program
.
global_block
().
append_op
(
...
...
@@ -322,21 +337,12 @@ class DistributeTranspiler:
attrs
=
{
"endpoints"
:
pserver_endpoints
})
# step 3.2: insert recv op to receive parameters from parameter server
ps_dispatcher
.
reset
()
recv_vars
=
[]
for
b
in
param_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
recv_vars
.
append
(
param_var_mapping
[
varname
][
int
(
block_id
)])
for
b
in
grad_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_vars
.
append
(
grad_var_mapping
[
varname
][
int
(
block_id
)])
for
_
,
var
in
enumerate
(
send_vars
):
recv_vars
.
append
(
grad_param_mapping
[
var
])
ps_dispatcher
.
reset
()
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
for
i
,
ep
in
enumerate
(
eplist
):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
program
.
global_block
().
append_op
(
type
=
"recv"
,
inputs
=
{},
...
...
@@ -344,6 +350,10 @@ class DistributeTranspiler:
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"epmap"
:
eplist
})
for
i
,
ep
in
enumerate
(
eplist
):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
# TODO(Yancey1989): check dist lookup table
if
self
.
has_distributed_lookup_table
:
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
rpc_client_var
,
...
...
@@ -848,6 +858,34 @@ class DistributeTranspiler:
lod_level
=
var
.
lod_level
,
persistable
=
persistable
)
def
_insert_split_op
(
self
,
program
,
orig_varname
,
splited_vars
):
orig_var
=
program
.
global_block
().
vars
[
orig_varname
]
index
=
find_op_by_output_arg
(
program
.
global_block
(),
orig_varname
)
if
orig_var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
height_sections
=
[]
for
v
in
splited_vars
:
height_sections
.
append
(
v
.
shape
[
0
])
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
type
=
"split_selected_rows"
,
inputs
=
{
"X"
:
orig_var
},
outputs
=
{
"Out"
:
splited_vars
},
attrs
=
{
"height_sections"
:
height_sections
})
elif
orig_var
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR
:
sections
=
[]
for
v
in
splited_vars
:
sections
.
append
(
v
.
shape
[
0
])
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
type
=
"split_byref"
,
inputs
=
{
"X"
:
orig_var
},
outputs
=
{
"Out"
:
splited_vars
},
attrs
=
{
"sections"
:
sections
}
# assume split evenly
)
else
:
AssertionError
(
"Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]"
)
def
_append_split_op
(
self
,
program
,
gradblocks
):
# Split variables that need to be split and append respective ops
add_suffix
=
False
...
...
@@ -860,11 +898,13 @@ class DistributeTranspiler:
if
len
(
splited_vars
)
<=
1
:
continue
orig_var
=
program
.
global_block
().
vars
[
varname
]
index
=
find_op_by_output_arg
(
program
.
global_block
(),
orig_var
.
name
)
if
orig_var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
height_sections
=
[]
for
v
in
splited_vars
:
height_sections
.
append
(
v
.
shape
[
0
])
program
.
global_block
().
append_op
(
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
type
=
"split_selected_rows"
,
inputs
=
{
"X"
:
orig_var
},
outputs
=
{
"Out"
:
splited_vars
},
...
...
@@ -873,7 +913,8 @@ class DistributeTranspiler:
sections
=
[]
for
v
in
splited_vars
:
sections
.
append
(
v
.
shape
[
0
])
program
.
global_block
().
append_op
(
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
type
=
"split_byref"
,
inputs
=
{
"X"
:
orig_var
},
outputs
=
{
"Out"
:
splited_vars
},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录