Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b31647c6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b31647c6
编写于
5月 31, 2018
作者:
Q
Qiyang Min
提交者:
GitHub
5月 31, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into update_simple_distranspiler
上级
0abf173e
f437c46f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
372 addition
and
263 deletion
+372
-263
paddle/contrib/inference/test_paddle_inference_api_impl.cc
paddle/contrib/inference/test_paddle_inference_api_impl.cc
+1
-2
python/paddle/fluid/tests/unittests/test_split_var.py
python/paddle/fluid/tests/unittests/test_split_var.py
+2
-2
python/paddle/fluid/transpiler/details/__init__.py
python/paddle/fluid/transpiler/details/__init__.py
+16
-0
python/paddle/fluid/transpiler/details/program_utils.py
python/paddle/fluid/transpiler/details/program_utils.py
+37
-0
python/paddle/fluid/transpiler/details/ufind.py
python/paddle/fluid/transpiler/details/ufind.py
+64
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+252
-259
未找到文件。
paddle/contrib/inference/test_paddle_inference_api_impl.cc
浏览文件 @
b31647c6
...
@@ -144,8 +144,7 @@ TEST(paddle_inference_api_impl, image_classification) {
...
@@ -144,8 +144,7 @@ TEST(paddle_inference_api_impl, image_classification) {
float
*
data
=
static_cast
<
float
*>
(
outputs
[
0
].
data
.
data
);
float
*
data
=
static_cast
<
float
*>
(
outputs
[
0
].
data
.
data
);
float
*
lod_data
=
output1
.
data
<
float
>
();
float
*
lod_data
=
output1
.
data
<
float
>
();
for
(
size_t
j
=
0
;
j
<
len
/
sizeof
(
float
);
++
j
)
{
for
(
size_t
j
=
0
;
j
<
len
/
sizeof
(
float
);
++
j
)
{
EXPECT_LT
(
lod_data
[
j
]
-
data
[
j
],
1e-10
);
EXPECT_NEAR
(
lod_data
[
j
],
data
[
j
],
1e-3
);
EXPECT_GT
(
lod_data
[
j
]
-
data
[
j
],
-
1e-10
);
}
}
free
(
data
);
free
(
data
);
}
}
...
...
python/paddle/fluid/tests/unittests/test_split_var.py
浏览文件 @
b31647c6
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
math
import
math
import
unittest
import
unittest
from
paddle.fluid.transpiler.distribute_transpiler
import
split_
dense_
variable
from
paddle.fluid.transpiler.distribute_transpiler
import
split_variable
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
import
random
import
random
...
@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
...
@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR,
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape
=
shape
)
shape
=
shape
)
var_list
.
append
(
var
)
var_list
.
append
(
var
)
blocks
=
split_
dense_
variable
(
var_list
,
10
,
min_size
)
blocks
=
split_variable
(
var_list
,
10
,
min_size
)
all_sizes
=
[]
all_sizes
=
[]
for
s
in
expected_sizes
:
for
s
in
expected_sizes
:
for
s2
in
s
:
for
s2
in
s
:
...
...
python/paddle/fluid/transpiler/details/__init__.py
0 → 100644
浏览文件 @
b31647c6
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
program_utils
import
*
from
ufind
import
*
python/paddle/fluid/transpiler/details/program_utils.py
0 → 100644
浏览文件 @
b31647c6
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
delete_ops
(
block
,
ops
):
try
:
start
=
list
(
block
.
ops
).
index
(
ops
[
0
])
end
=
list
(
block
.
ops
).
index
(
ops
[
-
1
])
[
block
.
remove_op
(
start
)
for
_
in
xrange
(
end
-
start
+
1
)]
except
Exception
,
e
:
raise
e
block
.
program
.
sync_with_cpp
()
def
find_op_by_input_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
input_arg_names
:
return
index
return
-
1
def
find_op_by_output_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
output_arg_names
:
return
index
return
-
1
python/paddle/fluid/transpiler/details/ufind.py
0 → 100644
浏览文件 @
b31647c6
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
UnionFind
(
object
):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def
__init__
(
self
,
elementes
=
None
):
self
.
_parents
=
[]
# index -> parent index
self
.
_index
=
{}
# element -> index
self
.
_curr_idx
=
0
if
not
elementes
:
elementes
=
[]
for
ele
in
elementes
:
self
.
_parents
.
append
(
self
.
_curr_idx
)
self
.
_index
.
update
({
ele
:
self
.
_curr_idx
})
self
.
_curr_idx
+=
1
def
find
(
self
,
x
):
# Find the root index of given element x,
# execute the path compress while findind the root index
if
not
x
in
self
.
_index
:
return
-
1
idx
=
self
.
_index
[
x
]
while
idx
!=
self
.
_parents
[
idx
]:
t
=
self
.
_parents
[
idx
]
self
.
_parents
[
idx
]
=
self
.
_parents
[
t
]
idx
=
t
return
idx
def
union
(
self
,
x
,
y
):
# Union two given element
x_root
=
self
.
find
(
x
)
y_root
=
self
.
find
(
y
)
if
x_root
==
y_root
:
return
self
.
_parents
[
x_root
]
=
y_root
def
is_connected
(
self
,
x
,
y
):
# If two given elements have the same root index,
# then they are connected.
return
self
.
find
(
x
)
==
self
.
find
(
y
)
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b31647c6
...
@@ -11,6 +11,30 @@
...
@@ -11,6 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
"""
from
__future__
import
print_function
from
__future__
import
print_function
...
@@ -22,9 +46,11 @@ from .. import core, framework
...
@@ -22,9 +46,11 @@ from .. import core, framework
from
..framework
import
Program
,
default_main_program
,
\
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
\
default_startup_program
,
\
Variable
,
Parameter
,
grad_var_name
Variable
,
Parameter
,
grad_var_name
from
details
import
*
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
)
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
...
@@ -41,62 +67,11 @@ class VarBlock:
...
@@ -41,62 +67,11 @@ class VarBlock:
return
"%s:%d:%d"
%
(
self
.
varname
,
self
.
offset
,
self
.
size
)
return
"%s:%d:%d"
%
(
self
.
varname
,
self
.
offset
,
self
.
size
)
class
UnionFind
(
object
):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def
__init__
(
self
,
elementes
=
None
):
self
.
_parents
=
[]
# index -> parent index
self
.
_index
=
{}
# element -> index
self
.
_curr_idx
=
0
if
not
elementes
:
elementes
=
[]
for
ele
in
elementes
:
self
.
_parents
.
append
(
self
.
_curr_idx
)
self
.
_index
.
update
({
ele
:
self
.
_curr_idx
})
self
.
_curr_idx
+=
1
def
find
(
self
,
x
):
# Find the root index of given element x,
# execute the path compress while findind the root index
if
not
x
in
self
.
_index
:
return
-
1
idx
=
self
.
_index
[
x
]
while
idx
!=
self
.
_parents
[
idx
]:
t
=
self
.
_parents
[
idx
]
self
.
_parents
[
idx
]
=
self
.
_parents
[
t
]
idx
=
t
return
idx
def
union
(
self
,
x
,
y
):
# Union two given element
x_root
=
self
.
find
(
x
)
y_root
=
self
.
find
(
y
)
if
x_root
==
y_root
:
return
self
.
_parents
[
x_root
]
=
y_root
def
is_connected
(
self
,
x
,
y
):
# If two given elements have the same root index,
# then they are connected.
return
self
.
find
(
x
)
==
self
.
find
(
y
)
def
same_or_split_var
(
p_name
,
var_name
):
def
same_or_split_var
(
p_name
,
var_name
):
return
p_name
==
var_name
or
p_name
.
startswith
(
var_name
+
".block"
)
return
p_name
==
var_name
or
p_name
.
startswith
(
var_name
+
".block"
)
def
split_
dense_
variable
(
var_list
,
service_count
,
min_block_size
=
8192
):
def
split_variable
(
var_list
,
service_count
,
min_block_size
=
8192
):
"""
"""
We may need to split dense tensor to one or more blocks and put
We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor
them equally onto parameter server. One block is a sub-tensor
...
@@ -142,101 +117,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192):
...
@@ -142,101 +117,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192):
return
blocks
return
blocks
def
delete_ops
(
block
,
ops
):
try
:
start
=
list
(
block
.
ops
).
index
(
ops
[
0
])
end
=
list
(
block
.
ops
).
index
(
ops
[
-
1
])
[
block
.
remove_op
(
start
)
for
_
in
xrange
(
end
-
start
+
1
)]
except
Exception
,
e
:
raise
e
block
.
program
.
sync_with_cpp
()
def
find_op_by_input_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
input_arg_names
:
return
index
return
-
1
def
find_op_by_output_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
output_arg_names
:
return
index
return
-
1
class
DistributeTranspiler
:
class
DistributeTranspiler
:
def
transpile
(
self
,
def
_has_distributed_lookup_table
(
self
):
trainer_id
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
align_var_to_block
=
True
,
split_method
=
RoundRobin
,
sync_mode
=
True
):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width)
if align_var_to_block is True
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert
(
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
if
program
is
None
:
program
=
default_main_program
()
self
.
origin_program
=
program
self
.
trainer_num
=
trainers
self
.
sync_mode
=
sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
optimize_ops
,
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
split_method
(
pserver_endpoints
)
# process lookup_table_op
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops
=
[]
distributed_lookup_table_ops
=
[]
# support only one distributed_lookup_table now
# support only one distributed_lookup_table now
self
.
table_name
=
None
self
.
table_name
=
None
for
op
in
program
.
global_block
().
ops
:
for
op
in
self
.
origin_
program
.
global_block
().
ops
:
if
op
.
type
==
LOOKUP_TABLE_TYPE
:
if
op
.
type
==
LOOKUP_TABLE_TYPE
:
if
op
.
attrs
[
'is_distributed'
]
is
True
:
if
op
.
attrs
[
'is_distributed'
]
is
True
:
if
self
.
table_name
is
None
:
if
self
.
table_name
is
None
:
...
@@ -249,20 +138,13 @@ class DistributeTranspiler:
...
@@ -249,20 +138,13 @@ class DistributeTranspiler:
if
self
.
table_name
is
not
None
:
if
self
.
table_name
is
not
None
:
assert
op
.
input
(
"W"
)[
0
]
!=
self
.
table_name
assert
op
.
input
(
"W"
)[
0
]
!=
self
.
table_name
self
.
has_distributed_lookup_table
=
len
(
return
len
(
distributed_lookup_table_ops
)
>
0
distributed_lookup_table_ops
)
>
0
# step1: For large parameters and gradients, split them into smaller
# blocks.
param_list
=
[]
grad_list
=
[]
for
p
,
g
in
params_grads
:
# skip parameter marked not trainable
if
type
(
p
)
==
Parameter
and
p
.
trainable
==
False
:
continue
param_list
.
append
(
p
)
grad_list
.
append
(
g
)
def
_update_dist_lookup_table_vars
(
self
,
param_list
,
grad_list
,
params_grads
):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
# update self.table_param_grad and self.trainer_side_table_grad_list
program
=
self
.
origin_program
if
self
.
has_distributed_lookup_table
:
if
self
.
has_distributed_lookup_table
:
param_list
=
[
param_list
=
[
param
for
param
in
param_list
if
param
.
name
!=
self
.
table_name
param
for
param
in
param_list
if
param
.
name
!=
self
.
table_name
...
@@ -280,7 +162,7 @@ class DistributeTranspiler:
...
@@ -280,7 +162,7 @@ class DistributeTranspiler:
self
.
trainer_side_table_grad_list
=
[
self
.
trainer_side_table_grad_list
=
[
program
.
global_block
().
create_var
(
program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d.pserver_%d"
%
name
=
"%s.trainer_%d.pserver_%d"
%
(
table_grad_var
.
name
,
trainer_id
,
index
),
(
table_grad_var
.
name
,
self
.
trainer_id
,
index
),
type
=
table_grad_var
.
type
,
type
=
table_grad_var
.
type
,
shape
=
table_grad_var
.
shape
,
shape
=
table_grad_var
.
shape
,
dtype
=
table_grad_var
.
dtype
)
dtype
=
table_grad_var
.
dtype
)
...
@@ -296,6 +178,25 @@ class DistributeTranspiler:
...
@@ -296,6 +178,25 @@ class DistributeTranspiler:
for
index
in
range
(
len
(
self
.
pserver_endpoints
))
for
index
in
range
(
len
(
self
.
pserver_endpoints
))
]
]
def
_init_splited_vars
(
self
,
split_method
):
# update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars]
# 3. grad_param_mapping: grad.blockx -> param.blockx
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
param_list
=
[]
grad_list
=
[]
for
p
,
g
in
self
.
params_grads
:
# skip parameter marked not trainable
if
type
(
p
)
==
Parameter
and
p
.
trainable
==
False
:
continue
param_list
.
append
(
p
)
grad_list
.
append
(
g
)
self
.
_update_dist_lookup_table_vars
(
param_list
,
grad_list
,
self
.
params_grads
)
if
align_var_to_block
:
if
align_var_to_block
:
grad_blocks
=
split_dense_variable
(
grad_list
,
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
len
(
pserver_endpoints
))
...
@@ -307,21 +208,19 @@ class DistributeTranspiler:
...
@@ -307,21 +208,19 @@ class DistributeTranspiler:
grad_blocks
=
split_dense_variable
(
grad_list
,
1
)
grad_blocks
=
split_dense_variable
(
grad_list
,
1
)
param_blocks
=
split_dense_variable
(
param_list
,
1
)
param_blocks
=
split_dense_variable
(
param_list
,
1
)
assert
(
len
(
grad_blocks
)
==
len
(
param_blocks
))
assert
(
len
(
grad_blocks
)
==
len
(
param_blocks
))
# origin_varname -> [splited_var]
# step2: Create new vars for the parameters and gradients blocks and
self
.
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
# add ops to do the split.
self
.
origin_program
,
param_blocks
)
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
self
.
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
param_blocks
)
self
.
origin_program
,
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
grad_blocks
,
program
,
grad_blocks
,
add_trainer_suffix
=
self
.
trainer_num
>
1
)
add_trainer_suffix
=
self
.
trainer_num
>
1
)
grad_param_mapping
=
dict
()
self
.
grad_param_mapping
=
dict
()
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
grad_param_mapping
[
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
self
.
grad_param_mapping
[
self
.
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
param_var_mapping
[
p_name
][
int
(
p_bid
)]
self
.
param_var_mapping
[
p_name
][
int
(
p_bid
)]
# step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program
# create mapping of endpoint -> split var to create pserver side program
self
.
param_grad_ep_mapping
=
dict
()
self
.
param_grad_ep_mapping
=
dict
()
...
@@ -334,6 +233,47 @@ class DistributeTranspiler:
...
@@ -334,6 +233,47 @@ class DistributeTranspiler:
})
for
ep
in
self
.
pserver_endpoints
})
for
ep
in
self
.
pserver_endpoints
]
]
def
transpile
(
self
,
trainer_id
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
align_var_to_block
=
True
,
split_method
=
RoundRobin
,
sync_mode
=
True
):
"""
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert
(
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
if
program
is
None
:
program
=
default_main_program
()
self
.
origin_program
=
program
self
.
trainer_num
=
trainers
self
.
sync_mode
=
sync_mode
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
optimize_ops
,
self
.
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
split_method
(
self
.
pserver_endpoints
)
self
.
has_distributed_lookup_table
=
self
.
_has_distributed_lookup_table
()
# split and create vars, then put splited vars in dicts for later use.
self
.
_init_splited_vars
(
split_method
)
# step 3.1: insert send op to send gradient vars to parameter servers
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher
.
reset
()
ps_dispatcher
.
reset
()
send_vars
=
[]
send_vars
=
[]
...
@@ -343,7 +283,7 @@ class DistributeTranspiler:
...
@@ -343,7 +283,7 @@ class DistributeTranspiler:
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above
# shuffle the map will avoid the uneven distribution above
grad_var_mapping_items
=
grad_var_mapping
.
items
()
grad_var_mapping_items
=
self
.
grad_var_mapping
.
items
()
if
not
align_var_to_block
:
if
not
align_var_to_block
:
np
.
random
.
shuffle
(
grad_var_mapping_items
)
np
.
random
.
shuffle
(
grad_var_mapping_items
)
...
@@ -393,7 +333,7 @@ class DistributeTranspiler:
...
@@ -393,7 +333,7 @@ class DistributeTranspiler:
# step 3.2: insert recv op to receive parameters from parameter server
# step 3.2: insert recv op to receive parameters from parameter server
recv_vars
=
[]
recv_vars
=
[]
for
_
,
var
in
enumerate
(
send_vars
):
for
_
,
var
in
enumerate
(
send_vars
):
recv_vars
.
append
(
grad_param_mapping
[
var
])
recv_vars
.
append
(
self
.
grad_param_mapping
[
var
])
ps_dispatcher
.
reset
()
ps_dispatcher
.
reset
()
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
...
@@ -401,7 +341,8 @@ class DistributeTranspiler:
...
@@ -401,7 +341,8 @@ class DistributeTranspiler:
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
self
.
param_var_mapping
.
iteritems
():
eps
=
[]
eps
=
[]
for
var
in
splited_var
:
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
...
@@ -425,8 +366,7 @@ class DistributeTranspiler:
...
@@ -425,8 +366,7 @@ class DistributeTranspiler:
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
})
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
self
.
param_var_mapping
.
iteritems
():
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
continue
continue
orig_param
=
program
.
global_block
().
vars
[
varname
]
orig_param
=
program
.
global_block
().
vars
[
varname
]
...
@@ -467,7 +407,6 @@ class DistributeTranspiler:
...
@@ -467,7 +407,6 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives.
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# change client side var name to origin name by
# removing ".trainer_%d" suffix
# removing ".trainer_%d" suffix
suff_idx
=
v
.
name
.
find
(
".trainer_"
)
suff_idx
=
v
.
name
.
find
(
".trainer_"
)
if
suff_idx
>=
0
:
if
suff_idx
>=
0
:
orig_var_name
=
v
.
name
[:
suff_idx
]
orig_var_name
=
v
.
name
[:
suff_idx
]
...
@@ -504,24 +443,14 @@ class DistributeTranspiler:
...
@@ -504,24 +443,14 @@ class DistributeTranspiler:
# located on current pserver
# located on current pserver
opt_op_on_pserver
=
[]
opt_op_on_pserver
=
[]
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
if
self
.
_is_opt_op
(
op
)
and
self
.
_is_opt_op_on_pserver
(
endpoint
,
op
):
if
self
.
_is_optimizer_op
(
op
)
and
self
.
_is_opt_op_on_pserver
(
endpoint
,
op
):
opt_op_on_pserver
.
append
(
op
)
opt_op_on_pserver
.
append
(
op
)
# step 3.3
# step 3.3
# Iterate through the ops, and if an op and the optimize ops
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# which located on current pserver are in one set, then
# append it into the sub program.
# append it into the sub program.
# We try to put optimization program run parallelly, assume
# optimization program always looks like:
#
# prevop -> prevop -> opt op -> following op -> following op; ->
# prevop -> prevop -> opt op -> following op -> following op; ->
# global op -> global op
#
# we put operators that can run parallelly to many program blocks.
# in above example, we seperate ops by the ";". Global ops must run
# after all the optimize ops finished.
global_ops
=
[]
global_ops
=
[]
# HACK: optimization global ops only used to scale beta1 and beta2
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
# replace it with dependency engine.
...
@@ -529,12 +458,18 @@ class DistributeTranspiler:
...
@@ -529,12 +458,18 @@ class DistributeTranspiler:
if
self
.
_is_adam_connected_op
(
op
):
if
self
.
_is_adam_connected_op
(
op
):
global_ops
.
append
(
op
)
global_ops
.
append
(
op
)
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
):
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
,
merged_var
):
if
self
.
_is_opt_op
(
op
):
if
self
.
_is_opt
imizer
_op
(
op
):
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
grad_to_block_id
,
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
grad_to_block_id
,
self
.
origin_program
)
self
.
origin_program
,
merged_var
)
else
:
else
:
self
.
_append_pserver_non_opt_ops
(
block
,
op
)
self
.
_append_pserver_non_opt_ops
(
block
,
op
,
endpoint
)
def
__op_have_grad_input__
(
op
):
for
varname
in
op
.
input_arg_names
:
if
varname
.
find
(
"@GRAD"
)
>=
0
:
return
varname
return
""
# append lr decay ops to the child block if exists
# append lr decay ops to the child block if exists
lr_ops
=
self
.
_get_lr_ops
()
lr_ops
=
self
.
_get_lr_ops
()
...
@@ -542,17 +477,26 @@ class DistributeTranspiler:
...
@@ -542,17 +477,26 @@ class DistributeTranspiler:
lr_decay_block
=
pserver_program
.
create_block
(
lr_decay_block
=
pserver_program
.
create_block
(
pserver_program
.
num_blocks
-
1
)
pserver_program
.
num_blocks
-
1
)
for
_
,
op
in
enumerate
(
lr_ops
):
for
_
,
op
in
enumerate
(
lr_ops
):
self
.
_append_pserver_non_opt_ops
(
lr_decay_block
,
op
)
self
.
_append_pserver_non_opt_ops
(
lr_decay_block
,
op
,
endpoint
)
# append op to the current block
# append op to the current block
grad_to_block_id
=
[]
grad_to_block_id
=
[]
pre_block_idx
=
pserver_program
.
num_blocks
-
1
pre_block_idx
=
pserver_program
.
num_blocks
-
1
for
idx
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
for
idx
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
per_opt_block
=
pserver_program
.
create_block
(
pre_block_idx
)
per_opt_block
=
pserver_program
.
create_block
(
pre_block_idx
)
# append grad merging ops before clip and weight decay
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
# find the origin @GRAD var before clipping
grad_varname_for_block
=
__op_have_grad_input__
(
op
)
if
ufind
.
is_connected
(
op
,
opt_op
)
and
grad_varname_for_block
:
merged_var
=
self
.
_append_pserver_grad_merge_ops
(
per_opt_block
,
grad_varname_for_block
,
endpoint
,
grad_to_block_id
,
self
.
origin_program
)
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
# optimizer is connected to itself
# optimizer is connected to itself
if
ufind
.
is_connected
(
op
,
opt_op
)
and
op
not
in
global_ops
:
if
ufind
.
is_connected
(
op
,
opt_op
)
and
op
not
in
global_ops
:
__append_optimize_op__
(
op
,
per_opt_block
,
grad_to_block_id
)
__append_optimize_op__
(
op
,
per_opt_block
,
grad_to_block_id
,
merged_var
)
# append global ops
# append global ops
if
global_ops
:
if
global_ops
:
...
@@ -560,15 +504,7 @@ class DistributeTranspiler:
...
@@ -560,15 +504,7 @@ class DistributeTranspiler:
pserver_program
.
num_blocks
-
1
)
pserver_program
.
num_blocks
-
1
)
for
glb_op
in
global_ops
:
for
glb_op
in
global_ops
:
__append_optimize_op__
(
glb_op
,
opt_state_block
,
__append_optimize_op__
(
glb_op
,
opt_state_block
,
grad_to_block_id
)
grad_to_block_id
,
None
)
# NOT USED: single block version:
#
# for _, op in enumerate(self.optimize_ops):
# for _, opt_op in enumerate(opt_op_on_pserver):
# if ufind.is_connected(op, opt_op):
# __append_optimize_op__(glb_op, optimize_block)
# break
# process distributed lookup_table
# process distributed lookup_table
prefetch_block
=
None
prefetch_block
=
None
...
@@ -658,6 +594,8 @@ class DistributeTranspiler:
...
@@ -658,6 +594,8 @@ class DistributeTranspiler:
attrs
=
op
.
attrs
)
attrs
=
op
.
attrs
)
return
s_prog
return
s_prog
# ====================== private transpiler functions =====================
# transpiler function for dis lookup_table
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
pserver_endpoints
):
pserver_endpoints
):
...
@@ -863,7 +801,6 @@ class DistributeTranspiler:
...
@@ -863,7 +801,6 @@ class DistributeTranspiler:
return
table_opt_block
return
table_opt_block
# ====================== private transpiler functions =====================
def
_create_vars_from_blocklist
(
self
,
def
_create_vars_from_blocklist
(
self
,
program
,
program
,
block_list
,
block_list
,
...
@@ -1006,17 +943,74 @@ class DistributeTranspiler:
...
@@ -1006,17 +943,74 @@ class DistributeTranspiler:
pass
pass
return
orig_shape
return
orig_shape
def
_
orig_varname
(
self
,
varname
):
def
_
get_varname_parts
(
self
,
varname
):
suff_idx
=
varname
.
find
(
".trainer_"
)
# returns origin, blockid, trainerid
orig_var_name
=
""
orig_var_name
=
""
if
suff_idx
>=
0
:
trainer_part
=
""
orig_var_name
=
varname
[:
suff_idx
]
block_part
=
""
trainer_idx
=
varname
.
find
(
".trainer_"
)
if
trainer_idx
>=
0
:
trainer_part
=
varname
[
trainer_idx
+
1
:]
else
:
trainer_idx
=
len
(
varname
)
block_index
=
varname
.
find
(
".block"
)
if
block_index
>=
0
:
block_part
=
varname
[
block_index
+
1
:
trainer_idx
]
else
:
block_index
=
len
(
varname
)
orig_var_name
=
varname
[
0
:
min
(
block_index
,
trainer_idx
)]
return
orig_var_name
,
block_part
,
trainer_part
def
_orig_varname
(
self
,
varname
):
orig
,
_
,
_
=
self
.
_get_varname_parts
(
varname
)
return
orig
def
_append_pserver_grad_merge_ops
(
self
,
optimize_block
,
grad_varname_for_block
,
endpoint
,
grad_to_block_id
,
origin_program
):
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
grad_block
=
None
for
g
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
if
self
.
_orig_varname
(
g
.
name
)
==
\
self
.
_orig_varname
(
grad_varname_for_block
):
grad_block
=
g
break
if
not
grad_block
:
# do not append this op if current endpoint
# is not dealing with this grad block
return
orig_varname
,
block_name
,
trainer_name
=
self
.
_get_varname_parts
(
grad_block
.
name
)
if
block_name
:
merged_var_name
=
'.'
.
join
([
orig_varname
,
block_name
])
else
:
else
:
orig_var_name
=
varname
merged_var_name
=
orig_varname
return
orig_var_name
merged_var
=
\
pserver_block
.
vars
[
merged_var_name
]
grad_to_block_id
.
append
(
merged_var
.
name
+
":"
+
str
(
optimize_block
.
idx
))
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
vars2merge
=
[]
for
i
in
xrange
(
self
.
trainer_num
):
per_trainer_name
=
"%s.trainer_%d"
%
\
(
merged_var_name
,
i
)
vars2merge
.
append
(
pserver_block
.
vars
[
per_trainer_name
])
optimize_block
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
vars2merge
},
outputs
=
{
"Out"
:
merged_var
})
# TODO(panyx0718): What if it's SELECTED_ROWS.
if
not
merged_var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
optimize_block
.
append_op
(
type
=
"scale"
,
inputs
=
{
"X"
:
merged_var
},
outputs
=
{
"Out"
:
merged_var
},
attrs
=
{
"scale"
:
1.0
/
float
(
self
.
trainer_num
)})
return
merged_var
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
grad_to_block_id
,
origin_program
):
grad_to_block_id
,
origin_program
,
merged_var
):
program
=
optimize_block
.
program
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
pserver_block
=
program
.
global_block
()
new_inputs
=
dict
()
new_inputs
=
dict
()
...
@@ -1024,40 +1018,6 @@ class DistributeTranspiler:
...
@@ -1024,40 +1018,6 @@ class DistributeTranspiler:
# moment can use the updated shape
# moment can use the updated shape
for
key
in
opt_op
.
input_names
:
for
key
in
opt_op
.
input_names
:
if
key
==
"Grad"
:
if
key
==
"Grad"
:
grad_block
=
None
for
g
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
if
same_or_split_var
(
self
.
_orig_varname
(
g
.
name
),
self
.
_orig_varname
(
opt_op
.
input
(
key
)[
0
])):
grad_block
=
g
break
if
not
grad_block
:
# do not append this op if current endpoint
# is not dealing with this grad block
return
merged_var
=
\
pserver_block
.
vars
[
self
.
_orig_varname
(
grad_block
.
name
)]
grad_to_block_id
.
append
(
merged_var
.
name
+
":"
+
str
(
optimize_block
.
idx
))
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
vars2merge
=
[]
for
i
in
xrange
(
self
.
trainer_num
):
per_trainer_name
=
"%s.trainer_%d"
%
\
(
self
.
_orig_varname
(
grad_block
.
name
),
i
)
vars2merge
.
append
(
pserver_block
.
vars
[
per_trainer_name
])
optimize_block
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
vars2merge
},
outputs
=
{
"Out"
:
merged_var
})
# TODO(panyx0718): What if it's SELECTED_ROWS.
if
not
merged_var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
optimize_block
.
append_op
(
type
=
"scale"
,
inputs
=
{
"X"
:
merged_var
},
outputs
=
{
"Out"
:
merged_var
},
attrs
=
{
"scale"
:
1.0
/
float
(
self
.
trainer_num
)})
new_inputs
[
key
]
=
merged_var
new_inputs
[
key
]
=
merged_var
elif
key
==
"Param"
:
elif
key
==
"Param"
:
# param is already created on global program
# param is already created on global program
...
@@ -1116,17 +1076,31 @@ class DistributeTranspiler:
...
@@ -1116,17 +1076,31 @@ class DistributeTranspiler:
outputs
=
outputs
,
outputs
=
outputs
,
attrs
=
opt_op
.
attrs
)
attrs
=
opt_op
.
attrs
)
def
_append_pserver_non_opt_ops
(
self
,
optimize_block
,
opt_op
):
def
_is_splited_grad_var
(
self
,
var
,
var_dict
):
grad_block
=
None
for
_
,
g
in
var_dict
.
iteritems
():
if
self
.
_orig_varname
(
g
.
name
)
==
self
.
_orig_varname
(
var
.
name
):
if
g
.
name
.
find
(
".trainer_"
)
==
-
1
:
grad_block
=
g
break
return
grad_block
def
_append_pserver_non_opt_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
):
program
=
optimize_block
.
program
program
=
optimize_block
.
program
# Append the ops for parameters that do not need to be optimized/updated
# Append the ops for parameters that do not need to be optimized/updated
inputs
=
self
.
_get_input_map_from_op
(
inputs
=
self
.
_get_input_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
for
varlist
in
inputs
.
itervalue
s
():
for
key
,
varlist
in
inputs
.
iteritem
s
():
if
not
isinstance
(
varlist
,
list
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
varlist
=
[
varlist
]
for
var
in
varlist
:
for
var
in
varlist
:
if
not
program
.
global_block
().
vars
.
has_key
(
var
.
name
):
# for ops like clipping and weight decay, get the splited var
# for inputs/outputs
grad_block
=
self
.
_is_splited_grad_var
(
var
,
program
.
global_block
().
vars
)
if
grad_block
:
inputs
[
key
]
=
grad_block
elif
not
program
.
global_block
().
vars
.
has_key
(
var
.
name
):
program
.
global_block
().
create_var
(
program
.
global_block
().
create_var
(
name
=
var
.
name
,
name
=
var
.
name
,
persistable
=
var
.
persistable
,
persistable
=
var
.
persistable
,
...
@@ -1135,13 +1109,16 @@ class DistributeTranspiler:
...
@@ -1135,13 +1109,16 @@ class DistributeTranspiler:
outputs
=
self
.
_get_output_map_from_op
(
outputs
=
self
.
_get_output_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
for
key
,
varlist
in
outputs
.
iteritems
():
for
varlist
in
outputs
.
itervalues
():
if
not
isinstance
(
varlist
,
list
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
varlist
=
[
varlist
]
for
var
in
varlist
:
for
var
in
varlist
:
program
.
global_block
().
clone_variable
(
var
)
grad_block
=
self
.
_is_splited_grad_var
(
var
,
program
.
global_block
().
vars
)
if
grad_block
:
outputs
[
key
]
=
grad_block
elif
not
program
.
global_block
().
vars
.
has_key
(
var
.
name
):
program
.
global_block
().
clone_variable
(
var
)
optimize_block
.
append_op
(
optimize_block
.
append_op
(
type
=
opt_op
.
type
,
type
=
opt_op
.
type
,
...
@@ -1187,9 +1164,17 @@ class DistributeTranspiler:
...
@@ -1187,9 +1164,17 @@ class DistributeTranspiler:
ufind
.
union
(
op1
,
op2
)
ufind
.
union
(
op1
,
op2
)
return
ufind
return
ufind
def
_is_opt_op
(
self
,
op
):
def
_is_opt_role_op
(
self
,
op
):
# NOTE: It's a HACK implement.
# NOTE: depend on oprole to find out whether this op is for
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
# optimize
op_maker
=
core
.
op_proto_and_checker_maker
optimize_role
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
if
op_maker
.
kOpRoleAttrName
()
in
op
.
attrs
and
\
int
(
op
.
attrs
[
op_maker
.
kOpRoleAttrName
()])
==
int
(
optimize_role
):
return
True
return
False
def
_is_optimizer_op
(
self
,
op
):
if
"Param"
in
op
.
input_names
and
\
if
"Param"
in
op
.
input_names
and
\
"LearningRate"
in
op
.
input_names
:
"LearningRate"
in
op
.
input_names
:
return
True
return
True
...
@@ -1239,7 +1224,7 @@ class DistributeTranspiler:
...
@@ -1239,7 +1224,7 @@ class DistributeTranspiler:
# find learning rate variables by optimize op
# find learning rate variables by optimize op
lr_vars
=
set
()
lr_vars
=
set
()
for
op
in
self
.
optimize_ops
:
for
op
in
self
.
optimize_ops
:
if
self
.
_is_opt_op
(
op
):
if
self
.
_is_opt
imizer
_op
(
op
):
lr_vars
.
add
(
op
.
input
(
"LearningRate"
)[
0
])
lr_vars
.
add
(
op
.
input
(
"LearningRate"
)[
0
])
find_ops
=
[]
find_ops
=
[]
...
@@ -1256,7 +1241,7 @@ class DistributeTranspiler:
...
@@ -1256,7 +1241,7 @@ class DistributeTranspiler:
# NOTE: we need to skip all optimize ops, since it is connected
# NOTE: we need to skip all optimize ops, since it is connected
# with forward/backward ops and lr ops, we only need the lr ops.
# with forward/backward ops and lr ops, we only need the lr ops.
if
op1
!=
op2
and
self
.
_is_op_connected
(
op1
,
op2
)
and
\
if
op1
!=
op2
and
self
.
_is_op_connected
(
op1
,
op2
)
and
\
not
self
.
_is_opt
_op
(
op1
)
and
not
self
.
_is_opt
_op
(
op2
):
not
self
.
_is_opt
imizer_op
(
op1
)
and
not
self
.
_is_optimizer
_op
(
op2
):
ufind
.
union
(
op1
,
op2
)
ufind
.
union
(
op1
,
op2
)
# find all ops which is related with lr var
# find all ops which is related with lr var
for
op1
in
block
.
ops
:
for
op1
in
block
.
ops
:
...
@@ -1277,13 +1262,21 @@ class DistributeTranspiler:
...
@@ -1277,13 +1262,21 @@ class DistributeTranspiler:
block
=
self
.
origin_program
.
global_block
()
block
=
self
.
origin_program
.
global_block
()
opt_ops
=
[]
opt_ops
=
[]
params_grads
=
[]
params_grads
=
[]
origin_var_dict
=
self
.
origin_program
.
global_block
().
vars
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
if
self
.
_is_opt_op
(
op
):
if
self
.
_is_opt_
role_
op
(
op
):
opt_ops
.
append
(
op
)
opt_ops
.
append
(
op
)
params_grads
.
append
((
self
.
origin_program
.
global_block
().
var
(
# HACK(wuyi): if we find grad vars from input of optimize
op
.
input
(
"Param"
)[
0
]),
# ops, we may get the output of clip op. Use syntax "@GRAD"
self
.
origin_program
.
global_block
().
var
(
# and op_role_var to get the pair.
op
.
input
(
"Grad"
)[
0
])))
for
input_name
in
op
.
input_arg_names
:
if
input_name
.
find
(
"@GRAD"
)
!=
-
1
and
\
op
.
attrs
[
RPC_OP_ROLE_ATTR_NAME
]:
param_name
=
op
.
attrs
[
OP_ROLE_VAR_ATTR_NAME
][
0
]
params_grads
.
append
([
origin_var_dict
[
param_name
],
origin_var_dict
[
input_name
]
])
elif
self
.
_is_adam_connected_op
(
op
):
elif
self
.
_is_adam_connected_op
(
op
):
opt_ops
.
append
(
op
)
opt_ops
.
append
(
op
)
else
:
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录