Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9f4b66f6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9f4b66f6
编写于
5月 28, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
table gradient should be split and send to each pserver
上级
25f47fc0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
30 addition
and
13 deletion
+30
-13
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+2
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+28
-12
未找到文件。
python/paddle/fluid/framework.py
浏览文件 @
9f4b66f6
...
@@ -797,7 +797,7 @@ class Block(object):
...
@@ -797,7 +797,7 @@ class Block(object):
Rename variable in vars and ops' inputs and outputs
Rename variable in vars and ops' inputs and outputs
"""
"""
if
not
self
.
has_var
(
name
):
if
not
self
.
has_var
(
name
):
raise
ValueError
(
"var %s is not in current"
%
name
)
raise
ValueError
(
"var %s is not in current
block
"
%
name
)
v
=
self
.
var
(
name
)
v
=
self
.
var
(
name
)
if
type
(
v
)
==
Parameter
:
if
type
(
v
)
==
Parameter
:
var_type
=
"Parameter"
var_type
=
"Parameter"
...
@@ -843,6 +843,7 @@ class Block(object):
...
@@ -843,6 +843,7 @@ class Block(object):
self
.
vars
[
new_name
]
=
var
self
.
vars
[
new_name
]
=
var
del
self
.
vars
[
name
]
del
self
.
vars
[
name
]
self
.
sync_with_cpp
()
self
.
sync_with_cpp
()
return
var
def
remove_var
(
self
,
name
):
def
remove_var
(
self
,
name
):
self
.
sync_with_cpp
()
self
.
sync_with_cpp
()
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
9f4b66f6
...
@@ -256,15 +256,25 @@ class DistributeTranspiler:
...
@@ -256,15 +256,25 @@ class DistributeTranspiler:
if
param_grad
[
0
].
name
==
self
.
table_name
if
param_grad
[
0
].
name
==
self
.
table_name
][
0
]
][
0
]
table_grad_var
=
self
.
table_param_grad
[
1
]
table_grad_var
=
self
.
table_param_grad
[
1
]
self
.
table_grad_list
=
[
if
self
.
sync_mode
:
program
.
global_block
().
create_var
(
self
.
table_grad_list
=
[
name
=
"%s.trainer_%d.pserver_%d"
%
program
.
global_block
().
create_var
(
(
table_grad_var
.
name
,
trainer_id
,
index
),
name
=
"%s.trainer_%d.pserver_%d"
%
type
=
table_grad_var
.
type
,
(
table_grad_var
.
name
,
trainer_id
,
index
),
shape
=
table_grad_var
.
shape
,
type
=
table_grad_var
.
type
,
dtype
=
table_grad_var
.
dtype
)
shape
=
table_grad_var
.
shape
,
for
index
in
range
(
len
(
self
.
pserver_endpoints
))
dtype
=
table_grad_var
.
dtype
)
]
for
index
in
range
(
len
(
self
.
pserver_endpoints
))
]
else
:
self
.
table_grad_list
=
[
program
.
global_block
().
create_var
(
name
=
"%s.pserver_%d"
%
(
table_grad_var
.
name
,
index
),
type
=
table_grad_var
.
type
,
shape
=
table_grad_var
.
shape
,
dtype
=
table_grad_var
.
dtype
)
for
index
in
range
(
len
(
self
.
pserver_endpoints
))
]
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
...
@@ -328,7 +338,7 @@ class DistributeTranspiler:
...
@@ -328,7 +338,7 @@ class DistributeTranspiler:
if
self
.
has_distributed_lookup_table
:
if
self
.
has_distributed_lookup_table
:
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
rpc_client_var
,
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
rpc_client_var
,
eplist
)
pserver_endpoints
)
self
.
_split_table_grad_and_add_send_vars
(
program
,
rpc_client_var
,
self
.
_split_table_grad_and_add_send_vars
(
program
,
rpc_client_var
,
pserver_endpoints
)
pserver_endpoints
)
...
@@ -551,7 +561,7 @@ class DistributeTranspiler:
...
@@ -551,7 +561,7 @@ class DistributeTranspiler:
# transpiler function for dis lookup_table
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
rpc_client_var
,
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
rpc_client_var
,
eplist
):
pserver_endpoints
):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self
.
prefetch_input_vars
=
None
self
.
prefetch_input_vars
=
None
self
.
prefetch_output_vars
=
None
self
.
prefetch_output_vars
=
None
...
@@ -602,7 +612,7 @@ class DistributeTranspiler:
...
@@ -602,7 +612,7 @@ class DistributeTranspiler:
"Out"
:
self
.
prefetch_output_vars
,
"Out"
:
self
.
prefetch_output_vars
,
"RPCClient"
:
rpc_client_var
"RPCClient"
:
rpc_client_var
},
},
attrs
=
{
"epmap"
:
eplist
})
attrs
=
{
"epmap"
:
pserver_endpoints
})
# insert concat_op
# insert concat_op
program
.
global_block
().
insert_op
(
program
.
global_block
().
insert_op
(
...
@@ -731,6 +741,12 @@ class DistributeTranspiler:
...
@@ -731,6 +741,12 @@ class DistributeTranspiler:
type
=
"sum"
,
type
=
"sum"
,
inputs
=
{
"X"
:
table_grad_list
},
inputs
=
{
"X"
:
table_grad_list
},
outputs
=
{
"Out"
:
[
grad_var
]})
outputs
=
{
"Out"
:
[
grad_var
]})
else
:
# in async_mode, for table gradient, it also need to be splited to each parameter server
old_name
=
grad_var
.
name
new_name
=
old_name
+
".pserver_"
+
str
(
pserver_index
)
grad_var
=
pserver_program
.
global_block
().
rename_var
(
old_name
,
new_name
)
lr_var
=
pserver_program
.
global_block
().
vars
[
table_opt_op
.
input
(
lr_var
=
pserver_program
.
global_block
().
vars
[
table_opt_op
.
input
(
"LearningRate"
)[
0
]]
"LearningRate"
)[
0
]]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录