Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2cfb2928
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2cfb2928
编写于
2月 11, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix develop dist transpiler bug
上级
caf9a09d
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
34 addition
and
44 deletion
+34
-44
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+34
-44
未找到文件。
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
2cfb2928
...
@@ -191,7 +191,6 @@ class DistributeTranspiler:
...
@@ -191,7 +191,6 @@ class DistributeTranspiler:
for
b
in
param_blocks
:
for
b
in
param_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_outputs
.
append
(
param_var_mapping
[
varname
][
int
(
block_id
)])
send_outputs
.
append
(
param_var_mapping
[
varname
][
int
(
block_id
)])
# let send_op know which endpoint to send which var to, eplist has the same
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
# order as send_inputs.
eplist
=
split_method
(
send_inputs
,
pserver_endpoints
)
eplist
=
split_method
(
send_inputs
,
pserver_endpoints
)
...
@@ -230,21 +229,6 @@ class DistributeTranspiler:
...
@@ -230,21 +229,6 @@ class DistributeTranspiler:
outputs
=
{
"Out"
:
[
orig_param
]},
outputs
=
{
"Out"
:
[
orig_param
]},
attrs
=
{
"axis"
:
0
})
attrs
=
{
"axis"
:
0
})
self
.
lr_param_mapping
=
self
.
_create_lr_param_mapping
()
def
_create_lr_param_mapping
(
self
):
lr_mapping
=
dict
()
for
_
,
opt_op
in
enumerate
(
self
.
optimize_ops
):
if
not
opt_op
.
inputs
or
not
opt_op
.
inputs
.
has_key
(
"LearningRate"
)
\
or
not
opt_op
.
inputs
.
has_key
(
"Param"
):
continue
lr
=
opt_op
.
inputs
[
"LearningRate"
].
name
param
=
opt_op
.
inputs
[
"Param"
].
name
if
not
lr_mapping
.
has_key
(
lr
):
lr_mapping
.
update
({
lr
:
list
()})
lr_mapping
[
lr
].
append
(
param
)
return
lr_mapping
def
_create_vars_from_blocklist
(
self
,
program
,
block_list
):
def
_create_vars_from_blocklist
(
self
,
program
,
block_list
):
# Create respective variables using the block_list
# Create respective variables using the block_list
block_map
=
dict
()
block_map
=
dict
()
...
@@ -369,18 +353,19 @@ class DistributeTranspiler:
...
@@ -369,18 +353,19 @@ class DistributeTranspiler:
pass
pass
return
orig_shape
return
orig_shape
def
_fetch_var_names
(
self
,
param_dict
):
#
def _fetch_var_names(self, param_dict):
res
=
[]
#
res = []
if
not
param_dict
:
#
if not param_dict:
return
res
#
return res
for
_
,
values
in
param_dict
.
iteritems
():
#
for _, values in param_dict.iteritems():
if
not
isinstance
(
values
,
list
):
#
if not isinstance(values, list):
values
=
[
values
]
#
values = [values]
res
+=
[
v
.
name
for
v
in
values
]
#
res += [v.name for v in values]
return
res
#
return res
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
):
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
):
program
=
optimize_block
.
program
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
new_inputs
=
dict
()
new_inputs
=
dict
()
# update param/grad shape first, then other inputs like
# update param/grad shape first, then other inputs like
# moment can use the updated shape
# moment can use the updated shape
...
@@ -395,11 +380,11 @@ class DistributeTranspiler:
...
@@ -395,11 +380,11 @@ class DistributeTranspiler:
# do not append this op if current endpoint
# do not append this op if current endpoint
# is not dealing with this grad block
# is not dealing with this grad block
return
return
merged_var
=
p
rogram
.
global_block
()
.
vars
[
grad_block
.
name
]
merged_var
=
p
server_block
.
vars
[
grad_block
.
name
]
# append merging ops if trainers > 1
# append merging ops if trainers > 1
if
self
.
trainers
>
1
:
if
self
.
trainers
>
1
:
vars2merge
=
self
.
_create_var_for_trainers
(
vars2merge
=
self
.
_create_var_for_trainers
(
p
rogram
.
global_block
()
,
grad_block
,
self
.
trainers
)
p
server_block
,
grad_block
,
self
.
trainers
)
optimize_block
.
append_op
(
optimize_block
.
append_op
(
type
=
"sum"
,
type
=
"sum"
,
inputs
=
{
"X"
:
vars2merge
},
inputs
=
{
"X"
:
vars2merge
},
...
@@ -419,29 +404,27 @@ class DistributeTranspiler:
...
@@ -419,29 +404,27 @@ class DistributeTranspiler:
break
break
if
not
param_block
:
if
not
param_block
:
return
return
tmpvar
=
p
rogram
.
global_block
()
.
create_var
(
tmpvar
=
p
server_block
.
create_var
(
name
=
param_block
.
name
,
name
=
param_block
.
name
,
persistable
=
True
,
persistable
=
True
,
dtype
=
param_block
.
dtype
,
dtype
=
param_block
.
dtype
,
shape
=
param_block
.
shape
)
shape
=
param_block
.
shape
)
new_inputs
[
key
]
=
tmpvar
new_inputs
[
key
]
=
tmpvar
elif
key
==
"LearningRate"
:
elif
key
==
"LearningRate"
:
# leraning rate variable has already be created by non-optimize op,
# leraning rate variable has already be created by non-optimize op,
# don't create it once again.
# don't create it once again.
new_inputs
[
key
]
=
program
.
global_block
().
vars
[
opt_op
.
input
(
key
)[
new_inputs
[
key
]
=
pserver_block
.
vars
[
opt_op
.
input
(
key
)[
0
]]
0
]]
for
key
in
opt_op
.
input_names
:
for
key
in
opt_op
.
input_names
:
new_shape
=
None
new_shape
=
None
if
key
in
[
"Param"
,
"Grad"
,
"LearningRate"
]:
if
key
in
[
"Param"
,
"Grad"
,
"LearningRate"
]:
continue
continue
var
=
program
.
global_block
().
vars
[
opt_op
.
input
(
key
)[
0
]]
var
=
self
.
program
.
global_block
().
vars
[
opt_op
.
input
(
key
)[
0
]]
# update accumulator variable shape
# update accumulator variable shape
param_shape
=
new_inputs
[
"Param"
].
shape
param_shape
=
new_inputs
[
"Param"
].
shape
new_shape
=
self
.
_get_optimizer_input_shape
(
opt_op
.
type
,
key
,
new_shape
=
self
.
_get_optimizer_input_shape
(
opt_op
.
type
,
key
,
var
.
shape
,
param_shape
)
var
.
shape
,
param_shape
)
tmpvar
=
p
rogram
.
global_block
()
.
create_var
(
tmpvar
=
p
server_block
.
create_var
(
name
=
var
.
name
,
name
=
var
.
name
,
persistable
=
var
.
persistable
,
persistable
=
var
.
persistable
,
dtype
=
var
.
dtype
,
dtype
=
var
.
dtype
,
...
@@ -449,11 +432,14 @@ class DistributeTranspiler:
...
@@ -449,11 +432,14 @@ class DistributeTranspiler:
new_inputs
[
key
]
=
tmpvar
new_inputs
[
key
]
=
tmpvar
# change output's ParamOut variable
# change output's ParamOut variable
outputs
=
self
.
_get_output_map_from_op
(
self
.
program
.
global_block
().
vars
,
opt_op
)
opt_op
.
outputs
[
"ParamOut"
]
=
new_inputs
[
"Param"
]
opt_op
.
outputs
[
"ParamOut"
]
=
new_inputs
[
"Param"
]
optimize_block
.
append_op
(
optimize_block
.
append_op
(
type
=
opt_op
.
type
,
type
=
opt_op
.
type
,
inputs
=
new_inputs
,
inputs
=
new_inputs
,
outputs
=
o
pt_op
.
o
utputs
,
outputs
=
outputs
,
attrs
=
opt_op
.
attrs
)
attrs
=
opt_op
.
attrs
)
def
_append_pserver_non_opt_ops
(
self
,
optimize_block
,
opt_op
):
def
_append_pserver_non_opt_ops
(
self
,
optimize_block
,
opt_op
):
...
@@ -497,11 +483,16 @@ class DistributeTranspiler:
...
@@ -497,11 +483,16 @@ class DistributeTranspiler:
# If one op's input is another op's output or
# If one op's input is another op's output or
# one op's output is another op's input, we say
# one op's output is another op's input, we say
# the two operator is connected.
# the two operator is connected.
op1_input_names
=
self
.
_fetch_var_names
(
op1
.
inputs
)
# op1_input_names = self._fetch_var_names(op1.inputs)
op1_output_names
=
self
.
_fetch_var_names
(
op1
.
outputs
)
# op1_output_names = self._fetch_var_names(op1.outputs)
op1_input_names
=
op1
.
desc
.
input_arg_names
()
op1_output_names
=
op1
.
desc
.
output_arg_names
()
# op2_input_names = self._fetch_var_names(op2.inputs)
# op2_output_names = self._fetch_var_names(op2.outputs)
op2_input_names
=
op2
.
desc
.
input_arg_names
()
op2_output_names
=
op2
.
desc
.
output_arg_names
()
op2_input_names
=
self
.
_fetch_var_names
(
op2
.
inputs
)
op2_output_names
=
self
.
_fetch_var_names
(
op2
.
outputs
)
if
set
(
op1_output_names
)
&
set
(
op2_input_names
)
or
\
if
set
(
op1_output_names
)
&
set
(
op2_input_names
)
or
\
set
(
op1_input_names
)
&
set
(
op2_output_names
):
set
(
op1_input_names
)
&
set
(
op2_output_names
):
return
True
return
True
...
@@ -521,8 +512,8 @@ class DistributeTranspiler:
...
@@ -521,8 +512,8 @@ class DistributeTranspiler:
def
_is_opt_op
(
self
,
op
):
def
_is_opt_op
(
self
,
op
):
# NOTE: It's a HACK implement.
# NOTE: It's a HACK implement.
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
if
op
.
inputs
and
op
.
inputs
.
has_key
(
"Param"
)
\
if
"Param"
in
op
.
input_names
and
\
and
op
.
inputs
.
has_key
(
"LearningRate"
)
:
"LearningRate"
in
op
.
input_names
:
return
True
return
True
return
False
return
False
...
@@ -530,12 +521,12 @@ class DistributeTranspiler:
...
@@ -530,12 +521,12 @@ class DistributeTranspiler:
param_names
=
[
param_names
=
[
p
.
name
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
p
.
name
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
]
]
if
op
.
input
s
[
"Param"
].
name
in
param_names
:
if
op
.
input
(
"Param"
)
in
param_names
:
return
True
return
True
else
:
else
:
for
n
in
param_names
:
for
n
in
param_names
:
param
=
op
.
input
s
[
"Param"
].
name
param
=
op
.
input
(
"Param"
)[
0
]
if
same_or_split_var
(
n
,
param
)
and
n
!=
op
.
inputs
[
"Param"
].
name
:
if
same_or_split_var
(
n
,
param
)
and
n
!=
param
:
return
True
return
True
return
False
return
False
return
False
return
False
...
@@ -564,7 +555,6 @@ class DistributeTranspiler:
...
@@ -564,7 +555,6 @@ class DistributeTranspiler:
persistable
=
True
,
persistable
=
True
,
dtype
=
v
.
dtype
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
shape
=
v
.
shape
)
# step6
# step6
optimize_block
=
pserver_program
.
create_block
(
0
)
optimize_block
=
pserver_program
.
create_block
(
0
)
# step 6.1
# step 6.1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录