Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cc6ef41d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cc6ef41d
编写于
11月 25, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update dist transpiler
上级
47280ef8
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
56 addition
and
23 deletion
+56
-23
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+56
-23
未找到文件。
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
cc6ef41d
...
...
@@ -236,6 +236,22 @@ class DistributeTranspiler(object):
else
:
raise
ValueError
(
"must set trainer_id > 0"
)
def
_get_all_sparse_update_op
(
self
,
main_program
):
sparse_update_ops
=
[]
sparse_update_op_types
=
[
"lookup_table"
]
for
op
in
main_program
.
global_block
().
ops
:
if
op
.
type
in
sparse_update_op_types
and
op
.
attr
(
'is_sparse'
)
is
True
and
not
op
.
attr
(
'is_distributed'
):
sparse_update_ops
.
append
(
op
)
return
sparse_update_ops
def
_update_sparse_update_op
(
self
,
param_varname
,
height_sections
,
endpint_map
):
for
op
in
self
.
sparse_update_ops
:
if
param_varname
in
op
.
input_arg_names
:
op
.
_set_attr
(
'epmap'
,
endpint_map
)
op
.
_set_attr
(
'height_sections'
,
height_sections
)
def
transpile
(
self
,
trainer_id
,
program
=
None
,
...
...
@@ -299,6 +315,11 @@ class DistributeTranspiler(object):
self
.
param_name_to_grad_name
[
param_var
.
name
]
=
grad_var
.
name
self
.
grad_name_to_param_name
[
grad_var
.
name
]
=
param_var
.
name
# get all sparse update ops
self
.
sparse_update_ops
=
self
.
_get_all_sparse_update_op
(
self
.
origin_program
)
self
.
sparse_param_to_height_sections
=
dict
()
# add distributed attrs to program
self
.
origin_program
.
_is_distributed
=
True
self
.
origin_program
.
_endpoints
=
self
.
pserver_endpoints
...
...
@@ -425,6 +446,12 @@ class DistributeTranspiler(object):
if
len
(
splited_trainer_grad
)
==
1
:
recv_op_role_var_name
=
splited_trainer_grad
[
0
].
name
if
param_varname
in
self
.
sparse_param_to_height_sections
:
height_sections
=
self
.
sparse_param_to_height_sections
[
param_varname
]
self
.
_update_sparse_update_op
(
param_varname
,
height_sections
,
eps
)
else
:
program
.
global_block
().
append_op
(
type
=
"recv"
,
inputs
=
{
"X"
:
[
recv_dep_in
]},
...
...
@@ -454,6 +481,9 @@ class DistributeTranspiler(object):
if
len
(
splited_var
)
<=
1
:
continue
orig_param
=
program
.
global_block
().
vars
[
param_varname
]
print
(
"sparse_param_to_height_sections: "
+
str
(
self
.
sparse_param_to_height_sections
))
if
param_varname
not
in
self
.
sparse_param_to_height_sections
:
program
.
global_block
().
append_op
(
type
=
"concat"
,
inputs
=
{
"X"
:
splited_var
},
...
...
@@ -1237,9 +1267,8 @@ to transpile() call.")
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op
=
[
op
for
op
in
self
.
optimize_ops
if
'Param'
in
op
.
input_names
and
op
.
input
(
"Param"
)[
0
]
==
self
.
table_name
op
for
op
in
self
.
optimize_ops
if
'Param'
in
op
.
input_names
and
op
.
input
(
"Param"
)[
0
]
==
self
.
table_name
][
0
]
origin_param_var
=
self
.
origin_program
.
global_block
().
vars
[
...
...
@@ -1418,6 +1447,10 @@ to transpile() call.")
height_sections
=
[]
for
v
in
splited_vars
:
height_sections
.
append
(
v
.
shape
[
0
])
sparse_param_name
=
self
.
grad_name_to_param_name
[
orig_var
.
name
]
if
sparse_param_name
!=
self
.
table_name
:
self
.
sparse_param_to_height_sections
[
sparse_param_name
]
=
height_sections
program
.
global_block
().
_insert_op
(
index
=
index
+
1
,
type
=
"split_selected_rows"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录