Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1882a74f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“f233b936c7b26d4147a6d1cc936dd62546991437”上不存在“benchmark/git@gitcode.net:Crayonxin2000/Paddle.git”
未验证
提交
1882a74f
编写于
2月 21, 2023
作者:
W
wangzhen38
提交者:
GitHub
2月 21, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[RM FLUID] rm ps ir (#50691)
上级
e84fa263
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
30 addition
and
5252 deletion
+30
-5252
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
...dle/distributed/fleet/runtime/parameter_server_runtime.py
+12
-13
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+8
-8
python/paddle/fluid/incubate/fleet/parameter_server/ir/__init__.py
...ddle/fluid/incubate/fleet/parameter_server/ir/__init__.py
+0
-13
python/paddle/fluid/incubate/fleet/parameter_server/ir/ps_dispatcher.py
...fluid/incubate/fleet/parameter_server/ir/ps_dispatcher.py
+0
-123
python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py
.../fluid/incubate/fleet/parameter_server/ir/pserver_pass.py
+0
-1124
python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py
...paddle/fluid/incubate/fleet/parameter_server/ir/public.py
+0
-1511
python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
.../fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
+0
-2147
python/paddle/fluid/incubate/fleet/parameter_server/ir/ufind.py
.../paddle/fluid/incubate/fleet/parameter_server/ir/ufind.py
+0
-64
python/paddle/fluid/incubate/fleet/parameter_server/ir/vars_metatools.py
...luid/incubate/fleet/parameter_server/ir/vars_metatools.py
+0
-235
python/paddle/fluid/tests/unittests/ps/ps_dnn_trainer.py
python/paddle/fluid/tests/unittests/ps/ps_dnn_trainer.py
+2
-2
python/paddle/fluid/tests/unittests/test_fleet_ps.py
python/paddle/fluid/tests/unittests/test_fleet_ps.py
+1
-1
python/paddle/fluid/tests/unittests/test_ps_dispatcher.py
python/paddle/fluid/tests/unittests/test_ps_dispatcher.py
+1
-1
python/paddle/incubate/fleet/parameter_server/ir/pserver_pass.py
...paddle/incubate/fleet/parameter_server/ir/pserver_pass.py
+2
-2
python/paddle/incubate/fleet/parameter_server/ir/public.py
python/paddle/incubate/fleet/parameter_server/ir/public.py
+2
-4
python/paddle/incubate/fleet/parameter_server/ir/trainer_pass.py
...paddle/incubate/fleet/parameter_server/ir/trainer_pass.py
+2
-2
python/setup.py.in
python/setup.py.in
+0
-1
setup.py
setup.py
+0
-1
未找到文件。
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
浏览文件 @
1882a74f
...
@@ -16,8 +16,7 @@ import os
...
@@ -16,8 +16,7 @@ import os
import
warnings
import
warnings
import
paddle
import
paddle
import
paddle.fluid
as
fluid
from
paddle.framework
import
core
from
paddle.fluid
import
core
from
paddle.static
import
(
from
paddle.static
import
(
CompiledProgram
,
CompiledProgram
,
Executor
,
Executor
,
...
@@ -73,7 +72,7 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -73,7 +72,7 @@ class ParameterServerRuntime(RuntimeBase):
return
strategy
return
strategy
def
build_compiled_startegy
(
self
):
def
build_compiled_startegy
(
self
):
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
CompileTimeStrategy
,
CompileTimeStrategy
,
)
)
...
@@ -102,7 +101,7 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -102,7 +101,7 @@ class ParameterServerRuntime(RuntimeBase):
if
main_program
is
None
:
if
main_program
is
None
:
main_program
=
self
.
origin_main_program
main_program
=
self
.
origin_main_program
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_varname_parts
,
_get_varname_parts
,
)
)
...
@@ -111,7 +110,7 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -111,7 +110,7 @@ class ParameterServerRuntime(RuntimeBase):
origin_varname
,
_
,
_
=
_get_varname_parts
(
each_var
.
name
)
origin_varname
,
_
,
_
=
_get_varname_parts
(
each_var
.
name
)
new_var
=
fluid
.
io
.
_clone_var_in_block_
(
load_block
,
each_var
)
new_var
=
paddle
.
static
.
io
.
_clone_var_in_block
(
load_block
,
each_var
)
var_path
=
os
.
path
.
join
(
dirname
,
origin_varname
)
var_path
=
os
.
path
.
join
(
dirname
,
origin_varname
)
if
not
os
.
path
.
exists
(
var_path
):
if
not
os
.
path
.
exists
(
var_path
):
raise
ValueError
(
raise
ValueError
(
...
@@ -138,7 +137,7 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -138,7 +137,7 @@ class ParameterServerRuntime(RuntimeBase):
def
_load_distributed_params
(
self
,
dirname
,
varnames
):
def
_load_distributed_params
(
self
,
dirname
,
varnames
):
from
paddle.distributed.communicator
import
LargeScaleKV
from
paddle.distributed.communicator
import
LargeScaleKV
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_varname_parts
,
_get_varname_parts
,
)
)
...
@@ -154,7 +153,7 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -154,7 +153,7 @@ class ParameterServerRuntime(RuntimeBase):
if
var
.
name
in
exclude_var_names
:
if
var
.
name
in
exclude_var_names
:
return
False
return
False
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_varname_parts
,
_get_varname_parts
,
)
)
...
@@ -185,7 +184,7 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -185,7 +184,7 @@ class ParameterServerRuntime(RuntimeBase):
return
kwargs
return
kwargs
def
geo_strategy_envs
():
def
geo_strategy_envs
():
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
get_sparse_tablenames
,
get_sparse_tablenames
,
)
)
...
@@ -239,14 +238,14 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -239,14 +238,14 @@ class ParameterServerRuntime(RuntimeBase):
kwargs
[
"sparse_attrs"
]
=
get_sparse_attrs
()
kwargs
[
"sparse_attrs"
]
=
get_sparse_attrs
()
return
kwargs
return
kwargs
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
_get_lr_ops
,
_has_global_step
,
)
from
paddle.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
(
from
paddle.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
(
GeoStrategy
,
GeoStrategy
,
SyncStrategy
,
SyncStrategy
,
)
)
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_lr_ops
,
_has_global_step
,
)
trainer_config
=
self
.
async_strategy
.
get_trainer_runtime_config
()
trainer_config
=
self
.
async_strategy
.
get_trainer_runtime_config
()
print
(
trainer_config
)
print
(
trainer_config
)
...
@@ -475,7 +474,7 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -475,7 +474,7 @@ class ParameterServerRuntime(RuntimeBase):
return
reshaped_names
,
origin_names
return
reshaped_names
,
origin_names
def
_get_optimizer_op
(
self
,
param_name
):
def
_get_optimizer_op
(
self
,
param_name
):
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_optimize_ops
,
_get_optimize_ops
,
)
)
...
...
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
1882a74f
...
@@ -36,7 +36,7 @@ PSERVER_SAVE_SUFFIX = ".shard"
...
@@ -36,7 +36,7 @@ PSERVER_SAVE_SUFFIX = ".shard"
def
parse_table_class
(
varname
,
o_main_program
):
def
parse_table_class
(
varname
,
o_main_program
):
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
is_distributed_sparse_op
,
is_distributed_sparse_op
,
is_sparse_op
,
is_sparse_op
,
)
)
...
@@ -247,7 +247,7 @@ class CommonAccessor:
...
@@ -247,7 +247,7 @@ class CommonAccessor:
self
.
opt_init_map
=
opt_init_map
self
.
opt_init_map
=
opt_init_map
def
parse_entry
(
self
,
varname
,
o_main_program
):
def
parse_entry
(
self
,
varname
,
o_main_program
):
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
is_distributed_sparse_op
,
is_distributed_sparse_op
,
is_sparse_op
,
is_sparse_op
,
)
)
...
@@ -304,7 +304,7 @@ class CommonAccessor:
...
@@ -304,7 +304,7 @@ class CommonAccessor:
compiled_strategy
,
compiled_strategy
,
adam_d2sum
,
adam_d2sum
,
):
):
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_optimize_ops
,
_get_optimize_ops
,
)
)
...
@@ -716,7 +716,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -716,7 +716,7 @@ class TheOnePSRuntime(RuntimeBase):
return
strategy
return
strategy
def
build_compiled_startegy
(
self
):
def
build_compiled_startegy
(
self
):
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
CompileTimeStrategy
,
CompileTimeStrategy
,
)
)
...
@@ -1191,7 +1191,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1191,7 +1191,7 @@ class TheOnePSRuntime(RuntimeBase):
proto_txt
,
string_hosts
,
role_id
,
trainers
,
self
.
_server_sub_program
proto_txt
,
string_hosts
,
role_id
,
trainers
,
self
.
_server_sub_program
)
)
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
get_sparse_tablenames
,
get_sparse_tablenames
,
)
)
...
@@ -1252,7 +1252,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1252,7 +1252,7 @@ class TheOnePSRuntime(RuntimeBase):
if
var
.
name
in
exclude_var_names
:
if
var
.
name
in
exclude_var_names
:
return
False
return
False
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_varname_parts
,
_get_varname_parts
,
)
)
...
@@ -1283,7 +1283,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1283,7 +1283,7 @@ class TheOnePSRuntime(RuntimeBase):
def
_save_sparse_params
(
def
_save_sparse_params
(
self
,
executor
,
dirname
,
context
,
main_program
,
mode
self
,
executor
,
dirname
,
context
,
main_program
,
mode
):
):
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
get_sparse_tablenames
,
get_sparse_tablenames
,
)
)
...
@@ -1479,7 +1479,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1479,7 +1479,7 @@ class TheOnePSRuntime(RuntimeBase):
self
.
_ps_inference_save_persistables
(
*
args
,
**
kwargs
)
self
.
_ps_inference_save_persistables
(
*
args
,
**
kwargs
)
def
_load_sparse_params
(
self
,
dirname
,
context
,
main_program
,
mode
):
def
_load_sparse_params
(
self
,
dirname
,
context
,
main_program
,
mode
):
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.public
import
(
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
get_sparse_tablenames
,
get_sparse_tablenames
,
)
)
...
...
python/paddle/fluid/incubate/fleet/parameter_server/ir/__init__.py
已删除
100644 → 0
浏览文件 @
e84fa263
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python/paddle/fluid/incubate/fleet/parameter_server/ir/ps_dispatcher.py
已删除
100644 → 0
浏览文件 @
e84fa263
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
PSDispatcher
:
"""
PSDispatcher is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` interface.
"""
def
__init__
(
self
,
pserver_endpoints
):
self
.
_eps
=
pserver_endpoints
self
.
_step
=
0
@
property
def
eps
(
self
):
return
self
.
_eps
def
reset
(
self
):
"""
reset the step counter, set it zero.
"""
self
.
_step
=
0
def
dispatch
(
self
,
varlist
):
"""
Args:
varlist(list): a list of Variables
Returns:
a map of pserver endpoint -> varname
"""
raise
NotImplementedError
(
"Interface has not been implemented."
)
class
HashName
(
PSDispatcher
):
"""
Hash variable names to several endpoints using python
"hash()" function.
Args:
pserver_endpoints (list): list of endpoint(ip:port).
Examples:
.. code-block:: python
pserver_endpoints = ["127.0.0.1:6007", "127.0.0.1:6008"]
vars = ["var1","var2","var3","var4","var5"]
rr = RoundRobin(pserver_endpoints)
rr.dispatch(vars)
"""
def
__init__
(
self
,
pserver_endpoints
):
super
().
__init__
(
pserver_endpoints
)
def
_hash_block
(
self
,
block_str
,
total
):
return
hash
(
block_str
)
%
total
def
dispatch
(
self
,
varlist
):
"""
use `HashName` method to dispatch variables with each parameter server.
Args:
varlist (list): a list of Variables
"""
eplist
=
[]
for
var
in
varlist
:
server_id
=
self
.
_hash_block
(
var
.
name
(),
len
(
self
.
_eps
))
server_for_param
=
self
.
_eps
[
server_id
]
eplist
.
append
(
server_for_param
)
return
eplist
class
RoundRobin
(
PSDispatcher
):
"""
Distribute variables to several endpoints using
RondRobin<https://en.wikipedia.org/wiki/Round-robin_scheduling> method.
Args:
pserver_endpoints (list): list of endpoint(ip:port).
Examples:
.. code-block:: python
pserver_endpoints = ["127.0.0.1:6007", "127.0.0.1:6008"]
vars = ["var1","var2","var3","var4","var5"]
rr = RoundRobin(pserver_endpoints)
rr.dispatch(vars)
"""
def
__init__
(
self
,
pserver_endpoints
):
super
().
__init__
(
pserver_endpoints
)
def
dispatch
(
self
,
varlist
):
"""
use `RoundRobin` method to dispatch variables with each parameter server.
Args:
varlist (list): a list of Variables
"""
eplist
=
[]
for
var
in
varlist
:
server_for_param
=
self
.
_eps
[
self
.
_step
]
eplist
.
append
(
server_for_param
)
self
.
_step
+=
1
if
self
.
_step
>=
len
(
self
.
_eps
):
self
.
_step
=
0
return
eplist
python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py
已删除
100644 → 0
浏览文件 @
e84fa263
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
from
paddle.framework
import
core
,
Block
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
_get_optimize_ops
,
)
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
_orig_varname
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
_get_varname_parts
,
)
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
is_distributed_sparse_op
,
)
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
get_sparse_tablename
,
)
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
get_sparse_tablenames
,
)
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
_get_lr_ops
LEARNING_RATE_DECAY_COUNTER
=
"@LR_DECAY_COUNTER@"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
RPC_OP_ROLE_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
OPT_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
LR_SCHED_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
LRSched
def
_is_optimizer_op
(
op
):
if
"Param"
in
op
.
input_names
and
"LearningRate"
in
op
.
input_names
:
return
True
return
False
def
_same_or_split_var
(
p_name
,
var_name
):
return
p_name
==
var_name
or
p_name
.
startswith
(
var_name
+
".block"
)
def
_get_optimizer_input_shape
(
op_type
,
varkey
,
orig_shape
,
param_shape
):
"""
Returns the shape for optimizer inputs that need to be reshaped when
Param and Grad is split to multiple servers.
"""
# HACK(typhoonzero) : Should use functions of corresponding optimizer in
# optimizer.py to get the shape, do not bind this in the transpiler.
if
op_type
==
"adam"
:
if
varkey
in
[
"Moment1"
,
"Moment2"
]:
return
param_shape
elif
op_type
==
"adagrad"
:
if
varkey
==
"Moment"
:
return
param_shape
elif
op_type
==
"adamax"
:
if
varkey
in
[
"Moment"
,
"InfNorm"
]:
return
param_shape
elif
op_type
in
[
"momentum"
,
"lars_momentum"
]:
if
varkey
==
"Velocity"
:
return
param_shape
elif
op_type
==
"rmsprop"
:
if
varkey
in
[
"Moment"
,
"MeanSquare"
]:
return
param_shape
elif
op_type
==
"decayed_adagrad"
:
if
varkey
==
"Moment"
:
return
param_shape
elif
op_type
==
"ftrl"
:
if
varkey
in
[
"SquaredAccumulator"
,
"LinearAccumulator"
]:
return
param_shape
elif
op_type
==
"sgd"
:
pass
else
:
raise
ValueError
(
"Not supported optimizer for distributed training: %s"
%
op_type
)
return
orig_shape
def
_append_pserver_non_opt_ops
(
optimize_block
,
opt_op
,
origin_program
,
config
):
def
_get_pserver_grad_param_var
(
var
,
var_dict
):
"""
Return pserver side grad/param variable, return None
if the variable is not grad/param, e.g.
a@GRAD -> a@GRAD.block0
a@GRAD -> a@GRAD (a is not split)
fc_0.w_0 -> fc_0.w_0.block_0
fc_0.w_0 -> fc_0.w_0 (weight is not split)
_generated_var_123 -> None
"""
grad_block
=
None
for
_
,
g
in
var_dict
.
items
():
if
_orig_varname
(
g
.
name
)
==
_orig_varname
(
var
.
name
):
# skip per trainer vars
if
g
.
name
.
find
(
".trainer_"
)
==
-
1
:
# only param or grads have split blocks
ovar_name
=
_orig_varname
(
g
.
name
)
if
ovar_name
in
config
.
param_grad_ep_mapping
:
grad_block
=
g
break
elif
ovar_name
in
config
.
grad_param_mapping
:
grad_block
=
g
break
return
grad_block
program
=
optimize_block
.
program
# Append the ops for parameters that do not need to be optimized / updated
inputs
=
_get_input_map_from_op
(
origin_program
.
global_block
().
vars
,
opt_op
)
for
key
,
varlist
in
inputs
.
items
():
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
i
in
range
(
len
(
varlist
)):
var
=
varlist
[
i
]
# for ops like clipping and weight decay, get the split var(xxx.block0)
# for inputs / outputs
grad_block
=
_get_pserver_grad_param_var
(
var
,
program
.
global_block
().
vars
)
if
grad_block
:
varlist
[
i
]
=
grad_block
elif
var
.
name
not
in
program
.
global_block
().
vars
:
tmpvar
=
program
.
global_block
().
_clone_variable
(
var
)
varlist
[
i
]
=
tmpvar
else
:
varlist
[
i
]
=
program
.
global_block
().
vars
[
var
.
name
]
inputs
[
key
]
=
varlist
outputs
=
_get_output_map_from_op
(
origin_program
.
global_block
().
vars
,
opt_op
)
for
key
,
varlist
in
outputs
.
items
():
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
i
in
range
(
len
(
varlist
)):
var
=
varlist
[
i
]
grad_block
=
_get_pserver_grad_param_var
(
var
,
program
.
global_block
().
vars
)
if
grad_block
:
varlist
[
i
]
=
grad_block
elif
var
.
name
not
in
program
.
global_block
().
vars
:
tmpvar
=
program
.
global_block
().
_clone_variable
(
var
)
varlist
[
i
]
=
tmpvar
else
:
varlist
[
i
]
=
program
.
global_block
().
vars
[
var
.
name
]
outputs
[
key
]
=
varlist
return
optimize_block
.
append_op
(
type
=
opt_op
.
type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
opt_op
.
all_attrs
(),
)
def
_append_pserver_ops
(
optimize_block
,
opt_op
,
endpoint
,
grad_to_block_id
,
origin_program
,
merged_var
,
sparse_grad_to_param
,
config
,
):
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
new_inputs
=
collections
.
OrderedDict
()
def
_get_param_block
(
opt_op
):
# param is already created on global program
unmerged_vars
=
[]
merged_vars
=
[]
merged_ordervars
=
[]
param_vars
=
[
p
for
p
in
config
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
]
for
var
in
param_vars
:
name
=
var
.
name
orig_varname
=
_orig_varname
(
name
)
for
pairs
in
config
.
merged_variables_pairs
:
merged_p
=
pairs
[
0
]
if
merged_p
.
merged_var
.
name
==
orig_varname
:
if
(
merged_p
.
merged_var
.
name
==
merged_p
.
ordered_vars
[
0
].
name
):
unmerged_vars
.
append
(
merged_p
.
ordered_vars
[
0
])
else
:
merged_vars
.
append
(
merged_p
.
merged_var
)
merged_ordervars
.
append
(
merged_p
.
ordered_vars
[
0
])
break
param_name
=
opt_op
.
input
(
"Param"
)[
0
]
for
i
in
range
(
len
(
unmerged_vars
)):
if
_same_or_split_var
(
param_name
,
unmerged_vars
[
i
].
name
):
for
var
in
param_vars
:
if
_same_or_split_var
(
var
.
name
,
unmerged_vars
[
i
].
name
):
return
var
for
i
in
range
(
len
(
merged_ordervars
)):
if
_same_or_split_var
(
param_name
,
merged_ordervars
[
i
].
name
):
for
var
in
param_vars
:
if
_same_or_split_var
(
var
.
name
,
merged_vars
[
i
].
name
):
return
var
return
None
for
key
in
opt_op
.
input_names
:
if
key
==
"Grad"
:
# Note !!This is for l2decay on sparse gradient, \
# because it will create a new tensor for
# decayed gradient but not inplace modify the origin one
origin_grad_name
=
opt_op
.
input
(
key
)[
0
]
if
(
core
.
kNewGradSuffix
()
in
origin_grad_name
and
pserver_block
.
has_var
(
origin_grad_name
)
):
new_grad
=
pserver_block
.
var
(
origin_grad_name
)
new_inputs
[
key
]
=
new_grad
else
:
new_inputs
[
key
]
=
merged_var
elif
key
==
"Param"
:
param_block
=
_get_param_block
(
opt_op
)
if
not
param_block
:
return
tmpvar
=
pserver_block
.
create_var
(
name
=
param_block
.
name
,
persistable
=
True
,
dtype
=
param_block
.
dtype
,
shape
=
param_block
.
shape
,
)
new_inputs
[
key
]
=
tmpvar
elif
key
==
"LearningRate"
:
# learning rate variable has already be created by non - optimize op,
# don't create it once again.
lr_varname
=
opt_op
.
input
(
key
)[
0
]
if
lr_varname
in
pserver_block
.
vars
:
new_inputs
[
key
]
=
pserver_block
.
vars
[
opt_op
.
input
(
key
)[
0
]]
else
:
origin_var
=
origin_program
.
global_block
().
vars
[
lr_varname
]
tmpvar
=
pserver_block
.
create_var
(
name
=
origin_var
.
name
,
persistable
=
origin_var
.
persistable
,
dtype
=
origin_var
.
dtype
,
shape
=
origin_var
.
shape
,
)
new_inputs
[
key
]
=
tmpvar
for
key
in
opt_op
.
input_names
:
new_shape
=
None
if
key
in
[
"Param"
,
"Grad"
,
"LearningRate"
,
"MasterParam"
,
"Beta1Tensor"
,
"Beta2Tensor"
,
]:
continue
var
=
origin_program
.
global_block
().
vars
[
opt_op
.
input
(
key
)[
0
]]
param_var
=
new_inputs
[
"Param"
]
# update accumulator variable shape
new_shape
=
_get_optimizer_input_shape
(
opt_op
.
type
,
key
,
var
.
shape
,
param_var
.
shape
)
tmpvar
=
pserver_block
.
create_var
(
name
=
var
.
name
,
persistable
=
var
.
persistable
,
dtype
=
var
.
dtype
,
shape
=
new_shape
,
)
new_inputs
[
key
]
=
tmpvar
# change output's ParamOut variable
outputs
=
_get_output_map_from_op
(
origin_program
.
global_block
().
vars
,
opt_op
)
outputs
[
"ParamOut"
]
=
new_inputs
[
"Param"
]
optimize_block
.
append_op
(
type
=
opt_op
.
type
,
inputs
=
new_inputs
,
outputs
=
outputs
,
attrs
=
opt_op
.
all_attrs
(),
)
# record sparse grad to param name
if
new_inputs
[
"Grad"
].
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
sparse_grad_to_param
.
append
(
str
(
new_inputs
[
"Grad"
].
name
)
+
":"
+
str
(
new_inputs
[
"Param"
].
name
)
)
def
_get_input_map_from_op
(
varmap
,
op
):
"""Returns a dict from op input name to the vars in varmap."""
iomap
=
collections
.
OrderedDict
()
for
key
in
op
.
input_names
:
vars
=
[]
for
varname
in
op
.
input
(
key
):
vars
.
append
(
varmap
[
varname
])
if
len
(
vars
)
==
1
:
iomap
[
key
]
=
vars
[
0
]
else
:
iomap
[
key
]
=
vars
return
iomap
def
_get_output_map_from_op
(
varmap
,
op
):
"""Returns a dict from op output name to the vars in varmap."""
iomap
=
collections
.
OrderedDict
()
for
key
in
op
.
output_names
:
vars
=
[]
for
varname
in
op
.
output
(
key
):
vars
.
append
(
varmap
[
varname
])
if
len
(
vars
)
==
1
:
iomap
[
key
]
=
vars
[
0
]
else
:
iomap
[
key
]
=
vars
return
iomap
def
get_op_by_type
(
block
,
op_type
):
for
op
in
block
.
ops
:
if
op
.
type
==
op_type
:
return
op
raise
ValueError
(
"add_listen_and_serv_pass must at first"
)
def
add_listen_and_serv_pass
(
program
,
config
):
attrs
=
{
"grad_to_block_id"
:
None
,
"sparse_grad_to_param"
:
None
,
"lr_decay_block_id"
:
None
,
"dense_optimize_blocks"
:
None
,
"sparse_optimize_blocks"
:
None
,
# runtime attribute
"endpoint"
:
config
.
get_ps_endpoint
(),
"pserver_id"
:
config
.
get_role_id
(),
"Fanin"
:
config
.
get_trainers
(),
"distributed_mode"
:
config
.
get_distributed_mode
(),
"rpc_get_thread_num"
:
-
1
,
"rpc_send_thread_num"
:
-
1
,
"rpc_prefetch_thread_num"
:
-
1
,
}
# step5 append the listen_and_serv op
program
.
global_block
().
append_op
(
type
=
"listen_and_serv"
,
inputs
=
{
'X'
:
[]},
outputs
=
{},
attrs
=
attrs
)
return
program
def
add_rpc_global_flags_pass
(
program
,
config
):
server_runtime
=
config
.
get_server_runtime_config
()
send_threads
=
server_runtime
.
_rpc_send_thread_num
get_threads
=
server_runtime
.
_rpc_get_thread_num
pull_threads
=
server_runtime
.
_rpc_prefetch_thread_num
op
=
get_op_by_type
(
program
.
global_block
(),
"listen_and_serv"
)
if
get_threads
<
1
or
send_threads
<
1
or
pull_threads
<
1
:
raise
ValueError
(
"error arguments in get_threads/send_threads/pull_threads"
)
op
.
_set_attr
(
"rpc_get_thread_num"
,
get_threads
)
op
.
_set_attr
(
"rpc_send_thread_num"
,
send_threads
)
op
.
_set_attr
(
"rpc_prefetch_thread_num"
,
pull_threads
)
return
program
def
_clone_var
(
block
,
var
,
persistable
=
True
):
return
block
.
create_var
(
name
=
var
.
name
,
shape
=
var
.
shape
,
dtype
=
var
.
dtype
,
type
=
var
.
type
,
lod_level
=
var
.
lod_level
,
persistable
=
persistable
,
)
def
add_optimizer_pass
(
program
,
config
):
def
_append_pserver_grad_merge_ops
(
optimize_block
,
grad_varname_for_block
,
endpoint
,
grad_to_block_id
):
trainers
=
config
.
get_trainers
()
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
grad_block
=
None
for
g
in
config
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
if
_orig_varname
(
g
.
name
)
==
_orig_varname
(
grad_varname_for_block
):
grad_block
=
g
break
if
not
grad_block
:
# do not append this op if current endpoint
# is not dealing with this grad block
return
None
orig_varname
,
block_name
,
trainer_name
=
_get_varname_parts
(
grad_block
.
name
)
if
block_name
:
merged_var_name
=
'.'
.
join
([
orig_varname
,
block_name
])
else
:
merged_var_name
=
orig_varname
merged_var
=
pserver_block
.
create_var
(
name
=
grad_block
.
name
,
persistable
=
True
,
type
=
grad_block
.
type
,
dtype
=
grad_block
.
dtype
,
shape
=
grad_block
.
shape
,
)
grad_to_block_id
.
append
(
merged_var
.
name
+
":"
+
str
(
optimize_block
.
idx
))
if
config
.
is_sync_mode
()
and
trainers
>
1
:
vars2merge
=
[]
for
i
in
range
(
trainers
):
per_trainer_name
=
"%s.trainer_%d"
%
(
merged_var_name
,
i
)
per_trainer_var
=
pserver_block
.
create_var
(
name
=
per_trainer_name
,
persistable
=
False
,
type
=
grad_block
.
type
,
dtype
=
grad_block
.
dtype
,
shape
=
grad_block
.
shape
,
)
vars2merge
.
append
(
per_trainer_var
)
optimize_block
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
vars2merge
},
outputs
=
{
"Out"
:
merged_var
},
attrs
=
{
"use_mkldnn"
:
False
},
)
optimize_block
.
append_op
(
type
=
"scale"
,
inputs
=
{
"X"
:
merged_var
},
outputs
=
{
"Out"
:
merged_var
},
attrs
=
{
"scale"
:
1.0
/
float
(
trainers
)},
)
return
merged_var
origin_program
=
config
.
get_origin_main_program
()
origin_program
=
origin_program
.
clone
()
ps_endpoint
=
config
.
get_ps_endpoint
()
opt_op_on_pserver
=
[]
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
global_ops
=
[]
# sparse grad name to param name
sparse_grad_to_param
=
[]
def
_is_opt_op_on_pserver
(
endpoint
,
op
):
param_names
=
[
p
.
name
for
p
in
config
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
]
unmerged_varnames
=
[]
merged_varnames
=
[]
merged_ordernames
=
[]
for
name
in
param_names
:
orig_varname
=
_orig_varname
(
name
)
for
pairs
in
config
.
merged_variables_pairs
:
merged_p
=
pairs
[
0
]
if
merged_p
.
merged_var
.
name
==
orig_varname
:
if
(
merged_p
.
merged_var
.
name
==
merged_p
.
ordered_vars
[
0
].
name
):
unmerged_varnames
.
append
(
merged_p
.
ordered_vars
[
0
].
name
)
else
:
merged_varnames
.
append
(
merged_p
.
merged_var
.
name
)
merged_ordernames
.
append
(
merged_p
.
ordered_vars
[
0
].
name
)
break
param
=
op
.
input
(
"Param"
)[
0
]
if
param
in
unmerged_varnames
:
return
True
for
i
in
range
(
len
(
merged_ordernames
)):
if
param
==
merged_ordernames
[
i
]:
merged_p
=
merged_varnames
[
i
]
merged_g
=
"{}@GRAD"
.
format
(
merged_varnames
[
i
])
op
.
_set_attr
(
OP_ROLE_VAR_ATTR_NAME
,
[
merged_p
,
merged_g
])
return
True
return
False
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
,
merged_var
,
lr_ops
):
if
_is_optimizer_op
(
op
):
_append_pserver_ops
(
block
,
op
,
ps_endpoint
,
grad_to_block_id
,
origin_program
,
merged_var
,
sparse_grad_to_param
,
config
,
)
elif
op
not
in
lr_ops
:
_append_pserver_non_opt_ops
(
block
,
op
,
origin_program
,
config
)
optimize_ops
=
_get_optimize_ops
(
origin_program
)
for
_
,
op
in
enumerate
(
optimize_ops
):
if
_is_optimizer_op
(
op
)
and
_is_opt_op_on_pserver
(
ps_endpoint
,
op
):
opt_op_on_pserver
.
append
(
op
)
# append lr decay ops to the child block if exists
lr_ops
=
_get_lr_ops
(
origin_program
)
has_lr_decay
=
True
if
len
(
lr_ops
)
>
0
else
False
lr_decay_block_id
=
-
1
optimize_blocks
=
[]
if
has_lr_decay
>
0
:
counter_increment_idx
=
-
1
for
idx
,
op
in
enumerate
(
lr_ops
):
if
op
.
type
!=
'increment'
:
continue
counter
=
op
.
input
(
"X"
)[
0
]
if
counter
==
LEARNING_RATE_DECAY_COUNTER
:
counter_increment_idx
=
idx
break
if
counter_increment_idx
!=
-
1
:
lr_ops
.
pop
(
counter_increment_idx
)
lr_decay_block
=
program
.
_create_block
(
program
.
num_blocks
-
1
)
optimize_blocks
.
append
(
lr_decay_block
)
for
op
in
lr_ops
:
cloned_op
=
_append_pserver_non_opt_ops
(
lr_decay_block
,
op
,
origin_program
,
config
)
# append sub blocks to pserver_program in lr_decay_op
# todo(tangwei12): __clone_lr_op_sub_block__
lr_decay_block_id
=
lr_decay_block
.
idx
# append op to the current block
grad_to_block_id
=
[]
pre_block_idx
=
program
.
num_blocks
-
1
for
idx
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
per_opt_block
=
program
.
_create_block
(
pre_block_idx
)
optimize_blocks
.
append
(
per_opt_block
)
optimize_target_param_name
=
opt_op
.
attr
(
OP_ROLE_VAR_ATTR_NAME
)[
0
]
# append grad merging ops before clip and weight decay
# e.g.merge grad->L2Decay op->clip op->optimize
merged_var
=
None
for
_
,
op
in
enumerate
(
optimize_ops
):
# find the origin grad var before clipping / L2Decay,
# merged_var should be the input var name of L2Decay
grad_varname_for_block
=
op
.
attr
(
OP_ROLE_VAR_ATTR_NAME
)[
1
]
if
op
.
attr
(
OP_ROLE_VAR_ATTR_NAME
)[
0
]
==
optimize_target_param_name
:
merged_var
=
_append_pserver_grad_merge_ops
(
per_opt_block
,
grad_varname_for_block
,
ps_endpoint
,
grad_to_block_id
,
)
if
merged_var
:
break
# append optimize op once then append other ops.
if
merged_var
:
for
_
,
op
in
enumerate
(
optimize_ops
):
# optimizer is connected to itself
if
(
op
.
attr
(
OP_ROLE_VAR_ATTR_NAME
)[
0
]
==
optimize_target_param_name
and
op
not
in
global_ops
):
__append_optimize_op__
(
op
,
per_opt_block
,
grad_to_block_id
,
merged_var
,
lr_ops
)
# dedup grad to ids list
grad_to_block_id
=
list
(
set
(
grad_to_block_id
))
# append global ops
if
global_ops
:
opt_state_block
=
program
.
_create_block
(
program
.
num_blocks
-
1
)
optimize_blocks
.
append
(
opt_state_block
)
for
glb_op
in
global_ops
:
__append_optimize_op__
(
glb_op
,
opt_state_block
,
grad_to_block_id
,
None
,
lr_ops
)
if
len
(
optimize_blocks
)
==
0
:
pre_block_idx
=
program
.
num_blocks
-
1
empty_block
=
program
.
_create_block
(
pre_block_idx
)
optimize_blocks
.
append
(
empty_block
)
op
=
get_op_by_type
(
program
.
global_block
(),
"listen_and_serv"
)
op
.
_set_attr
(
"optimize_blocks"
,
optimize_blocks
)
op
.
_set_attr
(
"grad_to_block_id"
,
grad_to_block_id
)
op
.
_set_attr
(
"sparse_grad_to_param"
,
sparse_grad_to_param
)
op
.
_set_attr
(
"lr_decay_block_id"
,
lr_decay_block_id
)
return
program
def
large_scale_sparse_pass
(
program
,
main_program
,
config
,
is_startup
=
False
):
opt_value_map
=
{}
opt_value_map
[
"sgd"
]
=
[
"Param"
]
opt_value_map
[
"adam"
]
=
[
"Param"
,
"Moment1"
,
"Moment2"
]
opt_value_map
[
"adagrad"
]
=
[
"Param"
,
"Moment"
]
opt_value_map
[
"adamax"
]
=
[
"Param"
,
"Moment"
,
"InfNorm"
]
opt_value_map
[
"momentum"
]
=
[
"Param"
,
"Velocity"
]
opt_value_map
[
"lars_momentum"
]
=
[
"Param"
,
"Velocity"
]
opt_value_map
[
"rmsprop"
]
=
[
"Param"
,
"Moment"
,
"MeanSquare"
]
opt_value_map
[
"decayed_adagrad"
]
=
[
"Param"
,
"Moment"
]
opt_value_map
[
"ftrl"
]
=
[
"Param"
,
"SquaredAccumulator"
,
"LinearAccumulator"
]
geo_value_map
=
{}
geo_value_map
[
"sum"
]
=
"Param"
opt_init_map
=
{}
opt_init_map
[
"gaussian_random"
]
=
[
"seed"
,
"mean"
,
"std"
]
opt_init_map
[
"fill_constant"
]
=
[
"value"
]
opt_init_map
[
"uniform_random"
]
=
[
"seed"
,
"min"
,
"max"
]
opt_init_map
[
"truncated_gaussian_random"
]
=
[
"seed"
,
"mean"
,
"std"
]
def
get_entry_attr
(
param_name
):
origin_name
=
_orig_varname
(
param_name
)
o_main_program
=
config
.
get_origin_main_program
()
for
op
in
o_main_program
.
global_block
().
ops
:
if
(
is_distributed_sparse_op
(
op
)
and
get_sparse_tablename
(
op
)
==
origin_name
):
entry
=
op
.
attr
(
"entry"
)
return
entry
def
get_initializer_attrs
(
acture_value_names
):
l_sep
=
","
l_in
=
"&"
init_attrs
=
[]
o_startup_program
=
config
.
get_origin_startup_program
()
for
value_name
in
acture_value_names
:
origin_var_name
=
_orig_varname
(
value_name
)
for
op
in
o_startup_program
.
global_block
().
ops
:
if
(
op
.
type
in
opt_init_map
.
keys
()
and
origin_var_name
==
op
.
output
(
"Out"
)[
0
]
):
init_attr
=
[
op
.
type
]
for
attr
in
opt_init_map
[
op
.
type
]:
init_attr
.
append
(
str
(
op
.
attr
(
attr
)))
init_attrs
.
append
(
l_in
.
join
(
init_attr
))
break
return
l_sep
.
join
(
init_attrs
)
def
get_optimizer_values
(
block
):
value_names
=
[]
acture_names
=
[]
value_dims
=
[]
grad
=
None
opt_idx
=
-
1
fuse
=
False
for
op
in
block
.
ops
:
opt_idx
+=
1
if
op
.
type
not
in
opt_value_map
.
keys
():
continue
if
op
.
type
in
[
"sgd"
,
"adam"
]:
fuse
=
True
grad
=
main_program
.
global_block
().
vars
[
op
.
input
(
"Grad"
)[
0
]]
for
value
in
opt_value_map
[
op
.
type
]:
var
=
main_program
.
global_block
().
vars
[
op
.
input
(
value
)[
0
]]
if
len
(
var
.
shape
)
!=
2
:
raise
ValueError
(
"sparse param's dimension must be 2"
)
value_names
.
append
(
value
)
value_dims
.
append
(
var
.
shape
[
1
])
acture_names
.
append
(
var
.
name
)
if
value_names
:
break
return
grad
,
opt_idx
,
value_names
,
value_dims
,
acture_names
,
fuse
def
add_fuse_large_scale_op
(
block
,
global_block
,
table_name
,
value_names
,
acture_names
,
grad
,
is_entry
,
opt_idx
,
):
op
=
block
.
ops
[
opt_idx
]
if
op
.
type
==
"sgd"
:
grad
=
main_program
.
global_block
().
vars
[
op
.
input
(
"Grad"
)[
0
]]
lr
=
main_program
.
global_block
().
vars
[
op
.
input
(
"LearningRate"
)[
0
]]
block
.
_insert_op
(
opt_idx
,
type
=
"lookup_sparse_table_fuse_sgd"
,
inputs
=
{
"Grad"
:
grad
,
"LearningRate"
:
lr
},
attrs
=
{
"is_entry"
:
is_entry
,
"tablename"
:
table_name
,
"value_names"
:
value_names
,
},
)
elif
op
.
type
==
"adam"
:
grad
=
main_program
.
global_block
().
vars
[
op
.
input
(
"Grad"
)[
0
]]
lr
=
main_program
.
global_block
().
vars
[
op
.
input
(
"LearningRate"
)[
0
]]
beta1_pow
=
main_program
.
global_block
().
vars
[
op
.
input
(
"Beta1Pow"
)[
0
]
]
beta2_pow
=
main_program
.
global_block
().
vars
[
op
.
input
(
"Beta2Pow"
)[
0
]
]
beta1_pow_o
=
main_program
.
global_block
().
vars
[
op
.
output
(
"Beta1PowOut"
)[
0
]
]
beta2_pow_o
=
main_program
.
global_block
().
vars
[
op
.
output
(
"Beta2PowOut"
)[
0
]
]
beta1
=
op
.
attr
(
'beta1'
)
beta2
=
op
.
attr
(
'beta2'
)
epsilon
=
op
.
attr
(
'epsilon'
)
block
.
_insert_op
(
opt_idx
,
type
=
"lookup_sparse_table_fuse_adam"
,
inputs
=
{
"Grad"
:
grad
,
"LearningRate"
:
lr
,
"Beta1Pow"
:
beta1_pow
,
"Beta2Pow"
:
beta2_pow
,
},
outputs
=
{
"Beta1PowOut"
:
beta1_pow_o
,
"Beta2PowOut"
:
beta2_pow_o
,
},
attrs
=
{
"beta1"
:
beta1
,
"beta2"
:
beta2
,
"epsilon"
:
epsilon
,
"is_entry"
:
is_entry
,
"tablename"
:
table_name
,
"value_names"
:
value_names
,
},
)
else
:
raise
ValueError
(
"only support sgd/adam optimizer now"
)
def
add_large_scale_op
(
block
,
global_block
,
table_name
,
value_names
,
acture_names
,
grad
,
is_entry
,
opt_idx
,
):
ids
=
global_block
.
create_var
(
name
=
"kSparseIDs@{}"
.
format
(
table_name
),
persistable
=
False
,
dtype
=
"int64"
,
shape
=
[
1
,
1
],
lod_level
=
0
,
)
# insert grad split to ids and tensor op
block
.
_insert_op
(
opt_idx
,
type
=
"lookup_sparse_table_grad_split"
,
inputs
=
{
"Grad"
:
grad
},
outputs
=
{
"Row"
:
ids
,
"Value"
:
grad
},
attrs
=
{
"tablename"
:
table_name
,
"is_entry"
:
is_entry
},
)
# insert read at first
vars
=
[
global_block
.
vars
[
acture_name
]
for
acture_name
in
acture_names
]
block
.
_insert_op
(
opt_idx
+
1
,
type
=
"lookup_sparse_table_read"
,
inputs
=
{
"Ids"
:
ids
},
outputs
=
{
"Out"
:
vars
},
attrs
=
{
"tablename"
:
table_name
,
"value_names"
:
value_names
},
)
# append write at last
inputs
=
{
"Ids"
:
ids
,
"In"
:
vars
}
block
.
append_op
(
type
=
"lookup_sparse_table_write"
,
inputs
=
inputs
,
outputs
=
{},
attrs
=
{
"tablename"
:
table_name
,
"value_names"
:
value_names
},
)
op
=
get_op_by_type
(
main_program
.
global_block
(),
"listen_and_serv"
)
param_blockid_map
=
{}
grad_blockid_map
=
{}
grad_to_params
=
op
.
attr
(
'sparse_grad_to_param'
)
grad_to_block_ids
=
op
.
attr
(
'grad_to_block_id'
)
origin_program
=
config
.
get_origin_main_program
()
sparse_varnames
=
get_sparse_tablenames
(
origin_program
,
False
)
for
grad_to_block_id
in
grad_to_block_ids
:
grad
,
blockid
=
grad_to_block_id
.
split
(
":"
)
grad_blockid_map
[
grad
]
=
int
(
blockid
)
for
grad_to_param
in
grad_to_params
:
grad
,
param
=
grad_to_param
.
split
(
":"
)
if
_orig_varname
(
param
)
in
sparse_varnames
:
continue
param_blockid_map
[
param
]
=
grad_blockid_map
[
grad
]
if
not
is_startup
:
for
param
,
blockid
in
param_blockid_map
.
items
():
opt_block
=
program
.
block
(
blockid
)
(
grad
,
opt_idx
,
value_names
,
value_dims
,
acture_names
,
fuse
,
)
=
get_optimizer_values
(
opt_block
)
entry_attr
=
get_entry_attr
(
param
)
is_entry
=
False
if
entry_attr
==
"none"
else
True
if
fuse
:
add_fuse_large_scale_op
(
opt_block
,
program
.
global_block
(),
param
,
value_names
,
acture_names
,
grad
,
is_entry
,
opt_idx
,
)
else
:
add_large_scale_op
(
opt_block
,
program
.
global_block
(),
param
,
value_names
,
acture_names
,
grad
,
is_entry
,
opt_idx
,
)
else
:
large_scale_kv_metas
=
[]
for
param
,
blockid
in
param_blockid_map
.
items
():
opt_block
=
main_program
.
block
(
blockid
)
(
grad
,
opt_idx
,
value_names
,
value_dims
,
acture_names
,
fuse
,
)
=
get_optimizer_values
(
opt_block
)
entry_attr
=
get_entry_attr
(
param
)
if
fuse
:
# remove origin optimzier op
opt_block
.
_remove_op
(
opt_idx
)
# training/infer
mode
=
"0"
names_str
=
","
.
join
(
value_names
)
dims_str
=
","
.
join
([
str
(
dim
)
for
dim
in
value_dims
])
ids_name
=
"kSparseIDs@{}"
.
format
(
param
)
cached_str
=
","
.
join
(
acture_names
+
[
ids_name
])
init_attr_str
=
get_initializer_attrs
(
acture_names
)
meta_str
=
":"
.
join
(
[
param
,
names_str
,
dims_str
,
mode
,
grad
.
name
,
cached_str
,
init_attr_str
,
entry_attr
,
]
)
print
(
"large_scale_metas: {}"
.
format
(
meta_str
))
large_scale_kv_metas
.
append
(
meta_str
)
program
.
global_block
().
append_op
(
type
=
"lookup_sparse_table_init"
,
inputs
=
None
,
outputs
=
None
,
attrs
=
{
"large_scale_metas"
:
large_scale_kv_metas
},
)
# todo: need delete unused var.
return
program
def
get_distributed_from_listen_and_serv
(
program
,
origin_program
):
op
=
get_op_by_type
(
program
.
global_block
(),
"listen_and_serv"
)
sparse_varnames
=
get_sparse_tablenames
(
origin_program
,
True
)
sparse_params
=
[]
grad_to_params
=
op
.
attr
(
'sparse_grad_to_param'
)
for
grad_to_param
in
grad_to_params
:
_
,
param
=
grad_to_param
.
split
(
":"
)
if
_orig_varname
(
param
)
in
sparse_varnames
:
sparse_params
.
append
(
param
)
return
sparse_params
def
delete_unused_in_main_pass
(
program
,
config
):
origin_program
=
config
.
get_origin_main_program
()
sparse_params
=
get_distributed_from_listen_and_serv
(
program
,
origin_program
)
for
var
in
sparse_params
:
if
program
.
global_block
().
has_var
(
var
):
program
.
global_block
().
_remove_var
(
var
)
return
program
def
delete_unused_in_startup_pass
(
program
,
main_program
,
config
):
origin_program
=
config
.
get_origin_main_program
()
sparse_params
=
get_distributed_from_listen_and_serv
(
main_program
,
origin_program
)
remove_ops
=
[]
for
op
in
program
.
global_block
().
ops
:
if
op
.
type
in
[
"recv"
,
"fetch_barrier"
,
"concat"
]:
continue
for
key
in
op
.
output_names
:
if
op
.
output
(
key
)[
0
]
in
sparse_params
:
remove_ops
.
append
(
op
)
all_ops
=
program
.
global_block
().
ops
op_idxs
=
[
all_ops
.
index
(
op
)
for
op
in
remove_ops
]
for
idx
in
op_idxs
[::
-
1
]:
program
.
global_block
().
_remove_op
(
idx
)
for
var
in
sparse_params
:
if
program
.
global_block
().
has_var
(
var
):
program
.
global_block
().
_remove_var
(
var
)
return
program
def
build_pserver_startup_program_pass
(
program
,
p_main_program
,
config
):
ps_endpoint
=
config
.
get_ps_endpoint
()
o_startup_program
=
config
.
get_origin_startup_program
()
program
.
random_seed
=
o_startup_program
.
random_seed
params
=
config
.
param_grad_ep_mapping
[
ps_endpoint
][
"params"
]
merged_ordervars
=
[]
for
var
in
params
:
name
=
var
.
name
orig_varname
=
_orig_varname
(
name
)
for
pairs
in
config
.
merged_variables_pairs
:
merged_p
=
pairs
[
0
]
if
merged_p
.
merged_var
.
name
==
orig_varname
:
if
merged_p
.
merged_var
.
name
!=
merged_p
.
ordered_vars
[
0
].
name
:
merged_ordervars
.
append
(
merged_p
.
ordered_vars
[
0
])
break
def
_get_splited_name_and_shape
(
varname
):
for
splited_param
in
params
:
pname
=
splited_param
.
name
if
_same_or_split_var
(
pname
,
varname
)
and
varname
!=
pname
:
return
pname
,
splited_param
.
shape
for
idx
,
ordered
in
enumerate
(
merged_ordervars
):
if
_same_or_split_var
(
varname
,
ordered
.
name
):
return
pname
,
splited_param
.
shape
return
""
,
[]
# 1. create vars in pserver program to startup program
pserver_vars
=
p_main_program
.
global_block
().
vars
created_var_map
=
collections
.
OrderedDict
()
for
_
,
var
in
pserver_vars
.
items
():
tmpvar
=
program
.
global_block
().
_clone_variable
(
var
)
created_var_map
[
var
.
name
]
=
tmpvar
# 2. rename op outputs
for
op
in
o_startup_program
.
global_block
().
ops
:
new_outputs
=
collections
.
OrderedDict
()
# do not append startup op if var is not on this pserver
op_on_pserver
=
False
# TODO(gongwb) : remove this line.
if
op
.
type
not
in
[
"recv"
,
"fetch_barrier"
,
"concat"
]:
for
key
in
op
.
output_names
:
newname
,
_
=
_get_splited_name_and_shape
(
op
.
output
(
key
)[
0
])
if
newname
:
op_on_pserver
=
True
new_outputs
[
key
]
=
created_var_map
[
newname
]
elif
op
.
output
(
key
)[
0
]
in
pserver_vars
:
op_on_pserver
=
True
new_outputs
[
key
]
=
pserver_vars
[
op
.
output
(
key
)[
0
]]
if
op_on_pserver
:
# most startup program ops have no inputs
new_inputs
=
_get_input_map_from_op
(
pserver_vars
,
op
)
if
op
.
type
in
[
"gaussian_random"
,
"fill_constant"
,
"uniform_random"
,
"truncated_gaussian_random"
,
]:
op
.
_set_attr
(
"shape"
,
list
(
new_outputs
[
"Out"
].
shape
))
program
.
global_block
().
append_op
(
type
=
op
.
type
,
inputs
=
new_inputs
,
outputs
=
new_outputs
,
attrs
=
op
.
all_attrs
(),
)
return
program
def
add_geo_optimizer_pass
(
program
,
config
):
endpoint
=
config
.
get_ps_endpoint
()
params
=
[
p
for
p
in
config
.
param_grad_ep_mapping
[
endpoint
][
"params"
]]
sparse_tablenames
=
get_sparse_tablenames
(
config
.
get_origin_main_program
(),
False
)
for
param
in
params
:
_clone_var
(
program
.
global_block
(),
param
)
optimize_block
=
[]
sparse_grad_to_param
=
[]
param_to_block_id
=
[]
pre_block_idx
=
program
.
num_blocks
-
1
for
param
in
params
:
per_opt_block
=
program
.
_create_block
(
pre_block_idx
)
optimize_block
.
append
(
per_opt_block
)
var_name
=
param
.
name
pserver_block
=
per_opt_block
.
program
.
global_block
()
param
=
pserver_block
.
vars
[
var_name
]
delta_var_name
=
"%s.delta"
%
(
param
.
name
)
origin_varname
=
_orig_varname
(
param
.
name
)
if
origin_varname
in
sparse_tablenames
:
sparse_grad_to_param
.
append
(
":"
.
join
([
delta_var_name
,
param
.
name
]))
delta_var
=
pserver_block
.
create_var
(
name
=
delta_var_name
,
persistable
=
False
,
type
=
param
.
type
,
dtype
=
param
.
dtype
,
shape
=
param
.
shape
,
)
per_opt_block
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
[
param
,
delta_var
]},
outputs
=
{
"Out"
:
param
}
)
param_to_block_id
.
append
(
delta_var_name
+
":"
+
str
(
per_opt_block
.
idx
))
op
=
get_op_by_type
(
program
.
global_block
(),
"listen_and_serv"
)
op
.
_set_attr
(
"optimize_blocks"
,
optimize_block
)
op
.
_set_attr
(
"grad_to_block_id"
,
param_to_block_id
)
op
.
_set_attr
(
"sparse_grad_to_param"
,
sparse_grad_to_param
)
return
program
python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py
已删除
100755 → 0
浏览文件 @
e84fa263
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
paddle
import
collections
import
math
import
os
import
warnings
import
logging
from
paddle.framework
import
core
from
paddle.fluid.incubate.fleet.parameter_server.mode
import
DistributedMode
from
paddle.fluid.incubate.fleet.parameter_server.ir
import
vars_metatools
from
paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher
import
(
RoundRobin
,
PSDispatcher
,
)
OP_NAME_SCOPE
=
"op_namescope"
CLIP_OP_NAME_SCOPE
=
"gradient_clip"
STEP_COUNTER
=
"@PS_STEP_COUNTER@"
LEARNING_RATE_DECAY_COUNTER
=
"@LR_DECAY_COUNTER@"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
RPC_OP_ROLE_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
LR_SCHED_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
LRSched
OPT_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
SPARSE_OP_LIST
=
[
"lookup_table"
,
"lookup_table_v2"
]
SPARSE_OP_TYPE_DICT
=
{
"lookup_table"
:
"W"
,
"lookup_table_v2"
:
"W"
}
def
_get_lr_ops
(
program
):
lr_ops
=
[]
for
index
,
op
in
enumerate
(
program
.
global_block
().
ops
):
role_id
=
int
(
op
.
attr
(
RPC_OP_ROLE_ATTR_NAME
))
if
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
or
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
|
int
(
OPT_OP_ROLE_ATTR_VALUE
):
lr_ops
.
append
(
op
)
return
lr_ops
def
_has_global_step
(
lr_ops
):
if
len
(
lr_ops
)
>
0
:
for
idx
,
op
in
enumerate
(
lr_ops
):
if
op
.
type
!=
'increment'
:
continue
counter
=
op
.
input
(
"X"
)[
0
]
if
counter
==
LEARNING_RATE_DECAY_COUNTER
:
return
True
return
False
def
is_sparse_op
(
op
):
if
(
op
.
type
in
SPARSE_OP_LIST
and
op
.
attr
(
'is_sparse'
)
is
True
and
op
.
attr
(
'is_distributed'
)
is
False
):
return
True
if
(
op
.
type
==
"distributed_lookup_table"
and
op
.
attr
(
'is_distributed'
)
is
False
):
return
True
return
False
def
is_distributed_sparse_op
(
op
):
if
op
.
type
in
SPARSE_OP_LIST
and
op
.
attr
(
'is_distributed'
)
is
True
:
return
True
if
(
op
.
type
==
"distributed_lookup_table"
and
op
.
attr
(
'is_distributed'
)
is
True
):
return
True
return
False
def
get_sparse_tablename
(
op
):
return
op
.
input
(
"W"
)[
0
]
def
get_sparse_tablenames
(
program
,
is_distributed
):
tablenames
=
set
()
if
is_distributed
:
for
op
in
program
.
global_block
().
ops
:
if
is_distributed_sparse_op
(
op
):
tablenames
.
add
(
get_sparse_tablename
(
op
))
else
:
for
op
in
program
.
global_block
().
ops
:
if
is_sparse_op
(
op
):
tablenames
.
add
(
get_sparse_tablename
(
op
))
return
list
(
tablenames
)
class
MergedVariable
:
def
__init__
(
self
,
merged
,
ordered
,
offsets
):
self
.
merged_var
=
merged
self
.
ordered_vars
=
ordered
self
.
offsets
=
offsets
def
Singleton
(
cls
):
_instance
=
{}
def
_singleton
(
*
args
,
**
kargs
):
if
cls
not
in
_instance
:
_instance
[
cls
]
=
cls
(
*
args
,
**
kargs
)
return
_instance
[
cls
]
return
_singleton
@
Singleton
class
CompileTimeStrategy
:
def
__init__
(
self
,
main_program
,
startup_program
,
strategy
,
role_maker
):
self
.
min_block_size
=
81920
self
.
origin_main_program
=
main_program
self
.
origin_startup_program
=
startup_program
self
.
origin_ps_main_program
=
main_program
self
.
origin_ps_startup_program
=
startup_program
self
.
strategy
=
strategy
self
.
role_maker
=
role_maker
self
.
use_ps_gpu
=
False
try
:
self
.
is_heter_ps_mode
=
role_maker
.
_is_heter_parameter_server_mode
except
:
warnings
.
warn
(
"Using paddle.distributed.fleet instead of paddle.fluid.incubate.fleet"
)
self
.
is_heter_ps_mode
=
False
self
.
origin_sparse_pairs
=
[]
self
.
origin_dense_pairs
=
[]
self
.
merged_variables_pairs
=
[]
self
.
merged_dense_pairs
=
[]
self
.
merged_sparse_pairs
=
[]
self
.
merged_variable_map
=
{}
self
.
param_name_to_grad_name
=
{}
self
.
grad_name_to_param_name
=
{}
self
.
param_grad_ep_mapping
=
collections
.
OrderedDict
()
self
.
grad_param_mapping
=
collections
.
OrderedDict
()
self
.
_build_var_distributed
()
self
.
tensor_table_dict
=
{}
# for heter-ps save variables
self
.
origin_merged_variables_pairs
=
list
(
self
.
merged_variables_pairs
)
self
.
origin_merged_dense_pairs
=
list
(
self
.
merged_dense_pairs
)
self
.
origin_merged_sparse_pairs
=
list
(
self
.
merged_sparse_pairs
)
def
get_distributed_mode
(
self
):
trainer
=
self
.
strategy
.
get_trainer_runtime_config
()
return
trainer
.
mode
def
is_sync_mode
(
self
):
trainer
=
self
.
strategy
.
get_trainer_runtime_config
()
return
trainer
.
mode
==
DistributedMode
.
SYNC
def
is_geo_mode
(
self
):
trainer
=
self
.
strategy
.
get_trainer_runtime_config
()
return
trainer
.
mode
==
DistributedMode
.
GEO
def
is_async_mode
(
self
):
trainer
=
self
.
strategy
.
get_trainer_runtime_config
()
return
trainer
.
mode
==
DistributedMode
.
ASYNC
def
get_role_id
(
self
):
try
:
return
self
.
role_maker
.
_role_id
()
except
Exception
:
return
self
.
role_maker
.
role_id
()
def
get_trainers
(
self
):
try
:
return
self
.
role_maker
.
_worker_num
()
except
Exception
:
return
self
.
role_maker
.
worker_num
()
def
get_ps_endpoint
(
self
):
try
:
return
self
.
role_maker
.
_get_pserver_endpoints
()[
self
.
get_role_id
()]
except
Exception
:
return
self
.
role_maker
.
get_pserver_endpoints
()[
self
.
get_role_id
()]
def
get_ps_endpoints
(
self
):
try
:
return
self
.
role_maker
.
_get_pserver_endpoints
()
except
Exception
:
return
self
.
role_maker
.
get_pserver_endpoints
()
def
get_heter_worker_endpoints
(
self
):
try
:
return
self
.
role_maker
.
_get_heter_worker_endpoints
()
except
Exception
:
return
self
.
role_maker
.
get_heter_worker_endpoints
()
def
get_next_stage_trainers
(
self
):
try
:
return
self
.
role_maker
.
_get_next_trainers
()
except
Exception
:
return
self
.
role_maker
.
get_next_trainers
()
def
get_heter_worker_endpoint
(
self
):
try
:
return
self
.
role_maker
.
_get_heter_worker_endpoint
()
except
Exception
:
return
self
.
role_maker
.
get_heter_worker_endpoint
()
def
get_trainer_endpoints
(
self
):
try
:
return
self
.
role_maker
.
_get_trainer_endpoints
()
except
Exception
:
return
self
.
role_maker
.
get_trainer_endpoints
()
def
get_trainer_endpoint
(
self
):
try
:
return
self
.
role_maker
.
_get_trainer_endpoint
()
except
Exception
:
return
self
.
role_maker
.
get_trainer_endpoint
()
def
get_previous_stage_trainers
(
self
):
try
:
return
self
.
role_maker
.
_get_previous_trainers
()
except
Exception
:
return
self
.
role_maker
.
get_previous_trainers
()
def
get_origin_programs
(
self
):
return
self
.
origin_main_program
,
self
.
origin_startup_program
def
get_origin_main_program
(
self
):
return
self
.
origin_main_program
def
get_origin_startup_program
(
self
):
return
self
.
origin_startup_program
def
set_origin_ps_main_program
(
self
,
program
):
self
.
origin_ps_main_program
=
program
def
set_origin_ps_startup_program
(
self
,
program
):
self
.
origin_ps_startup_program
=
program
def
get_origin_ps_main_program
(
self
):
return
self
.
origin_ps_main_program
def
get_origin_ps_startup_program
(
self
):
return
self
.
origin_ps_startup_program
def
add_tensor_table
(
self
,
feed_var_name
,
fetch_var_name
=
""
,
startup_program
=
None
,
main_program
=
None
,
tensor_table_class
=
""
,
):
self
.
tensor_table_dict
[
feed_var_name
]
=
{}
self
.
tensor_table_dict
[
feed_var_name
][
"feed_var_name"
]
=
feed_var_name
self
.
tensor_table_dict
[
feed_var_name
][
"fetch_var_name"
]
=
fetch_var_name
self
.
tensor_table_dict
[
feed_var_name
][
"startup_program"
]
=
startup_program
self
.
tensor_table_dict
[
feed_var_name
][
"main_program"
]
=
main_program
self
.
tensor_table_dict
[
feed_var_name
][
"tensor_table_class"
]
=
tensor_table_class
def
get_tensor_table_dict
(
self
):
return
self
.
tensor_table_dict
def
get_sparse_varname_on_ps
(
self
,
is_distributed
,
endpoint
=
None
):
if
not
endpoint
:
endpoint
=
self
.
get_ps_endpoint
()
varnames
=
get_sparse_tablenames
(
self
.
get_origin_main_program
(),
is_distributed
)
ps_sparse_varnames
=
[]
for
varname
in
varnames
:
tables
=
self
.
get_var_distributed
(
varname
,
True
)
for
i
in
range
(
len
(
tables
)):
table
,
ep
,
_
=
tables
[
i
]
if
ep
==
endpoint
:
ps_sparse_varnames
.
append
(
table
)
return
ps_sparse_varnames
def
get_optimize_varname_on_ps
(
self
,
param_name
):
origin_param_name
,
_
,
_
=
_get_varname_parts
(
param_name
)
optimize_var_names
=
[]
for
op
in
self
.
get_origin_main_program
().
global_block
().
ops
:
# check all optimizer op
if
int
(
op
.
all_attrs
()[
"op_role"
])
==
2
:
# check param name
if
op
.
input
(
"Param"
)[
0
]
!=
origin_param_name
:
continue
# check all input
for
key
in
op
.
input_names
:
if
key
in
[
"Param"
,
"Grad"
,
"LearningRate"
,
"Beta1Tensor"
,
"Beta2Tensor"
,
]:
continue
# check varibale shape related param, e.g: Moment1
optimize_var_names
+=
(
self
.
_get_optimizer_param_related_var_name
(
op
,
op
.
type
,
key
)
)
return
optimize_var_names
def
_get_optimizer_param_related_var_name
(
self
,
op
,
op_type
,
varkey
):
"""
Returns the names for optimizer inputs that need to be load
"""
related_var_names
=
[]
if
op_type
==
"adam"
:
if
varkey
in
[
"Moment1"
,
"Moment2"
]:
related_var_names
.
append
(
op
.
input
(
varkey
)[
0
])
elif
op_type
==
"adagrad"
:
if
varkey
==
"Moment"
:
related_var_names
.
append
(
op
.
input
(
varkey
)[
0
])
elif
op_type
in
[
"momentum"
,
"lars_momentum"
]:
if
varkey
==
"Velocity"
:
related_var_names
.
append
(
op
.
input
(
varkey
)[
0
])
elif
op_type
==
"rmsprop"
:
if
varkey
in
[
"Moment"
,
"MeanSquare"
]:
related_var_names
.
append
(
op
.
input
(
varkey
)[
0
])
elif
op_type
==
"ftrl"
:
if
varkey
in
[
"SquaredAccumulator"
,
"LinearAccumulator"
]:
related_var_names
.
append
(
op
.
input
(
varkey
)[
0
])
elif
op_type
==
"sgd"
:
pass
else
:
raise
ValueError
(
"Not supported optimizer for distributed training: %s"
%
op_type
)
return
related_var_names
def
build_ctx
(
self
,
vars
,
mapping
,
is_grad
,
is_sparse
,
is_send
,
is_distributed
=
False
):
def
get_grad_var_ep
(
slices
):
names
=
[]
eps
=
[]
sections
=
[]
for
slice
in
slices
:
if
self
.
is_geo_mode
():
if
is_send
:
names
.
append
(
"{}.delta"
.
format
(
slice
.
name
))
else
:
names
.
append
(
slice
.
name
)
elif
(
is_grad
and
self
.
is_sync_mode
()
and
self
.
get_trainers
()
>
1
):
names
.
append
(
"{}.trainer_{}"
.
format
(
slice
.
name
,
self
.
get_role_id
())
)
else
:
names
.
append
(
slice
.
name
)
sections
.
append
(
slice
.
shape
[
0
])
for
ep
,
pairs
in
self
.
param_grad_ep_mapping
.
items
():
params
,
grads
=
pairs
[
"params"
],
pairs
[
"grads"
]
for
var
in
params
+
grads
:
if
slice
.
name
==
var
.
name
:
eps
.
append
(
ep
)
break
return
names
,
eps
,
sections
if
isinstance
(
vars
,
MergedVariable
):
name
=
vars
.
merged_var
.
name
slices
=
mapping
[
name
]
names
,
eps
,
sections
=
get_grad_var_ep
(
slices
)
origin_varnames
=
[
var
.
name
for
var
in
vars
.
ordered_vars
]
else
:
name
=
vars
.
name
slices
=
mapping
[
name
]
names
,
eps
,
sections
=
get_grad_var_ep
(
slices
)
origin_varnames
=
[
vars
.
name
]
trainer_id
=
self
.
get_role_id
()
aggregate
=
True
ctx
=
core
.
CommContext
(
name
,
names
,
eps
,
sections
,
origin_varnames
,
trainer_id
,
aggregate
,
is_sparse
,
is_distributed
,
[],
)
return
ctx
def
get_trainer_send_context
(
self
):
send_ctx
=
{}
distibuted_varnames
=
get_sparse_tablenames
(
self
.
origin_main_program
,
True
)
idx
=
0
if
not
self
.
is_geo_mode
():
for
merged
in
self
.
merged_dense_pairs
:
grad
=
merged
[
1
]
ctx
=
self
.
build_ctx
(
grad
,
self
.
grad_var_mapping
,
True
,
False
,
True
)
send_ctx
[
ctx
.
var_name
()]
=
ctx
for
merged
in
self
.
merged_sparse_pairs
:
param
=
merged
[
0
]
grad
=
merged
[
1
]
param_name
=
param
.
merged_var
.
name
is_distributed
=
(
True
if
param_name
in
distibuted_varnames
else
False
)
ctx
=
self
.
build_ctx
(
grad
,
self
.
grad_var_mapping
,
True
,
True
,
True
,
is_distributed
,
)
send_ctx
[
ctx
.
var_name
()]
=
ctx
idx
+=
1
if
self
.
is_async_mode
():
name
,
ctx
=
self
.
_step_ctx
(
idx
)
send_ctx
[
name
]
=
ctx
else
:
for
pairs
in
self
.
origin_sparse_pairs
:
param
,
grad
=
pairs
param_name
=
param
.
name
is_distributed
=
(
True
if
param_name
in
distibuted_varnames
else
False
)
param_ctx
=
self
.
build_ctx
(
param
,
self
.
param_var_mapping
,
False
,
True
,
True
,
is_distributed
,
)
grad_ctx
=
self
.
build_ctx
(
grad
,
self
.
grad_var_mapping
,
True
,
True
,
True
,
is_distributed
,
)
ctx
=
core
.
CommContext
(
param_ctx
.
var_name
(),
param_ctx
.
split_varnames
(),
param_ctx
.
split_endpoints
(),
param_ctx
.
sections
(),
grad_ctx
.
origin_varnames
(),
param_ctx
.
trainer_id
(),
param_ctx
.
aggregate
(),
param_ctx
.
is_sparse
(),
param_ctx
.
is_distributed
(),
[],
)
send_ctx
[
ctx
.
var_name
()]
=
ctx
idx
+=
1
name
,
ctx
=
self
.
_step_ctx
(
idx
)
send_ctx
[
name
]
=
ctx
return
send_ctx
def
get_communicator_send_context
(
self
):
send_ctx
=
{}
distibuted_varnames
=
get_sparse_tablenames
(
self
.
origin_main_program
,
True
)
idx
=
0
if
self
.
is_geo_mode
():
for
pairs
in
self
.
merged_dense_pairs
:
param
=
pairs
[
0
]
ctx
=
self
.
build_ctx
(
param
,
self
.
param_var_mapping
,
False
,
False
,
True
)
send_ctx
[
ctx
.
var_name
()]
=
ctx
for
pairs
in
self
.
merged_sparse_pairs
:
param
=
pairs
[
0
]
param_name
=
param
.
merged_var
.
name
is_distributed
=
(
True
if
param_name
in
distibuted_varnames
else
False
)
ctx
=
self
.
build_ctx
(
param
,
self
.
param_var_mapping
,
False
,
True
,
True
,
is_distributed
,
)
send_ctx
[
ctx
.
var_name
()]
=
ctx
idx
+=
1
name
,
ctx
=
self
.
_step_ctx
(
idx
)
send_ctx
[
name
]
=
ctx
else
:
for
merged
in
self
.
merged_dense_pairs
:
grad
=
merged
[
1
]
ctx
=
self
.
build_ctx
(
grad
,
self
.
grad_var_mapping
,
True
,
False
,
True
)
send_ctx
[
ctx
.
var_name
()]
=
ctx
for
merged
in
self
.
merged_sparse_pairs
:
param
,
grad
=
merged
param_name
=
param
.
merged_var
.
name
is_distributed
=
(
True
if
param_name
in
distibuted_varnames
else
False
)
ctx
=
self
.
build_ctx
(
grad
,
self
.
grad_var_mapping
,
True
,
True
,
True
,
is_distributed
,
)
send_ctx
[
ctx
.
var_name
()]
=
ctx
idx
+=
1
name
,
ctx
=
self
.
_step_ctx
(
idx
)
send_ctx
[
name
]
=
ctx
return
send_ctx
def
get_communicator_recv_context
(
self
,
recv_type
=
1
,
use_origin_program
=
False
):
# recv_type
# 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL
distibuted_varnames
=
get_sparse_tablenames
(
self
.
origin_main_program
,
True
)
sparse_varnames
=
[]
for
pairs
in
self
.
origin_sparse_pairs
:
param
,
grad
=
pairs
sparse_varnames
.
append
(
param
.
name
)
dense_recv_ctx
=
{}
sparse_recv_ctx
=
{}
distributed_recv_ctx
=
{}
variables_pairs
=
(
self
.
merged_variables_pairs
if
not
use_origin_program
else
self
.
origin_merged_variables_pairs
)
for
merged
in
variables_pairs
:
params
=
merged
[
0
]
if
params
.
merged_var
.
name
in
sparse_varnames
:
continue
ctx
=
self
.
build_ctx
(
params
,
self
.
param_var_mapping
,
False
,
False
,
False
,
False
)
dense_recv_ctx
[
ctx
.
var_name
()]
=
ctx
for
pairs
in
self
.
origin_sparse_pairs
:
param
,
grad
=
pairs
if
param
.
name
in
distibuted_varnames
:
ctx
=
self
.
build_ctx
(
param
,
self
.
param_var_mapping
,
False
,
True
,
False
,
True
)
distributed_recv_ctx
[
ctx
.
var_name
()]
=
ctx
else
:
ctx
=
self
.
build_ctx
(
param
,
self
.
param_var_mapping
,
False
,
True
,
False
,
False
)
sparse_recv_ctx
[
ctx
.
var_name
()]
=
ctx
if
recv_type
==
1
:
return
dense_recv_ctx
if
recv_type
==
2
:
return
sparse_recv_ctx
if
recv_type
==
3
:
return
distributed_recv_ctx
if
recv_type
==
4
:
dense_recv_ctx
.
update
(
sparse_recv_ctx
)
dense_recv_ctx
.
update
(
distributed_recv_ctx
)
return
dense_recv_ctx
assert
ValueError
(
"recv_type can only be 1/2/3/4, 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL"
)
def
get_the_one_trainer_send_context
(
self
,
split_dense_table
):
if
self
.
is_geo_mode
():
send_ctx
=
{}
trainer_id
=
self
.
get_role_id
()
idx
=
0
distibuted_varnames
=
get_sparse_tablenames
(
self
.
origin_main_program
,
True
)
for
merged
in
self
.
merged_sparse_pairs
:
param
,
grad
=
merged
grad_name
=
grad
.
merged_var
.
name
param_name
=
param
.
merged_var
.
name
is_distributed
=
(
True
if
param_name
in
distibuted_varnames
else
False
)
var
=
self
.
origin_main_program
.
global_block
().
vars
[
grad
.
merged_var
.
name
]
var_numel
=
reduce
(
lambda
x
,
y
:
x
*
y
,
var
.
shape
[
1
:])
sparse_ctx
=
core
.
CommContext
(
grad_name
,
[
grad_name
],
[
"127.0.0.1:6071"
],
[
var_numel
],
[
grad_name
],
trainer_id
,
True
,
True
,
is_distributed
,
idx
,
False
,
False
,
-
1
,
[],
)
idx
+=
1
send_ctx
[
sparse_ctx
.
var_name
()]
=
sparse_ctx
if
len
(
send_ctx
)
==
0
:
raise
ValueError
(
"GeoSGD require sparse parameters in your net."
)
if
len
(
self
.
tensor_table_dict
)
>
0
and
self
.
role_maker
.
_is_worker
():
name
,
ctx
=
self
.
_step_ctx
(
idx
)
send_ctx
[
name
]
=
ctx
return
send_ctx
else
:
return
self
.
get_the_one_send_context
(
split_dense_table
)
def
get_dense_send_context
(
self
,
send_ctx
,
idx
,
merged_dense_pairs
,
trainer_id
,
split_dense_table
=
False
,
):
if
len
(
merged_dense_pairs
)
<
1
:
return
idx
if
not
split_dense_table
:
origin_varnames
=
[]
var_numel
=
0
for
merged
in
merged_dense_pairs
:
grad
=
merged
[
1
]
origin_varnames
.
append
(
grad
.
merged_var
.
name
)
var
=
self
.
origin_main_program
.
global_block
().
vars
[
grad
.
merged_var
.
name
]
var_numel
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
var
.
shape
)
grad_name
=
"Dense@Grad"
trainer_id
=
self
.
get_role_id
()
aggregate
=
True
dense_ctx
=
core
.
CommContext
(
grad_name
,
[
grad_name
],
[
"127.0.0.1:6071"
],
[
var_numel
],
origin_varnames
,
trainer_id
,
aggregate
,
False
,
False
,
idx
,
False
,
False
,
-
1
,
[],
)
send_ctx
[
grad_name
]
=
dense_ctx
idx
+=
1
else
:
for
merged
in
merged_dense_pairs
:
grad
=
merged
[
1
]
origin_varname
=
grad
.
merged_var
.
name
var
=
self
.
origin_main_program
.
global_block
().
vars
[
origin_varname
]
var_numel
=
reduce
(
lambda
x
,
y
:
x
*
y
,
var
.
shape
)
grad_name
=
origin_varname
aggregate
=
True
dense_ctx
=
core
.
CommContext
(
grad_name
,
[
grad_name
],
[
"127.0.0.1:6071"
],
[
var_numel
],
[
origin_varname
],
trainer_id
,
aggregate
,
False
,
False
,
idx
,
False
,
False
,
-
1
,
[],
)
send_ctx
[
grad_name
]
=
dense_ctx
idx
+=
1
return
idx
def
get_the_one_send_context
(
self
,
split_dense_table
=
False
,
use_origin_program
=
False
,
ep_list
=
None
):
if
ep_list
is
None
:
ep_list
=
[
"127.0.0.1:6071"
]
send_ctx
=
{}
trainer_id
=
self
.
get_role_id
()
idx
=
0
merged_dense_pairs
=
(
self
.
origin_merged_dense_pairs
if
use_origin_program
else
self
.
merged_dense_pairs
)
merged_sparse_pairs
=
(
self
.
origin_merged_sparse_pairs
if
use_origin_program
else
self
.
merged_sparse_pairs
)
idx
+=
self
.
get_dense_send_context
(
send_ctx
,
idx
,
merged_dense_pairs
,
trainer_id
,
split_dense_table
)
distibuted_varnames
=
get_sparse_tablenames
(
self
.
origin_main_program
,
True
)
for
merged
in
merged_sparse_pairs
:
param
,
grad
=
merged
grad_name
=
grad
.
merged_var
.
name
param_name
=
param
.
merged_var
.
name
splited_varname
=
[]
for
i
in
range
(
len
(
ep_list
)):
splited_varname
.
append
(
"{}.block{}"
.
format
(
param_name
,
i
))
is_distributed
=
(
True
if
param_name
in
distibuted_varnames
else
False
)
var
=
self
.
origin_main_program
.
global_block
().
vars
[
grad
.
merged_var
.
name
]
shape
=
list
(
var
.
shape
)
shape
[
0
]
=
0
if
is_distributed
else
shape
[
0
]
sparse_ctx
=
core
.
CommContext
(
grad_name
,
splited_varname
,
ep_list
,
shape
,
[
grad_name
],
trainer_id
,
True
,
True
,
is_distributed
,
idx
,
False
,
False
,
-
1
,
[],
)
idx
+=
1
send_ctx
[
sparse_ctx
.
var_name
()]
=
sparse_ctx
if
len
(
self
.
tensor_table_dict
)
>
0
and
self
.
role_maker
.
_is_worker
():
name
,
ctx
=
self
.
_step_ctx
(
idx
)
send_ctx
[
name
]
=
ctx
return
send_ctx
def
get_the_one_recv_context
(
self
,
is_dense
=
True
,
split_dense_table
=
False
,
use_origin_program
=
False
):
recv_id_maps
=
{}
if
is_dense
:
send_ctx
=
self
.
get_the_one_send_context
(
split_dense_table
=
split_dense_table
,
use_origin_program
=
use_origin_program
,
)
for
idx
,
(
name
,
ctx
)
in
enumerate
(
send_ctx
.
items
()):
if
ctx
.
is_sparse
():
continue
if
ctx
.
is_tensor_table
():
continue
origin_grad_varnames
=
ctx
.
origin_varnames
()
param_names
=
[]
for
grad_varname
in
origin_grad_varnames
:
param_name
=
self
.
grad_name_to_param_name
[
grad_varname
]
param_names
.
append
(
param_name
)
recv_id_maps
[
ctx
.
table_id
()]
=
param_names
else
:
send_ctx
=
self
.
get_the_one_send_context
()
for
idx
,
(
name
,
ctx
)
in
enumerate
(
send_ctx
.
items
()):
if
not
ctx
.
is_sparse
():
continue
origin_grad_varnames
=
ctx
.
origin_varnames
()
param_names
=
[]
for
grad_varname
in
origin_grad_varnames
:
param_name
=
self
.
grad_name_to_param_name
[
grad_varname
]
param_names
.
append
(
param_name
)
recv_id_maps
[
ctx
.
table_id
()]
=
param_names
return
recv_id_maps
def
get_server_runtime_config
(
self
):
return
self
.
strategy
.
get_server_runtime_config
()
def
get_var_distributed
(
self
,
varname
,
is_param
):
var_distributed
=
[]
offset
=
0
if
is_param
:
params
=
self
.
param_var_mapping
[
varname
]
param_varnames
=
[
var
.
name
for
var
in
params
]
for
ep
,
pairs
in
self
.
param_grad_ep_mapping
.
items
():
for
p
in
pairs
[
"params"
]:
if
p
.
name
in
param_varnames
:
offset
+=
p
.
shape
[
0
]
var_distributed
.
append
((
p
.
name
,
ep
,
p
.
shape
[
0
]))
else
:
grads
=
self
.
grad_var_mapping
[
varname
]
grad_varnames
=
[
var
.
name
for
var
in
grads
]
for
ep
,
pairs
in
self
.
param_grad_ep_mapping
.
items
():
for
g
in
pairs
[
"grads"
]:
if
g
.
name
in
grad_varnames
:
var_distributed
.
append
((
g
.
name
,
ep
,
g
.
shape
[
0
]))
return
var_distributed
def
_step_ctx
(
self
,
idx
):
name
=
STEP_COUNTER
trainer_id
=
self
.
get_role_id
()
endpoints
=
self
.
get_ps_endpoints
()
sections
=
[
1
]
*
len
(
endpoints
)
names
=
[
name
]
*
len
(
endpoints
)
ctx
=
core
.
CommContext
(
name
,
names
,
endpoints
,
sections
,
[
name
],
trainer_id
,
True
,
False
,
False
,
idx
,
True
,
False
,
-
1
,
[],
)
return
name
,
ctx
def
_create_vars_from_blocklist
(
self
,
block_list
):
"""
Create vars for each split.
NOTE: only grads need to be named for different trainers, use
add_trainer_suffix to rename the grad vars.
Args:
block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
Returns:
var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping
from original var name to each var split.
"""
# varname->[(block_id, current_block_size)]
block_map
=
collections
.
OrderedDict
()
var_mapping
=
collections
.
OrderedDict
()
for
block_str
in
block_list
:
varname
,
offset
,
size
=
block_str
.
split
(
":"
)
if
varname
not
in
block_map
:
block_map
[
varname
]
=
[]
block_map
[
varname
].
append
((
int
(
offset
),
int
(
size
)))
for
varname
,
split
in
block_map
.
items
():
orig_var
=
self
.
merged_variable_map
[
varname
]
if
len
(
split
)
==
1
:
var_mapping
[
varname
]
=
[
orig_var
]
self
.
var_distributed
.
add_distributed_var
(
origin_var
=
orig_var
,
slice_var
=
orig_var
,
block_id
=
0
,
offset
=
0
,
is_slice
=
False
,
vtype
=
"Param"
,
)
else
:
var_mapping
[
varname
]
=
[]
orig_shape
=
orig_var
.
shape
orig_dim1_flatten
=
1
if
len
(
orig_shape
)
>=
2
:
orig_dim1_flatten
=
reduce
(
lambda
x
,
y
:
x
*
y
,
orig_shape
[
1
:]
)
for
i
,
block
in
enumerate
(
split
):
size
=
block
[
1
]
rows
=
size
//
orig_dim1_flatten
splited_shape
=
[
rows
]
if
len
(
orig_shape
)
>=
2
:
splited_shape
.
extend
(
orig_shape
[
1
:])
new_var_name
=
"%s.block%d"
%
(
varname
,
i
)
slice_var
=
vars_metatools
.
VarStruct
(
name
=
new_var_name
,
shape
=
splited_shape
,
dtype
=
orig_var
.
dtype
,
type
=
orig_var
.
type
,
lod_level
=
orig_var
.
lod_level
,
persistable
=
False
,
)
var_mapping
[
varname
].
append
(
slice_var
)
self
.
var_distributed
.
add_distributed_var
(
origin_var
=
orig_var
,
slice_var
=
slice_var
,
block_id
=
i
,
offset
=-
1
,
is_slice
=
False
,
vtype
=
"Param"
,
)
return
var_mapping
def
_dispatcher
(
self
):
ps_dispatcher
=
RoundRobin
(
self
.
get_ps_endpoints
())
ps_dispatcher
.
reset
()
grad_var_mapping_items
=
list
(
self
.
grad_var_mapping
.
items
())
sparse_gradnames
=
[
grad
.
name
for
_
,
grad
in
self
.
origin_sparse_pairs
]
for
grad_varname
,
splited_vars
in
grad_var_mapping_items
:
if
grad_varname
in
sparse_gradnames
:
continue
send_vars
=
[]
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
recv_vars
=
[]
for
_
,
var
in
enumerate
(
send_vars
):
recv_vars
.
append
(
self
.
grad_param_mapping
[
var
])
eps
=
ps_dispatcher
.
dispatch
(
recv_vars
)
for
i
,
ep
in
enumerate
(
eps
):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
for
grad_varname
,
splited_vars
in
grad_var_mapping_items
:
if
grad_varname
not
in
sparse_gradnames
:
continue
ps_dispatcher
.
reset
()
send_vars
=
[]
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
recv_vars
=
[]
for
_
,
var
in
enumerate
(
send_vars
):
recv_vars
.
append
(
self
.
grad_param_mapping
[
var
])
eps
=
ps_dispatcher
.
dispatch
(
recv_vars
)
for
i
,
ep
in
enumerate
(
eps
):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
def
_slice_variable
(
self
,
var_list
,
slice_count
,
min_block_size
,
uniform
=
False
):
"""
We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor
aligned by dim[0] of the tensor.
We need to have a minimal block size so that the calculations in
the parameter server side can gain better performance. By default
minimum block size 8K elements (maybe 16bit or 32bit or 64bit).
Args:
var_list (list): List of variables.
slice_count (int): Numel of count that variables will be sliced, which
could be the pserver services' count.
min_block_size (int): Minimum split block size.
Returns:
blocks (list[(varname, block_id, current_block_size)]): A list
of VarBlocks. Each VarBlock specifies a shard of the var.
"""
blocks
=
[]
for
var
in
var_list
:
if
not
uniform
:
var_numel
=
reduce
(
lambda
x
,
y
:
x
*
y
,
var
.
shape
)
split_count
=
1
if
min_block_size
==
-
1
:
split_count
=
1
else
:
split_count
=
slice_count
max_pserver_count
=
int
(
math
.
floor
(
var_numel
/
float
(
min_block_size
))
)
if
max_pserver_count
==
0
:
max_pserver_count
=
1
if
max_pserver_count
<
slice_count
:
split_count
=
max_pserver_count
block_size
=
int
(
math
.
ceil
(
var_numel
/
float
(
split_count
)))
if
len
(
var
.
shape
)
>=
2
:
# align by dim1(width)
dim1
=
reduce
(
lambda
x
,
y
:
x
*
y
,
var
.
shape
[
1
:])
remains
=
block_size
%
dim1
if
remains
!=
0
:
block_size
+=
dim1
-
remains
# update split_count after aligning
split_count
=
int
(
math
.
ceil
(
var_numel
/
float
(
block_size
)))
for
block_id
in
range
(
split_count
):
curr_block_size
=
min
(
block_size
,
var_numel
-
((
block_id
)
*
block_size
)
)
block
=
vars_metatools
.
VarBlock
(
var
.
name
,
block_id
,
curr_block_size
)
blocks
.
append
(
str
(
block
))
else
:
block_size
=
var
.
shape
[
0
]
/
slice_count
remainder
=
var
.
shape
[
0
]
%
slice_count
if
block_size
==
0
:
dim0s
=
[
block_size
]
*
remainder
else
:
dim0s
=
[
block_size
]
*
slice_count
for
i
in
range
(
remainder
):
dim0s
[
i
]
=
dim0s
[
i
]
+
1
dim1
=
reduce
(
lambda
x
,
y
:
x
*
y
,
var
.
shape
[
1
:])
for
block_id
in
range
(
len
(
dim0s
)):
numel
=
dim0s
[
block_id
]
*
dim1
block
=
vars_metatools
.
VarBlock
(
var
.
name
,
block_id
,
numel
)
blocks
.
append
(
str
(
block
))
return
blocks
def
_get_param_grad_blocks
(
self
,
pairs
,
min_block_size
,
uniform
=
False
):
param_list
=
[]
grad_list
=
[]
param_grad_set
=
set
()
for
p
,
g
in
pairs
:
# todo(tangwei12) skip parameter marked not trainable
# if type(p) == Parameter and p.trainable == False:
# continue
p
=
p
.
merged_var
g
=
g
.
merged_var
if
p
.
name
not
in
param_grad_set
:
param_list
.
append
(
p
)
param_grad_set
.
add
(
p
.
name
)
if
g
.
name
not
in
param_grad_set
:
grad_list
.
append
(
g
)
param_grad_set
.
add
(
g
.
name
)
# when we slice var up into blocks, we will slice the var according to
# pserver services' count. A pserver may have two or more listening ports.
grad_blocks
=
self
.
_slice_variable
(
grad_list
,
len
(
self
.
get_ps_endpoints
()),
min_block_size
,
uniform
)
param_blocks
=
self
.
_slice_variable
(
param_list
,
len
(
self
.
get_ps_endpoints
()),
min_block_size
,
uniform
)
return
param_blocks
,
grad_blocks
def
_var_slice_and_distribute
(
self
):
# update these mappings for further transpile:
# 1. param_var_mapping : param var name->[split params vars]
# 2. grad_var_mapping : grad var name->[split grads vars]
# 3. grad_param_mapping : grad.blockx->param.blockx
# 4. param_grad_ep_mapping : ep->{"params" : [], "grads" : [] }
dps
,
dgs
=
self
.
_get_param_grad_blocks
(
self
.
merged_dense_pairs
,
self
.
min_block_size
,
False
)
sps
,
sgs
=
self
.
_get_param_grad_blocks
(
self
.
merged_sparse_pairs
,
self
.
min_block_size
,
True
)
param_blocks
=
dps
+
sps
grad_blocks
=
dgs
+
sgs
assert
len
(
grad_blocks
)
==
len
(
param_blocks
)
# origin_param_name->[splited_param_vars]
self
.
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
param_blocks
)
self
.
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
grad_blocks
)
# dict(grad_splited_var->param_splited_var)
self
.
grad_param_mapping
=
collections
.
OrderedDict
()
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
self
.
grad_param_mapping
[
self
.
grad_var_mapping
[
g_name
][
int
(
g_bid
)]
]
=
self
.
param_var_mapping
[
p_name
][
int
(
p_bid
)]
print_maps
=
{}
for
k
,
v
in
self
.
grad_param_mapping
.
items
():
print_maps
[
str
(
k
)]
=
str
(
v
)
# create mapping of endpoint->split var to create pserver side program
self
.
param_grad_ep_mapping
=
collections
.
OrderedDict
()
[
self
.
param_grad_ep_mapping
.
update
({
ep
:
{
"params"
:
[],
"grads"
:
[]}})
for
ep
in
self
.
get_ps_endpoints
()
]
def
_build_var_distributed
(
self
):
self
.
var_distributed
=
vars_metatools
.
VarsDistributed
()
sparse_pairs
,
dense_pairs
=
self
.
get_param_grads
()
origin_for_sparse
=
[]
origin_for_dense
=
[]
param_name_grad_name
=
dict
()
grad_name_to_param_name
=
dict
()
for
param
,
grad
in
sparse_pairs
:
param
=
vars_metatools
.
create_var_struct
(
param
)
grad
=
vars_metatools
.
create_var_struct
(
grad
)
origin_for_sparse
.
append
((
param
,
grad
))
for
param
,
grad
in
dense_pairs
:
param
=
vars_metatools
.
create_var_struct
(
param
)
grad
=
vars_metatools
.
create_var_struct
(
grad
)
origin_for_dense
.
append
((
param
,
grad
))
for
dense_pair
in
origin_for_dense
:
param
,
grad
=
dense_pair
m_param
=
MergedVariable
(
param
,
[
param
],
[
0
])
m_grad
=
MergedVariable
(
grad
,
[
grad
],
[
0
])
self
.
merged_variables_pairs
.
append
((
m_param
,
m_grad
))
self
.
merged_dense_pairs
.
append
((
m_param
,
m_grad
))
for
sparse_pair
in
origin_for_sparse
:
param
,
grad
=
sparse_pair
m_param
=
MergedVariable
(
param
,
[
param
],
[
0
])
m_grad
=
MergedVariable
(
grad
,
[
grad
],
[
0
])
self
.
merged_variables_pairs
.
append
((
m_param
,
m_grad
))
self
.
merged_sparse_pairs
.
append
((
m_param
,
m_grad
))
for
merged
in
self
.
merged_variables_pairs
:
m_param
,
m_grad
=
merged
self
.
merged_variable_map
[
m_param
.
merged_var
.
name
]
=
m_param
.
merged_var
self
.
merged_variable_map
[
m_grad
.
merged_var
.
name
]
=
m_grad
.
merged_var
param_merges
=
[]
param_merges
.
extend
(
origin_for_sparse
)
param_merges
.
extend
(
origin_for_dense
)
for
param
,
grad
in
param_merges
:
param_name_grad_name
[
param
.
name
]
=
grad
.
name
grad_name_to_param_name
[
grad
.
name
]
=
param
.
name
self
.
origin_sparse_pairs
=
origin_for_sparse
self
.
origin_dense_pairs
=
origin_for_dense
self
.
param_name_to_grad_name
=
param_name_grad_name
self
.
grad_name_to_param_name
=
grad_name_to_param_name
sparse_pair_map
=
collections
.
OrderedDict
()
for
pair
in
self
.
origin_sparse_pairs
+
self
.
origin_dense_pairs
:
param
,
grad
=
pair
sparse_pair_map
[
param
.
name
]
=
str
(
param
)
sparse_pair_map
[
grad
.
name
]
=
str
(
grad
)
self
.
_var_slice_and_distribute
()
self
.
_dispatcher
()
def
get_param_grads
(
self
):
origin_program
=
self
.
origin_main_program
def
_get_params_grads
(
sparse_varnames
):
block
=
origin_program
.
global_block
()
dense_param_grads
=
[]
sparse_param_grads
=
[]
optimize_params
=
set
()
origin_var_dict
=
origin_program
.
global_block
().
vars
role_id
=
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
)
for
op
in
block
.
ops
:
if
_is_opt_role_op
(
op
):
# delete clip op from opt_ops when run in Parameter Server mode
if
(
OP_NAME_SCOPE
in
op
.
all_attrs
()
and
CLIP_OP_NAME_SCOPE
in
op
.
attr
(
OP_NAME_SCOPE
)
):
op
.
_set_attr
(
"op_role"
,
role_id
)
continue
if
op
.
attr
(
OP_ROLE_VAR_ATTR_NAME
):
param_name
=
op
.
attr
(
OP_ROLE_VAR_ATTR_NAME
)[
0
]
grad_name
=
op
.
attr
(
OP_ROLE_VAR_ATTR_NAME
)[
1
]
if
param_name
not
in
optimize_params
:
optimize_params
.
add
(
param_name
)
param_grad
=
(
origin_var_dict
[
param_name
],
origin_var_dict
[
grad_name
],
)
if
param_name
in
sparse_varnames
:
sparse_param_grads
.
append
(
param_grad
)
else
:
dense_param_grads
.
append
(
param_grad
)
return
sparse_param_grads
,
dense_param_grads
def
_get_sparse_varnames
():
varnames
=
[]
for
op
in
origin_program
.
global_block
().
ops
:
if
(
op
.
type
in
SPARSE_OP_TYPE_DICT
.
keys
()
and
op
.
attr
(
'remote_prefetch'
)
is
True
):
param_name
=
op
.
input
(
SPARSE_OP_TYPE_DICT
[
op
.
type
])[
0
]
varnames
.
append
(
param_name
)
return
list
(
set
(
varnames
))
sparse_varnames
=
_get_sparse_varnames
()
sparse_param_grads
,
dense_param_grads
=
_get_params_grads
(
sparse_varnames
)
return
sparse_param_grads
,
dense_param_grads
def
remove_var_pair_by_grad
(
self
,
var_name
):
for
index
,
pair
in
enumerate
(
self
.
merged_variables_pairs
):
var
=
pair
[
0
]
var_grad
=
pair
[
1
]
if
var_grad
.
merged_var
.
name
==
var_name
:
del
self
.
merged_variables_pairs
[
index
]
for
index
,
pair
in
enumerate
(
self
.
merged_dense_pairs
):
var
=
pair
[
0
]
var_grad
=
pair
[
1
]
if
var_grad
.
merged_var
.
name
==
var_name
:
del
self
.
merged_dense_pairs
[
index
]
return
for
index
,
pair
in
enumerate
(
self
.
merged_sparse_pairs
):
var
=
pair
[
0
]
var_grad
=
pair
[
1
]
if
var_grad
.
merged_var
.
name
==
var_name
:
del
self
.
merged_sparse_pairs
[
index
]
return
print
(
"Not find {} in self.merge_pairs"
.
format
(
var_name
))
def
_is_opt_role_op
(
op
):
# NOTE : depend on oprole to find out whether this op is for
# optimize
op_maker
=
core
.
op_proto_and_checker_maker
optimize_role
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
if
op_maker
.
kOpRoleAttrName
()
in
op
.
attr_names
and
int
(
op
.
all_attrs
()[
op_maker
.
kOpRoleAttrName
()]
)
==
int
(
optimize_role
):
return
True
return
False
def
_get_optimize_ops
(
_program
):
block
=
_program
.
global_block
()
opt_ops
=
[]
for
op
in
block
.
ops
:
if
_is_opt_role_op
(
op
):
# delete clip op from opt_ops when run in Parameter Server mode
if
(
OP_NAME_SCOPE
in
op
.
all_attrs
()
and
CLIP_OP_NAME_SCOPE
in
op
.
attr
(
OP_NAME_SCOPE
)
):
op
.
_set_attr
(
"op_role"
,
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
),
)
continue
opt_ops
.
append
(
op
)
return
opt_ops
def
_add_lr_decay_table_pass
(
main_program
,
compiled_config
,
lr_decay_steps
):
if
hasattr
(
compiled_config
.
origin_main_program
,
'lr_sheduler'
):
from
paddle.optimizer.lr
import
LRScheduler
assert
isinstance
(
compiled_config
.
origin_main_program
.
lr_sheduler
,
LRScheduler
),
"must be LRScheduler"
ops
=
_get_optimize_ops
(
compiled_config
.
origin_main_program
)
lr_param_dict
=
_get_lr_param_dict
(
ops
)
(
lr_decay_main_program
,
lr_decay_startup_program
,
lr_name
,
)
=
_get_lr_sheduler_program
(
compiled_config
.
origin_main_program
.
lr_sheduler
,
lr_param_dict
,
lr_decay_steps
,
)
compiled_config
.
add_tensor_table
(
"@LR_DECAY_COUNTER@"
,
lr_name
,
lr_decay_startup_program
,
lr_decay_main_program
,
"GlobalStepTable"
,
)
def
_get_lr_param_dict
(
opt_ops
):
lr_param_dict
=
{}
for
op
in
opt_ops
:
lr_name
=
op
.
input
(
"LearningRate"
)[
0
]
param_name
=
op
.
input
(
"Param"
)[
0
]
if
lr_name
not
in
lr_param_dict
:
lr_param_dict
[
lr_name
]
=
[]
lr_param_dict
[
lr_name
].
append
(
param_name
)
return
lr_param_dict
def
_get_lr_sheduler_program
(
lr_sheduler
,
lr_param_dict
,
lr_decay_steps
):
schedler_decay
=
[
'NoamDecay'
,
'NaturalExpDecay'
,
'InverseTimeDecay'
,
'ExponentialDecay'
,
]
from
paddle.optimizer.lr
import
(
ExponentialDecay
,
NoamDecay
,
PiecewiseDecay
,
NaturalExpDecay
,
InverseTimeDecay
,
)
from
paddle.static.learning_rate_scheduler
import
(
exponential_decay
,
noam_decay
,
piecewise_decay
,
natural_exp_decay
,
inverse_time_decay
,
)
decay_main_program
=
paddle
.
static
.
Program
()
decay_startup_program
=
paddle
.
static
.
Program
()
lr_name
=
""
if
isinstance
(
lr_sheduler
,
ExponentialDecay
):
with
paddle
.
static
.
program_guard
(
decay_main_program
,
decay_startup_program
):
lr
=
exponential_decay
(
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
)
lr_name
=
lr
.
name
logging
.
warn
(
"ExponentialDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow:
\n
"
"
\t
strategy = paddle.distributed.fleet.DistributedStrategy()
\n
"
"
\t
strategy.a_sync = True
\n
"
"
\t
strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP }
\n
"
%
lr_decay_steps
)
elif
isinstance
(
lr_sheduler
,
NoamDecay
):
with
paddle
.
static
.
program_guard
(
decay_main_program
,
decay_startup_program
):
lr
=
noam_decay
(
lr_sheduler
.
d_model
,
lr_sheduler
.
warmup_steps
,
1.0
)
lr_name
=
lr
.
name
logging
.
warn
(
"NoamDecay is set, warmup steps is [ %d ]"
%
lr_sheduler
.
warmup_steps
)
elif
isinstance
(
lr_sheduler
,
NaturalExpDecay
):
with
paddle
.
static
.
program_guard
(
decay_main_program
,
decay_startup_program
):
lr
=
natural_exp_decay
(
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
)
lr_name
=
lr
.
name
logging
.
warn
(
"NaturalExpDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow:
\n
"
"
\t
strategy = paddle.distributed.fleet.DistributedStrategy()
\n
"
"
\t
strategy.a_sync = True
\n
"
"
\t
strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP }
\n
"
%
lr_decay_steps
)
elif
isinstance
(
lr_sheduler
,
InverseTimeDecay
):
with
paddle
.
static
.
program_guard
(
decay_main_program
,
decay_startup_program
):
lr
=
inverse_time_decay
(
1.0
,
lr_decay_steps
,
lr_sheduler
.
gamma
,
True
)
lr_name
=
lr
.
name
logging
.
warn
(
"InverseTimeDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow:
\n
"
"
\t
strategy = paddle.distributed.fleet.DistributedStrategy()
\n
"
"
\t
strategy.a_sync = True
\n
"
"
\t
strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP }
\n
"
%
lr_decay_steps
)
else
:
raise
ValueError
(
"Not supported current LearningRate strategy, please use follow decay strategy: {}"
.
format
(
schedler_decay
)
)
return
decay_main_program
,
decay_startup_program
,
lr_name
def
_get_varname_parts
(
varname
):
# returns origin, blockid, trainerid
orig_var_name
=
""
trainer_part
=
""
block_part
=
""
trainer_idx
=
varname
.
find
(
".trainer_"
)
if
trainer_idx
>=
0
:
trainer_part
=
varname
[
trainer_idx
+
1
:]
else
:
trainer_idx
=
len
(
varname
)
block_index
=
varname
.
find
(
".block"
)
if
block_index
>=
0
:
block_part
=
varname
[
block_index
+
1
:
trainer_idx
]
else
:
block_index
=
len
(
varname
)
orig_var_name
=
varname
[
0
:
min
(
block_index
,
trainer_idx
)]
return
orig_var_name
,
block_part
,
trainer_part
def
_orig_varname
(
varname
):
orig
,
_
,
_
=
_get_varname_parts
(
varname
)
return
orig
python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
已删除
100644 → 0
浏览文件 @
e84fa263
# -*- coding: UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
collections
import
warnings
import
math
from
functools
import
reduce
import
paddle
from
paddle.framework
import
core
import
paddle.framework
as
framework
from
paddle.distributed.transpiler.details.program_utils
import
delete_ops
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
_get_optimize_ops
,
)
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
_get_lr_ops
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
get_sparse_tablenames
,
)
from
paddle.fluid.incubate.fleet.parameter_server.mode
import
DistributedMode
OP_NAME_SCOPE
=
"op_namescope"
CLIP_OP_NAME_SCOPE
=
"gradient_clip"
STEP_COUNTER
=
"@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
RPC_OP_ROLE_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
LR_SCHED_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
LRSched
OPT_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
SPARSE_OP_TYPE_DICT
=
{
"lookup_table"
:
"W"
,
"lookup_table_v2"
:
"W"
}
SPARSE_GRAD_OP_TYPE_DICT
=
{
"lookup_table_grad"
:
"W"
,
"lookup_table_v2_grad"
:
"W"
,
}
DEVICE_LIST
=
[
"cpu"
,
"gpu"
,
"xpu"
]
COMMUNICATE_OPS_TYPE
=
[
"send"
,
"recv"
,
"fetch_barrier"
,
"send_barrier"
]
DEFAULT_DEVICE
=
'cpu'
def
delete_optimizer_pass
(
program
,
config
):
def
_delete_optimizer_op_and_vars
(
_program
,
optimize_ops
):
optimize_vars
=
[]
optimize_op_role_vars
=
[]
optimize_need_delete_vars
=
[]
for
op
in
optimize_ops
:
optimize_vars
.
extend
(
op
.
input_arg_names
)
optimize_op_role_vars
.
extend
(
op
.
attr
(
"op_role_var"
))
optimize_vars
=
list
(
set
(
optimize_vars
))
optimize_op_role_vars
=
list
(
set
(
optimize_op_role_vars
))
for
var
in
optimize_vars
:
if
var
not
in
optimize_op_role_vars
:
optimize_need_delete_vars
.
append
(
var
)
need_delete_optimize_vars
=
list
(
set
(
optimize_need_delete_vars
))
delete_ops
(
_program
.
global_block
(),
optimize_ops
)
for
var
in
need_delete_optimize_vars
:
if
_program
.
global_block
().
has_var
(
var
):
_program
.
global_block
().
_remove_var
(
var
)
def
_add_lr_var
(
main_program
,
compiled_config
):
# Todo: hard code for pe
lr_var
=
compiled_config
.
origin_main_program
.
global_block
().
vars
[
"learning_rate_0"
]
main_program
.
global_block
().
create_var
(
name
=
lr_var
.
name
,
shape
=
lr_var
.
shape
,
dtype
=
lr_var
.
dtype
,
type
=
lr_var
.
type
,
lod_level
=
lr_var
.
lod_level
,
persistable
=
True
,
)
optimizer_ops
=
_get_optimize_ops
(
program
)
lr_ops
=
_get_lr_ops
(
program
)
optimizer_ops
.
extend
(
lr_ops
)
_delete_optimizer_op_and_vars
(
program
,
optimizer_ops
)
if
hasattr
(
config
.
origin_main_program
,
'lr_sheduler'
):
_add_lr_var
(
program
,
config
)
return
program
def
distributed_ops_pass
(
program
,
config
,
use_ps_gpu
=
False
):
trainer_id
=
config
.
get_role_id
()
send_ctx
=
config
.
get_the_one_send_context
(
split_dense_table
=
config
.
is_heter_ps_mode
)
w_2_table_id
=
{}
emb_size
=
{}
def
_get_pull_sparse_ops
(
_program
):
pull_sparse_ops
=
{}
pull_sparse_ids
=
{}
push_sparse_ops
=
{}
ops
=
{}
for
op
in
_program
.
global_block
().
ops
:
if
(
op
.
type
in
SPARSE_OP_TYPE_DICT
.
keys
()
and
op
.
attr
(
'remote_prefetch'
)
is
True
):
param_name
=
op
.
input
(
SPARSE_OP_TYPE_DICT
[
op
.
type
])[
0
]
if
config
.
is_heter_ps_mode
:
# trick for matchnet, need to modify
param_name
+=
op
.
input
(
"Ids"
)[
0
][
0
]
ops
=
pull_sparse_ops
.
get
(
param_name
,
[])
ops
.
append
(
op
)
pull_sparse_ops
[
param_name
]
=
ops
ids
=
pull_sparse_ids
.
get
(
param_name
,
[])
ids
.
append
(
op
.
input
(
"Ids"
)[
0
])
pull_sparse_ids
[
param_name
]
=
ids
for
op
in
_program
.
global_block
().
ops
:
if
op
.
type
in
SPARSE_GRAD_OP_TYPE_DICT
.
keys
():
param_name
=
op
.
input
(
SPARSE_GRAD_OP_TYPE_DICT
[
op
.
type
])[
0
]
if
(
param_name
in
pull_sparse_ids
and
op
.
input
(
"Ids"
)[
0
]
in
pull_sparse_ids
[
param_name
]
):
ops
=
push_sparse_ops
.
get
(
param_name
,
[])
ops
.
append
(
op
)
push_sparse_ops
[
param_name
]
=
ops
return
pull_sparse_ops
,
push_sparse_ops
def
_pull_sparse_fuse
(
_program
,
pull_sparse_ops
,
use_ps_gpu
):
def
dag_check_up_and_reorder
(
program
,
inputs
,
outputs
):
global_block
=
program
.
global_block
()
min_output_index
=
len
(
global_block
.
ops
)
max_input_index
=
-
1
input_indexes
=
[
0
]
*
len
(
global_block
.
ops
)
output_indexes
=
[
0
]
*
len
(
global_block
.
ops
)
for
idx
,
op
in
enumerate
(
global_block
.
ops
):
for
i
in
range
(
0
,
len
(
op
.
output_names
)):
if
input_indexes
[
idx
]
==
1
:
break
outs
=
op
.
output
(
op
.
output_names
[
i
])
for
in_id
,
in_var
in
enumerate
(
inputs
):
if
in_var
.
name
in
outs
:
input_indexes
[
idx
]
=
1
max_input_index
=
max
(
max_input_index
,
idx
)
break
for
i
in
range
(
0
,
len
(
op
.
input_names
)):
if
output_indexes
[
idx
]
==
1
:
break
ins
=
op
.
input
(
op
.
input_names
[
i
])
for
out_id
,
out_var
in
enumerate
(
outputs
):
if
out_var
.
name
in
ins
:
output_indexes
[
idx
]
=
1
min_output_index
=
min
(
min_output_index
,
idx
)
for
i
in
range
(
len
(
global_block
.
ops
)):
if
input_indexes
[
i
]
==
1
and
output_indexes
[
i
]
==
1
:
warnings
.
warn
(
"unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input"
)
return
if
min_output_index
<
max_input_index
:
move_ops
=
[]
for
i
in
range
(
min_output_index
+
1
,
len
(
input_indexes
)):
if
input_indexes
[
i
]
==
1
:
move_ops
.
append
((
global_block
.
ops
[
i
],
i
))
for
i
,
op
in
enumerate
(
move_ops
):
queue
=
list
()
visited
=
set
()
queue
.
append
(
op
[
1
])
visited
.
add
(
op
[
0
])
start
=
0
while
start
<
len
(
queue
):
pos
=
queue
[
start
]
op
=
global_block
.
ops
[
pos
]
op_inputs
=
[]
for
k
in
range
(
0
,
len
(
op
.
input_names
)):
ins
=
op
.
input
(
op
.
input_names
[
k
])
op_inputs
.
append
(
ins
)
for
j
in
range
(
pos
-
1
,
min_output_index
-
1
,
-
1
):
op1
=
global_block
.
ops
[
j
]
if
op1
in
visited
:
continue
found
=
False
for
k
in
range
(
0
,
len
(
op1
.
output_names
)):
outs
=
op1
.
output
(
op1
.
output_names
[
k
])
for
t
in
range
(
len
(
op_inputs
)):
for
y
in
op_inputs
[
t
]:
if
y
in
outs
:
found
=
True
break
if
found
:
break
if
found
:
break
if
found
:
if
output_indexes
[
j
]
==
True
:
warnings
.
warn
(
"unable to re-arrange dags order to combine distributed embedding ops"
)
return
queue
.
append
(
j
)
visited
.
add
(
global_block
.
ops
[
j
])
start
=
start
+
1
queue
.
sort
()
for
index
in
queue
:
desc
=
global_block
.
desc
.
_insert_op
(
min_output_index
)
desc
.
copy_from
(
global_block
.
ops
[
index
].
desc
)
global_block
.
desc
.
_remove_op
(
index
+
1
,
index
+
2
)
global_block
.
ops
[
index
].
desc
=
desc
insert_op
=
global_block
.
ops
.
pop
(
index
)
input_state
=
input_indexes
.
pop
(
index
)
output_state
=
output_indexes
.
pop
(
index
)
global_block
.
ops
.
insert
(
min_output_index
,
insert_op
)
input_indexes
.
insert
(
min_output_index
,
input_state
)
output_indexes
.
insert
(
min_output_index
,
output_state
)
min_output_index
=
min_output_index
+
1
assert
global_block
.
desc
.
op_size
()
==
len
(
global_block
.
ops
)
for
i
in
range
(
len
(
global_block
.
ops
)):
assert
global_block
.
desc
.
op
(
i
)
==
global_block
.
ops
[
i
].
desc
for
param
,
ops
in
pull_sparse_ops
.
items
():
all_ops
=
program
.
global_block
().
ops
op_device
=
""
if
config
.
is_heter_ps_mode
:
op_device
=
ops
[
0
].
attr
(
"op_device"
)
inputs
=
[
program
.
global_block
().
vars
[
op
.
input
(
"Ids"
)[
0
]]
for
op
in
ops
]
w
=
program
.
global_block
().
vars
[
ops
[
0
].
input
(
"W"
)[
0
]]
emb_size
[
param
]
=
w
.
shape
[
1
]
grad_name
=
config
.
param_name_to_grad_name
[
w
.
name
]
table_id
=
-
1
for
name
,
ctx
in
send_ctx
.
items
():
if
grad_name
in
ctx
.
origin_varnames
():
table_id
=
ctx
.
table_id
()
if
table_id
==
-
1
:
raise
ValueError
(
"can not find suitable sparse table, please check"
)
w_2_table_id
[
param
]
=
table_id
padding_idx
=
ops
[
0
].
attr
(
"padding_idx"
)
is_distributed
=
ops
[
0
].
attr
(
"is_distributed"
)
op_type
=
ops
[
0
].
type
outputs
=
[
program
.
global_block
().
vars
[
op
.
output
(
"Out"
)[
0
]]
for
op
in
ops
]
dag_check_up_and_reorder
(
program
,
inputs
,
outputs
)
op_idxs
=
[
all_ops
.
index
(
op
)
for
op
in
ops
]
for
idx
in
op_idxs
[::
-
1
]:
program
.
global_block
().
_remove_op
(
idx
)
inputs_idxs
=
[
-
1
]
*
len
(
inputs
)
outputs_idxs
=
[
len
(
program
.
global_block
().
ops
)
+
1
]
*
len
(
outputs
)
for
idx
,
op
in
enumerate
(
program
.
global_block
().
ops
):
for
i
in
range
(
0
,
len
(
op
.
output_names
)):
outs
=
op
.
output
(
op
.
output_names
[
i
])
for
in_id
,
in_var
in
enumerate
(
inputs
):
if
in_var
.
name
in
outs
:
inputs_idxs
[
in_id
]
=
max
(
idx
,
inputs_idxs
[
in_id
])
for
i
in
range
(
0
,
len
(
op
.
input_names
)):
ins
=
op
.
input
(
op
.
input_names
[
i
])
for
out_id
,
out_var
in
enumerate
(
outputs
):
if
out_var
.
name
in
ins
:
outputs_idxs
[
out_id
]
=
min
(
idx
,
outputs_idxs
[
out_id
]
)
if
min
(
outputs_idxs
)
-
max
(
inputs_idxs
)
>=
1
:
if
max
(
inputs_idxs
)
==
-
1
:
distributed_idx
=
min
(
op_idxs
)
else
:
distributed_idx
=
max
(
inputs_idxs
)
+
1
if
use_ps_gpu
:
program
.
global_block
().
_insert_op
(
index
=
distributed_idx
,
type
=
"pull_gpups_sparse"
,
inputs
=
{
"Ids"
:
inputs
,
'W'
:
w
},
outputs
=
{
"Out"
:
outputs
},
attrs
=
{
"size"
:
[
w
.
shape
[
1
]
for
i
in
inputs
],
"is_distributed"
:
True
,
"is_sparse"
:
True
,
},
)
else
:
program
.
global_block
().
_insert_op
(
index
=
distributed_idx
,
type
=
"distributed_lookup_table"
,
inputs
=
{
"Ids"
:
inputs
,
'W'
:
w
},
outputs
=
{
"Outputs"
:
outputs
},
attrs
=
{
"is_distributed"
:
is_distributed
,
"padding_idx"
:
padding_idx
,
"table_id"
:
table_id
,
"lookup_table_version"
:
op_type
,
"op_device"
:
op_device
,
},
)
else
:
for
i
in
range
(
len
(
inputs_idxs
)):
distributed_idx
=
op_idxs
[
i
]
program
.
global_block
().
_insert_op
(
index
=
distributed_idx
,
type
=
"distributed_lookup_table"
,
inputs
=
{
"Ids"
:
[
inputs
[
i
]],
'W'
:
w
},
outputs
=
{
"Outputs"
:
[
outputs
[
i
]]},
attrs
=
{
"is_distributed"
:
is_distributed
,
"padding_idx"
:
padding_idx
,
"table_id"
:
table_id
,
"lookup_table_version"
:
op_type
,
"op_device"
:
op_device
,
},
)
def
_push_sparse_fuse
(
_program
,
push_sparse_ops
,
use_ps_gpu
):
if
use_ps_gpu
:
# in ps_gpu_pass
return
if
len
(
push_sparse_ops
)
==
0
:
return
show
=
None
clk
=
None
use_entry
=
False
for
param
,
ops
in
push_sparse_ops
.
items
():
op_first
=
ops
[
0
]
break
print
(
op_first
)
if
op_first
.
has_attr
(
"entry"
):
entry
=
op_first
.
attr
(
"entry"
)
entry
=
entry
.
split
(
':'
)
if
len
(
entry
)
==
3
and
entry
[
0
]
==
'show_click_entry'
:
show_var_name
=
entry
[
1
]
click_var_name
=
entry
[
2
]
if
(
show_var_name
in
program
.
global_block
().
vars
and
click_var_name
in
program
.
global_block
().
vars
):
show
=
program
.
global_block
().
vars
[
show_var_name
]
clk
=
program
.
global_block
().
vars
[
click_var_name
]
use_entry
=
True
else
:
warnings
.
warn
(
'ShowClickEntry configured, but cannot find show/click var, will not use'
)
if
not
use_entry
:
print
(
'ShowClickEntry not configured, will not use'
)
show
=
program
.
global_block
().
create_var
(
name
=
"show"
,
dtype
=
core
.
VarDesc
.
VarType
.
INT64
,
persistable
=
False
,
stop_gradient
=
True
,
)
program
.
global_block
().
_insert_op
(
index
=
0
,
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
show
},
attrs
=
{
'shape'
:
[
1
],
'dtype'
:
show
.
dtype
,
'value'
:
1
,
# OP_ROLE_KEY: OpRole.Forward
},
)
clk
=
program
.
global_block
().
create_var
(
name
=
"clk"
,
dtype
=
core
.
VarDesc
.
VarType
.
INT64
,
persistable
=
False
,
stop_gradient
=
True
,
)
program
.
global_block
().
_insert_op
(
index
=
0
,
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
clk
},
attrs
=
{
'shape'
:
[
1
],
'dtype'
:
clk
.
dtype
,
'value'
:
0
,
# OP_ROLE_KEY: OpRole.Forward
},
)
for
param
,
ops
in
push_sparse_ops
.
items
():
all_ops
=
program
.
global_block
().
ops
op_idxs
=
[
all_ops
.
index
(
op
)
for
op
in
ops
]
inputs
=
[
program
.
global_block
().
vars
[
op
.
input
(
"Ids"
)[
0
]]
for
op
in
ops
]
w
=
program
.
global_block
().
vars
[
ops
[
0
].
output
(
"W@GRAD"
)[
0
]]
table_id
=
w_2_table_id
[
param
]
padding_idx
=
ops
[
0
].
attr
(
"padding_idx"
)
is_distributed
=
ops
[
0
].
attr
(
"is_distributed"
)
op_type
=
ops
[
0
].
type
outputs
=
[
program
.
global_block
().
vars
[
op
.
input
(
"Out@GRAD"
)[
0
]]
for
op
in
ops
]
for
idx
in
op_idxs
[::
-
1
]:
program
.
global_block
().
_remove_op
(
idx
)
# if use_ps_gpu:
# program.global_block().append_op(
# type="push_box_sparse",
# inputs={"Ids": inputs,
# 'Out': outputs},
# outputs={"Out": outputs},
# attrs={
# "size": w.shape[1],
# "is_distributed": True,
# "is_sparse": True
# })
# else:
program
.
global_block
().
append_op
(
type
=
"distributed_push_sparse"
,
inputs
=
{
"Ids"
:
inputs
,
'W'
:
w
,
"Outputs"
:
outputs
,
"Shows"
:
show
,
"Clicks"
:
clk
,
},
outputs
=
{
"Outputs"
:
outputs
},
attrs
=
{
"is_distributed"
:
is_distributed
,
"padding_idx"
:
padding_idx
,
"table_id"
:
table_id
,
"size"
:
emb_size
[
param
],
},
)
pull_sparse_ops
,
push_sparse_ops
=
_get_pull_sparse_ops
(
program
)
_pull_sparse_fuse
(
program
,
pull_sparse_ops
,
use_ps_gpu
)
_push_sparse_fuse
(
program
,
push_sparse_ops
,
use_ps_gpu
)
return
program
def
append_send_ops_pass
(
program
,
config
):
mode
=
config
.
get_distributed_mode
()
trainer_id
=
config
.
get_role_id
()
def
_append_send_op
(
union_vars
,
queue
,
is_sparse
,
table_id
):
if
queue
==
STEP_COUNTER
:
send_input_vars
=
[]
else
:
send_input_vars
=
[
program
.
global_block
().
vars
[
union_var
]
for
union_var
in
union_vars
]
dummy_output
=
[]
if
mode
in
[
DistributedMode
.
SYNC
,
DistributedMode
.
HALF_ASYNC
]:
dummy_output
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
()
)
program
.
global_block
().
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
send_input_vars
},
outputs
=
{
"Out"
:
dummy_output
},
attrs
=
{
"send_varnames"
:
[
queue
],
"is_sparse"
:
is_sparse
,
"table_id"
:
table_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
},
)
return
dummy_output
def
_append_barrier_op
(
dummys
):
program
.
global_block
().
append_op
(
type
=
"send_barrier"
,
inputs
=
{
"X"
:
dummys
},
outputs
=
{
"Out"
:
[]},
attrs
=
{
"trainer_id"
:
trainer_id
,
"half_async"
:
True
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
},
)
dummys
=
[]
sends
=
config
.
get_the_one_trainer_send_context
(
split_dense_table
=
config
.
is_heter_ps_mode
)
for
merged_name
,
send
in
sends
.
items
():
if
send
.
is_sparse
()
and
not
config
.
is_geo_mode
():
continue
is_sparse
=
1
if
send
.
is_sparse
()
else
0
is_sparse
=
2
if
send
.
is_distributed
()
else
is_sparse
dummys
.
append
(
_append_send_op
(
send
.
origin_varnames
(),
merged_name
,
is_sparse
,
send
.
table_id
()
)
)
if
mode
in
[
DistributedMode
.
SYNC
,
DistributedMode
.
HALF_ASYNC
]:
_append_barrier_op
(
dummys
)
return
program
def
init_from_server_pass
(
program
,
config
):
# 0' trainer do not need barrier, it will call barrier at the end init_worker
if
config
.
role_maker
.
_is_first_worker
():
return
program
fetch_barrier_out
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
()
)
program
.
global_block
().
append_op
(
type
=
"fetch_barrier"
,
inputs
=
{},
outputs
=
{
"Out"
:
fetch_barrier_out
},
attrs
=
{
"endpoints"
:
config
.
get_ps_endpoints
(),
"trainer_id"
:
config
.
get_role_id
(),
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
},
)
return
program
def
fake_init_ops_pass
(
program
,
config
):
origin_program
=
config
.
get_origin_main_program
()
def
_get_sparse_table_names
():
dist_varnames
=
get_sparse_tablenames
(
origin_program
,
True
)
sparse_varnames
=
get_sparse_tablenames
(
origin_program
,
False
)
return
list
(
set
(
dist_varnames
+
sparse_varnames
))
def
_fake_init_sparsetable
(
sparse_table_names
):
# delete table init op
for
table_name
in
sparse_table_names
:
table_var
=
program
.
global_block
().
vars
[
table_name
]
table_param_init_op
=
[]
for
op
in
program
.
global_block
().
ops
:
if
table_name
in
op
.
output_arg_names
:
table_param_init_op
.
append
(
op
)
init_op_num
=
len
(
table_param_init_op
)
if
init_op_num
!=
1
:
raise
ValueError
(
"table init op num should be 1, now is "
+
str
(
init_op_num
)
)
table_init_op
=
table_param_init_op
[
0
]
program
.
global_block
().
append_op
(
type
=
"fake_init"
,
inputs
=
{},
outputs
=
{
"Out"
:
table_var
},
attrs
=
{
"shape"
:
table_init_op
.
attr
(
'shape'
)},
)
delete_ops
(
program
.
global_block
(),
table_param_init_op
)
sparse_tables
=
_get_sparse_table_names
()
_fake_init_sparsetable
(
sparse_tables
)
return
program
def
ps_gpu_pass
(
program
):
def
_add_push_box_sparse_op
(
program
):
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
backward
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
for
op
in
program
.
global_block
().
ops
:
if
op
.
type
!=
"pull_box_sparse"
and
op
.
type
!=
"pull_gpups_sparse"
:
continue
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
set
(),
[]
)
for
op_desc
in
grad_op_desc
:
new_op_desc
=
program
.
global_block
().
desc
.
append_op
()
new_op_desc
.
copy_from
(
op_desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
def
_remove_lookup_table_grad_op_and_var
(
program
):
lookup_table_grad_var
=
{}
remove_op_index
=
[]
remove_var
=
[]
for
idx
,
op
in
list
(
enumerate
(
program
.
global_block
().
ops
)):
if
op
.
type
==
"lookup_table_grad"
:
for
name
in
op
.
output
(
"W@GRAD"
):
lookup_table_grad_var
[
name
]
=
1
remove_op_index
.
append
(
idx
)
remove_var
.
append
(
name
)
for
name
in
op
.
input
(
"W"
):
lookup_table_grad_var
[
name
]
=
1
for
idx
,
op
in
list
(
enumerate
(
program
.
global_block
().
ops
)):
if
op
.
type
==
"pull_box_sparse"
or
op
.
type
==
"pull_gpups_sparse"
:
continue
for
key_name
in
op
.
input_names
:
for
var
in
op
.
input
(
key_name
):
if
var
in
lookup_table_grad_var
:
remove_op_index
.
append
(
idx
)
break
remove_op_index
=
list
(
set
(
remove_op_index
))
remove_op_index
.
sort
(
reverse
=
True
)
for
idx
in
remove_op_index
:
program
.
global_block
().
_remove_op
(
idx
)
for
name
in
remove_var
:
program
.
global_block
().
_remove_var
(
name
)
def
_remove_optimizer_var
(
program
):
embedding_w
=
{}
for
idx
,
op
in
list
(
enumerate
(
program
.
global_block
().
ops
)):
if
op
.
type
==
"lookup_table_grad"
:
for
name
in
op
.
input
(
"W"
):
embedding_w
[
name
]
=
1
optimize_vars
=
[]
optimize_op_role_vars
=
[]
optimize_need_delete_vars
=
[]
for
op
in
_get_optimize_ops
(
program
):
for
name
in
op
.
input
(
"Param"
):
if
name
in
embedding_w
:
optimize_op_role_vars
.
extend
(
op
.
attr
(
"op_role_var"
))
for
key_name
in
op
.
input_names
:
if
key_name
==
"LearningRate"
:
continue
for
var
in
op
.
input
(
key_name
):
optimize_vars
.
append
(
var
)
optimize_vars
=
list
(
set
(
optimize_vars
))
optimize_op_role_vars
=
list
(
set
(
optimize_op_role_vars
))
for
var
in
optimize_vars
:
if
var
not
in
optimize_op_role_vars
:
optimize_need_delete_vars
.
append
(
var
)
need_delete_optimize_vars
=
list
(
set
(
optimize_need_delete_vars
))
for
name
in
need_delete_optimize_vars
:
if
program
.
global_block
().
has_var
(
name
):
program
.
global_block
().
_remove_var
(
name
)
_add_push_box_sparse_op
(
program
)
_remove_optimizer_var
(
program
)
_remove_lookup_table_grad_op_and_var
(
program
)
return
program
def
delete_extra_optimizes_pass
(
program
,
config
):
optimize_vars
=
[]
optimize_op_role_vars
=
[]
optimize_need_delete_vars
=
[]
origin_program
=
config
.
get_origin_main_program
()
for
op
in
_get_optimize_ops
(
origin_program
):
optimize_vars
.
extend
(
op
.
input_arg_names
)
optimize_op_role_vars
.
extend
(
op
.
attr
(
"op_role_var"
))
optimize_vars
=
list
(
set
(
optimize_vars
))
optimize_op_role_vars
=
list
(
set
(
optimize_op_role_vars
))
for
var
in
optimize_vars
:
if
var
not
in
optimize_op_role_vars
:
optimize_need_delete_vars
.
append
(
var
)
need_delete_optimize_vars
=
list
(
set
(
optimize_need_delete_vars
))
init_ops
=
[]
for
var
in
need_delete_optimize_vars
:
param_init_op
=
[]
for
op
in
program
.
global_block
().
ops
:
if
var
in
op
.
output_arg_names
:
param_init_op
.
append
(
op
)
init_ops
.
extend
(
param_init_op
)
delete_ops
(
program
.
global_block
(),
init_ops
)
for
var
in
need_delete_optimize_vars
:
if
program
.
global_block
().
has_var
(
var
):
program
.
global_block
().
_remove_var
(
var
)
return
program
def
find_heter_ops
(
program
,
default_device
=
"cpu"
):
if
default_device
not
in
DEVICE_LIST
:
raise
ValueError
(
"Given device {} is not in device list {}"
.
format
(
default_device
,
DEVICE_LIST
)
)
def
_is_heter_op
(
op
,
current_heter_device
,
default_device
=
"cpu"
):
heter_devices
=
list
(
DEVICE_LIST
)
heter_devices
.
remove
(
default_device
)
op_device
=
op
.
attr
(
"op_device"
)
op_type
=
op
.
type
if
op_device
in
heter_devices
:
return
True
elif
(
op_type
in
COMMUNICATE_OPS_TYPE
and
current_heter_device
!=
default_device
):
# for distributed communciate ops: send & recv & barrier etc.
# Todo: need update this method
# op._set_attr('op_device', current_heter_device)
return
True
elif
op_device
is
None
or
op_device
==
default_device
:
op
.
_set_attr
(
'op_device'
,
default_device
)
return
False
return
False
def
_is_same_device
(
op
,
pre_device
,
default_device
=
"cpu"
):
op_device
=
op
.
attr
(
"op_device"
)
if
op_device
==
pre_device
:
return
True
if
pre_device
==
default_device
:
return
True
return
False
def
_append_heter_op
(
op
,
current_heter_block_ops
,
heter_ops
):
op_device
=
op
.
attr
(
"op_device"
)
if
op_device
not
in
heter_ops
:
heter_ops
[
op_device
]
=
{}
current_heter_block_ops
.
append
(
op
)
origin_porgram
=
program
.
clone
()
block
=
program
.
global_block
()
'''
re-place sum op to fix bug for union forward backward op
'''
var2idx
=
{}
op_list
=
list
(
block
.
ops
)
op_size
=
len
(
op_list
)
for
i
in
range
(
op_size
-
1
,
-
1
,
-
1
):
op_list
=
list
(
block
.
ops
)
op
=
op_list
[
i
]
if
"_grad"
in
op
.
type
:
forward_op_type
=
op
.
type
.
split
(
"_grad"
)[
0
]
if
(
forward_op_type
in
SPARSE_OP_TYPE_DICT
.
keys
()
and
op
.
attr
(
'remote_prefetch'
)
is
True
):
param_name
=
op
.
input
(
SPARSE_OP_TYPE_DICT
[
forward_op_type
])[
0
]
if
param_name
in
var2idx
:
## insert sum op & remove sum op from var2idx and origin place
op_list
=
list
(
block
.
ops
)
sum_op
=
op_list
[
var2idx
[
param_name
]]
sum_op_inputs
=
{
sum_op
.
input_names
[
0
]:
[
block
.
vars
[
input
]
for
input
in
sum_op
.
input_arg_names
]
}
sum_op_outputs
=
{
sum_op
.
output_names
[
0
]:
[
block
.
vars
[
output
]
for
output
in
sum_op
.
output_arg_names
]
}
block
.
_insert_op
(
index
=
i
+
1
,
type
=
sum_op
.
type
,
inputs
=
sum_op_inputs
,
outputs
=
sum_op_outputs
,
attrs
=
sum_op
.
all_attrs
(),
)
block
.
_remove_op
(
var2idx
[
param_name
]
+
1
)
var2idx
.
pop
(
param_name
)
for
var_
in
var2idx
:
var2idx
[
var_
]
+=
1
elif
forward_op_type
==
"elementwise_mul"
:
"""
get output varname of pre op
"""
output_vars_no_grad
=
[]
for
key
in
op
.
output_names
:
for
varname
in
op
.
output
(
key
):
if
varname
==
"@EMPTY@"
:
continue
if
"lod_tensor_blocking_queue"
in
varname
:
continue
output_vars_no_grad
.
append
(
varname
.
split
(
"@GRAD"
)[
0
])
for
no_grad_var
in
output_vars_no_grad
:
if
no_grad_var
in
var2idx
:
"""
insert sum op & remove sum op from var2idx and origin place
"""
op_list
=
list
(
block
.
ops
)
sum_op
=
op_list
[
var2idx
[
no_grad_var
]]
sum_op_inputs
=
{
sum_op
.
input_names
[
0
]:
[
block
.
vars
[
input
]
for
input
in
sum_op
.
input_arg_names
]
}
sum_op_outputs
=
{
sum_op
.
output_names
[
0
]:
[
block
.
vars
[
output
]
for
output
in
sum_op
.
output_arg_names
]
}
block
.
_insert_op
(
index
=
i
+
1
,
type
=
sum_op
.
type
,
inputs
=
sum_op_inputs
,
outputs
=
sum_op_outputs
,
attrs
=
sum_op
.
all_attrs
(),
)
block
.
_remove_op
(
var2idx
[
no_grad_var
]
+
1
)
var2idx
.
pop
(
no_grad_var
)
for
var_
in
var2idx
:
var2idx
[
var_
]
+=
1
else
:
if
op
.
type
==
"sum"
:
var
=
op
.
output
(
"Out"
)[
0
]
if
"@GRAD"
in
var
:
origin_var
=
var
.
split
(
"@GRAD"
)[
0
]
pre_op
=
op_list
[
i
-
1
]
if
"_grad"
in
pre_op
.
type
:
forward_op_type
=
pre_op
.
type
.
split
(
"_grad"
)[
0
]
if
(
forward_op_type
in
SPARSE_OP_TYPE_DICT
.
keys
()
and
pre_op
.
attr
(
'remote_prefetch'
)
is
True
):
param_name
=
pre_op
.
input
(
SPARSE_OP_TYPE_DICT
[
forward_op_type
]
)[
0
]
if
param_name
==
origin_var
and
op
.
attr
(
"op_device"
)
==
pre_op
.
attr
(
"op_device"
):
continue
else
:
var2idx
[
origin_var
]
=
i
elif
forward_op_type
==
"elementwise_mul"
:
output_vars
=
[]
for
key
in
pre_op
.
output_names
:
for
varname
in
pre_op
.
output
(
key
):
if
varname
==
"@EMPTY@"
:
continue
if
"lod_tensor_blocking_queue"
in
varname
:
continue
output_vars
.
append
(
varname
)
input_vars
=
[]
for
key
in
op
.
input_names
:
for
varname
in
op
.
input
(
key
):
if
varname
==
"@EMPTY@"
:
continue
if
"lod_tensor_blocking_queue"
in
varname
:
continue
input_vars
.
append
(
varname
)
is_match
=
False
for
varname
in
output_vars
:
if
varname
in
input_vars
:
is_match
=
True
break
if
is_match
:
continue
else
:
var2idx
[
origin_var
]
=
i
else
:
var2idx
[
origin_var
]
=
i
origin_porgram
=
program
.
clone
()
block
=
program
.
global_block
()
program_block_ops
=
[]
default_ops
=
{
default_device
:
{}}
heter_ops
=
{}
block_index
=
0
current_heter_block_ops
=
[]
current_default_block_ops
=
[]
current_heter_device
=
default_device
is_heter
=
False
for
op
in
block
.
ops
:
if
_is_heter_op
(
op
,
current_heter_device
,
default_device
):
# for gpu/xpu-op
is_heter
=
True
# for cpu-op block append
if
len
(
current_default_block_ops
)
>
1
:
default_ops
[
default_device
][
block_index
]
=
current_default_block_ops
program_block_ops
.
append
(
current_default_block_ops
)
current_default_block_ops
=
[]
block_index
+=
1
if
_is_same_device
(
op
,
current_heter_device
,
default_device
):
# for gpu-op, gpu-op -> gpu-op,...
current_heter_device
=
op
.
attr
(
"op_device"
)
_append_heter_op
(
op
,
current_heter_block_ops
,
heter_ops
)
else
:
# for gpu-op -> xpu-op, ...
op_device
=
current_heter_block_ops
[
0
].
attr
(
"op_device"
)
heter_ops
[
op_device
][
block_index
]
=
current_heter_block_ops
program_block_ops
.
append
(
current_heter_block_ops
)
block_index
+=
1
current_heter_block_ops
=
[]
current_heter_device
=
op
.
attr
(
"op_device"
)
_append_heter_op
(
op
,
current_heter_block_ops
,
heter_ops
)
elif
is_heter
:
# for gpu/xpu-op -> cpu-op
op_device
=
current_heter_block_ops
[
0
].
attr
(
"op_device"
)
heter_ops
[
op_device
][
block_index
]
=
current_heter_block_ops
program_block_ops
.
append
(
current_heter_block_ops
)
block_index
+=
1
current_heter_block_ops
=
[]
current_heter_device
=
default_device
is_heter
=
False
current_default_block_ops
.
append
(
op
)
else
:
# for cpu-op
current_default_block_ops
.
append
(
op
)
if
current_default_block_ops
!=
[]:
default_ops
[
default_device
][
block_index
]
=
current_default_block_ops
program_block_ops
.
append
(
current_default_block_ops
)
if
current_heter_block_ops
!=
[]:
op_device
=
current_heter_block_ops
[
0
].
attr
(
"op_device"
)
heter_ops
[
op_device
][
block_index
]
=
current_heter_block_ops
program_block_ops
.
append
(
current_heter_block_ops
)
if
len
(
heter_ops
)
==
0
:
warnings
.
warn
(
"No heterogeneous OP was found in your program , "
" please using paddle.static.device_guard() to run OPs on different device."
)
total_heter_ops
=
0
heter_blocks
=
0
for
device
in
heter_ops
.
keys
():
heter_block_dict
=
heter_ops
[
device
]
heter_blocks
+=
len
(
heter_block_dict
)
for
_
,
heter_block
in
heter_block_dict
.
items
():
total_heter_ops
+=
len
(
heter_block
)
print
(
"There are {} OPs in your main_program, and contains {} heter-OPs which is made up of {} heter-blocks."
.
format
(
len
(
block
.
ops
),
total_heter_ops
,
heter_blocks
)
)
return
origin_porgram
,
heter_ops
,
default_ops
,
program_block_ops
def
create_heter_program
(
program
,
config
,
heter_program
,
program_block_ops_list
,
heter_ops
,
block_var_detail
,
current_device
,
stage_id
,
):
# This function mainly includes the following contents:
# 1. For every heter block:
# a) copy heter device op from origin program
# b) create variables which belong to heter op:
# -> if variable is persistable, clone it in global_scope
# -> if variable is temp, create it in heter block
# c) create communicate related op as follow:
# joint_var.0_1 -> slice -> reshape -> origin_var
# origin_var -> origin_program
# reshape -> concat -> joint_var.1_2
# d) copy send op from origin program for var@grad which loacted in current heter block
# e) re-check every op in current blcok if its device is not current heter devie
# 2. Create send op for step counter in last heter-block
# 3. Create Listen&Serv OP and Send&Recv OP for distributed training
# 4. update CompileTimeStrategy for heter_program
optimizer_block
=
[]
grad_to_block_id
=
[]
send_grad_var_list
=
[]
pre_block_idx
=
heter_program
.
num_blocks
-
1
stage_id
=
int
(
stage_id
)
print
(
"stage id"
,
stage_id
)
heter_block_ops_forward
=
program_block_ops_list
[
stage_id
-
1
][
"forward"
]
heter_block_ops_backward
=
program_block_ops_list
[
stage_id
-
1
][
"backward"
]
heter_block
=
heter_program
.
_create_block
(
pre_block_idx
)
optimizer_block
.
append
(
heter_block
)
for
_
,
op
in
enumerate
(
heter_block_ops_forward
):
block_append_op
(
heter_program
,
program
,
heter_block
,
op
)
entrance_vars
=
block_var_detail
[
stage_id
-
1
][
"forward"
][
"entrance"
]
add_vars_by_var_list
(
entrance_vars
,
program
,
heter_program
,
heter_block
)
exit_vars
=
block_var_detail
[
stage_id
-
1
][
"forward"
][
"exit"
]
add_vars_by_var_list
(
exit_vars
,
program
,
heter_program
,
heter_block
)
first_op_index_fp
=
len
(
heter_block
.
ops
)
if
stage_id
<
len
(
program_block_ops_list
):
heter_block_bp
=
heter_program
.
_create_block
(
pre_block_idx
)
optimizer_block
.
append
(
heter_block_bp
)
for
_
,
op
in
enumerate
(
heter_block_ops_backward
):
block_append_op
(
heter_program
,
program
,
heter_block_bp
,
op
)
bp_entrance_vars
=
block_var_detail
[
stage_id
-
1
][
"backward"
][
"entrance"
]
add_vars_by_var_list
(
bp_entrance_vars
,
program
,
heter_program
,
heter_block_bp
)
bp_exit_vars
=
block_var_detail
[
stage_id
-
1
][
"backward"
][
"exit"
]
add_vars_by_var_list
(
bp_exit_vars
,
program
,
heter_program
,
heter_block_bp
)
backward_comm_info
=
get_communicate_var_info
(
program
,
stage_id
,
bp_entrance_vars
,
type
=
"backward"
)
grad_to_block_id
.
append
(
backward_comm_info
[
"block_input_var_name"
]
+
":"
+
str
(
heter_block_bp
.
idx
)
)
else
:
for
_
,
op
in
enumerate
(
heter_block_ops_backward
):
block_append_op
(
heter_program
,
program
,
heter_block
,
op
)
bp_entrance_vars
=
block_var_detail
[
stage_id
-
1
][
"backward"
][
"entrance"
]
add_vars_by_var_list
(
bp_entrance_vars
,
program
,
heter_program
,
heter_block
)
bp_exit_vars
=
block_var_detail
[
stage_id
-
1
][
"backward"
][
"exit"
]
add_vars_by_var_list
(
bp_exit_vars
,
program
,
heter_program
,
heter_block
)
heter_block_bp
=
heter_block
forward_comm_info
=
get_communicate_var_info
(
program
,
stage_id
,
entrance_vars
,
type
=
"forward"
)
grad_to_block_id
.
append
(
forward_comm_info
[
"block_input_var_name"
]
+
":"
+
str
(
heter_block
.
idx
)
)
first_op_index_bp
=
len
(
heter_block_bp
.
ops
)
if
stage_id
<=
len
(
block_var_detail
)
-
1
:
static_var
=
insert_communicate_op
(
program
,
config
,
heter_block
,
stage_id
,
first_op_index_fp
,
block_var_detail
,
current_device
,
)
static_var_bp
=
insert_communicate_op
(
program
,
config
,
heter_block_bp
,
stage_id
,
first_op_index_bp
,
block_var_detail
,
current_device
,
False
,
)
# add send op
send_grad_var_list
=
add_heter_send_op
(
program
,
heter_program
,
heter_block_bp
,
block_var_detail
[
stage_id
-
1
]
)
# ---------------
# add step conter
send_input_vars
=
[]
dummy_output
=
[]
pserver_endpoints
=
config
.
get_ps_endpoints
()
# optimizer_block[-1].append_op(
# type="send",
# inputs={"X": send_input_vars},
# outputs={"Out": dummy_output},
# attrs={
# "send_varnames": [STEP_COUNTER],
# "merge_add": True,
# "use_send_handler": False,
# "endpoints": pserver_endpoints
# })
# add info in listen&serv
attrs
=
{
# "mode": "sync",
# "trainers": config.get_trainers(),
# "trainer_id": config.get_role_id() + config.get_trainers(),
"message_to_block_id"
:
grad_to_block_id
,
"optimize_blocks"
:
optimizer_block
,
# runtime attribute
"endpoint"
:
config
.
get_heter_worker_endpoint
(),
"fanin"
:
len
(
config
.
get_previous_stage_trainers
()),
"pserver_id"
:
config
.
get_role_id
(),
"distributed_mode"
:
config
.
get_distributed_mode
(),
"rpc_exec_thread_num"
:
int
(
os
.
getenv
(
"CPU_NUM"
,
32
)),
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
}
# append the listen_and_serv op
heter_program
.
global_block
().
append_op
(
type
=
"heter_listen_and_serv"
,
inputs
=
{
'X'
:
[]},
outputs
=
{},
attrs
=
attrs
)
check_heter_compile_time_strategy
(
program
,
config
,
send_grad_var_list
)
def
check_heter_compile_time_strategy
(
program
,
config
,
send_grad_var_list
):
origin_grad_var_list
=
[]
for
_
,
var_grad
in
config
.
merged_variables_pairs
:
origin_grad_var_list
.
append
(
var_grad
.
merged_var
.
name
)
origin_grad_var_list
=
list
(
set
(
origin_grad_var_list
))
send_grad_var_list
=
list
(
set
(
send_grad_var_list
))
useless_grad_var_list
=
list
(
set
(
origin_grad_var_list
)
-
set
(
send_grad_var_list
)
)
for
useless_grad_var
in
useless_grad_var_list
:
config
.
remove_var_pair_by_grad
(
useless_grad_var
)
def
create_trainer_program
(
program
,
origin_program
,
config
,
program_block_ops_list
,
block_var_detail
):
# This function mainly includes the following contents:
# 1. For every heter block in origin program
# a) delete heter op and related variables
# b) add send&recv op
# c) add communicate ops as follows:
# origin_var -> reshape -> concat -> joint_var.0_1
# send&recv op(send joint_var.0_1; recv joint_var.1_2)
# joint_var.1_2 -> slice -> reshape -> origin_var
# d) remove send op which related var@grad is not in trainer program
# 2. check every op's device
static_var
=
[]
for
heter_block_index
in
range
(
1
,
len
(
program_block_ops_list
)):
ops_list
=
(
program_block_ops_list
[
heter_block_index
][
"forward"
]
+
program_block_ops_list
[
heter_block_index
][
"backward"
]
)
static_var
+=
replace_ops_by_communicate_op
(
program
,
config
,
heter_block_index
,
ops_list
,
block_var_detail
)
remove_trainer_send_op
(
program
,
config
,
heter_block_index
,
block_var_detail
)
optimizer_block
=
[]
grad_to_block_id
=
[]
bp_ops_list
=
program_block_ops_list
[
0
][
"backward"
]
delete_same_ops
(
program
.
global_block
(),
bp_ops_list
)
delete_trainer_useless_var
(
config
,
program
,
static_var
)
backward_block
=
create_backward_block
(
program
,
origin_program
,
config
,
bp_ops_list
,
block_var_detail
)
bp_entrance_vars
=
block_var_detail
[
0
][
"backward"
][
"entrance"
]
backward_comm_info
=
get_communicate_var_info
(
origin_program
,
1
,
bp_entrance_vars
,
type
=
"backward"
)
grad_to_block_id
.
append
(
backward_comm_info
[
"block_input_var_name"
]
+
":"
+
str
(
backward_block
.
idx
)
)
optimizer_block
.
append
(
backward_block
)
attrs
=
{
# "mode": "sync",
# "trainers": config.get_trainers(),
# "trainer_id": config.get_role_id(),
"message_to_block_id"
:
grad_to_block_id
,
"optimize_blocks"
:
optimizer_block
,
# runtime attribute
"endpoint"
:
config
.
get_trainer_endpoint
(),
## get trainer endpoint
"fanin"
:
0
,
## get heter worker
"pserver_id"
:
config
.
get_role_id
(),
"distributed_mode"
:
config
.
get_distributed_mode
(),
"rpc_exec_thread_num"
:
int
(
os
.
getenv
(
"CPU_NUM"
,
32
)),
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
}
# append the listen_and_serv op
program
.
global_block
().
_insert_op
(
index
=
0
,
type
=
"heter_listen_and_serv"
,
inputs
=
{
'X'
:
[]},
outputs
=
{},
attrs
=
attrs
,
)
## TODO add check for bp block
check_op_device
(
program
.
global_block
(),
DEFAULT_DEVICE
)
def
insert_communicate_op
(
orign_program
,
config
,
heter_block
,
stage_id
,
first_op_index
,
block_var_detail
,
device
,
is_forward
=
True
,
):
if
is_forward
:
next_heter_worker_endpoints
=
config
.
get_next_stage_trainers
()
previous_heter_worker_endpoints
=
config
.
get_previous_stage_trainers
()
entrance_var
=
block_var_detail
[
stage_id
][
"forward"
][
"entrance"
]
comm_info
=
get_communicate_var_info
(
orign_program
,
stage_id
+
1
,
entrance_var
)
else
:
next_heter_worker_endpoints
=
config
.
get_next_stage_trainers
()
# if next_heter_worker_endpoints == "":
# next_heter_worker_endpoints = []
previous_heter_worker_endpoints
=
config
.
get_previous_stage_trainers
()
entrance_var
=
block_var_detail
[
stage_id
-
1
][
"backward"
][
"exit"
]
comm_info
=
get_communicate_var_info
(
orign_program
,
stage_id
-
1
,
entrance_var
,
"backward"
)
heter_block
.
_insert_op
(
index
=
first_op_index
,
type
=
"send_and_recv"
,
inputs
=
{
"X"
:
heter_block
.
vars
[
entrance_var
[
0
]]},
outputs
=
{
"Out"
:
[]},
attrs
=
{
"mode"
:
"forward"
if
is_forward
else
"backward"
,
"send_var_name"
:
entrance_var
+
[
"microbatch_id"
],
"recv_var_name"
:
[],
"message_name"
:
comm_info
[
"block_input_var_name"
],
"next_endpoints"
:
next_heter_worker_endpoints
,
"previous_endpoints"
:
previous_heter_worker_endpoints
,
"trainer_id"
:
config
.
get_role_id
(),
"op_device"
:
device
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
},
)
return
entrance_var
def
create_backward_block
(
program
,
origin_program
,
config
,
bp_ops_list
,
block_var_detail
):
pre_block_idx
=
program
.
num_blocks
-
1
heter_block
=
program
.
_create_block
(
pre_block_idx
)
for
_
,
op
in
enumerate
(
bp_ops_list
):
if
op
.
type
==
"send"
:
send_varnames
=
op
.
attr
(
'send_varnames'
)
is_skip
=
False
for
varname
in
send_varnames
:
if
(
varname
not
in
program
.
global_block
().
vars
and
varname
not
in
heter_block
.
vars
):
is_skip
=
True
break
if
is_skip
==
True
:
continue
block_append_op
(
program
,
origin_program
,
heter_block
,
op
)
entrance_vars
=
block_var_detail
[
0
][
"backward"
][
"entrance"
]
add_vars_by_var_list
(
entrance_vars
,
origin_program
,
program
,
heter_block
)
exit_vars
=
block_var_detail
[
0
][
"backward"
][
"exit"
]
add_vars_by_var_list
(
exit_vars
,
origin_program
,
program
,
heter_block
)
return
heter_block
def
replace_ops_by_communicate_op
(
program
,
config
,
heter_block_index
,
ops_list
,
block_var_detail
):
all_op
=
program
.
global_block
().
ops
start_op
=
ops_list
[
0
]
first_op_idx
=
-
1
for
op
in
all_op
:
if
is_same_op
(
op
,
start_op
):
first_op_idx
=
all_op
.
index
(
op
)
break
assert
first_op_idx
!=
-
1
delete_same_ops
(
program
.
global_block
(),
ops_list
)
entrance_var
=
[]
if
heter_block_index
==
1
:
mode
=
config
.
get_distributed_mode
()
next_heter_worker_endpoints
=
config
.
get_next_stage_trainers
()
entrance_var
=
block_var_detail
[
heter_block_index
][
"forward"
][
"entrance"
]
comm_info
=
get_communicate_var_info
(
program
,
heter_block_index
+
1
,
entrance_var
)
program
.
global_block
().
_insert_op
(
index
=
first_op_idx
,
type
=
"send_and_recv"
,
inputs
=
{
"X"
:
program
.
global_block
().
vars
[
entrance_var
[
0
]]},
outputs
=
{
"Out"
:
[]},
attrs
=
{
"mode"
:
"forward"
,
"send_var_name"
:
entrance_var
+
[
"microbatch_id"
],
"recv_var_name"
:
[],
"message_name"
:
comm_info
[
"block_input_var_name"
],
"next_endpoints"
:
next_heter_worker_endpoints
,
"previous_endpoints"
:
[],
"trainer_id"
:
config
.
get_role_id
(),
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
},
)
return
entrance_var
def
remove_trainer_send_op
(
program
,
config
,
heter_block_index
,
block_var_detail
):
# if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD
# if trainer only do SEND, it has one var: var@GRAD
# Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD)
persistables
=
(
block_var_detail
[
heter_block_index
][
"forward"
][
"persistables"
]
+
block_var_detail
[
heter_block_index
][
"backward"
][
"persistables"
]
)
need_remove_send_op
=
[]
need_remove_grad_var
=
[]
for
op
in
find_send_op
(
program
):
input_list
,
_
=
find_op_input_output
(
program
,
program
.
global_block
(),
op
)
for
var_name
in
input_list
:
origin_var_name
=
var_name
.
split
(
"@GRAD"
)[
0
]
if
origin_var_name
in
persistables
:
need_remove_send_op
.
append
(
op
)
need_remove_grad_var
.
append
(
var_name
)
need_remove_send_op
=
list
(
set
(
need_remove_send_op
))
delete_ops
(
program
.
global_block
(),
need_remove_send_op
)
for
grad_var_name
in
need_remove_grad_var
:
config
.
remove_var_pair_by_grad
(
grad_var_name
)
def
add_heter_send_op
(
program
,
heter_program
,
block
,
block_var_detail
):
def
_get_send_op_dict
():
send_op_dict
=
{}
send_op_list
=
find_send_op
(
program
)
for
op
in
send_op_list
:
input_list
,
_
=
find_op_input_output
(
program
,
program
.
global_block
(),
op
)
for
var
in
input_list
:
send_op_dict
[
var
]
=
op
return
send_op_dict
# send_Op = { inputs{'X':[]},
# outputs{'Out':dummy_output},
# attrs{'send_varnames'"[]",
# 'is_sparse':int,
# 'table_id':int } }
send_grad_var_list
=
[]
send_op_dict
=
_get_send_op_dict
()
table_dict
=
{}
for
persistable_var
in
block_var_detail
[
"backward"
][
"persistables"
]:
# check var_name == var@GRAD
if
"@GRAD"
not
in
persistable_var
:
continue
if
"GRAD"
!=
persistable_var
.
split
(
"@"
)[
-
1
]:
continue
if
persistable_var
not
in
send_op_dict
:
continue
send_op
=
send_op_dict
[
persistable_var
]
is_sparse
=
send_op
.
attr
(
'is_sparse'
)
table_id
=
send_op
.
attr
(
'table_id'
)
send_varnames
=
send_op
.
attr
(
'send_varnames'
)
send_grad_var_list
.
append
(
persistable_var
)
if
table_id
not
in
table_dict
:
table_dict
[
table_id
]
=
{}
table_dict
[
table_id
][
'var_list'
]
=
[]
table_dict
[
table_id
][
'is_sparse'
]
=
is_sparse
table_dict
[
table_id
][
'send_varnames'
]
=
send_varnames
table_dict
[
table_id
][
'var_list'
].
append
(
persistable_var
)
for
table_id
in
table_dict
:
dummy_output
=
block
.
create_var
(
name
=
framework
.
generate_control_dev_var_name
()
)
send_input_vars
=
[
block
.
vars
[
union_var
]
for
union_var
in
table_dict
[
table_id
][
'var_list'
]
]
block
.
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
send_input_vars
},
outputs
=
{
"Out"
:
dummy_output
},
attrs
=
{
"send_varnames"
:
table_dict
[
table_id
][
'send_varnames'
],
"is_sparse"
:
is_sparse
,
"table_id"
:
table_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
},
)
return
send_grad_var_list
def
find_send_op
(
program
):
send_op_list
=
[]
for
op
in
program
.
global_block
().
ops
:
if
op
.
type
==
"send"
:
send_op_list
.
append
(
op
)
return
send_op_list
def
get_communicate_var_info
(
program
,
block_index
,
entrance_var_list
,
type
=
"forward"
):
input_var_reshape_dim
=
[]
input_var_reshape_name
=
[]
if
type
==
"forward"
:
block_input_var_name
=
"forward_joint_{}_{}@Heter"
.
format
(
block_index
-
1
,
block_index
)
else
:
block_input_var_name
=
"backward_joint_{}_{}@Heter"
.
format
(
block_index
+
1
,
block_index
)
entrance_var_list
.
sort
()
# input
# Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var
for
name
in
entrance_var_list
:
var
=
program
.
global_block
().
vars
[
name
]
shape
=
var
.
shape
# if len(shape) < 2 or shape[0] != -1:
# raise ValueError(
# "Variable {} not support heter training. its shape is {}".
# format(name, shape))
recv_var_dim
=
-
1
*
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
input_var_reshape_dim
.
append
(
recv_var_dim
)
input_var_reshape_name
.
append
(
"{}.input_reshape@Heter"
.
format
(
name
))
# output
# var -> reshape -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> concat -> Heter_SERVER_BLOCK_index@JOINT_VAR
# for var_name in exit_var_list:
# var = program.global_block().vars[var_name]
# shape = var.shape
# # if len(shape) < 2 or shape[0] != -1:
# # raise ValueError(
# # "Variable {} not support heter training. its shape is {}".
# # format(var_name, shape))
# send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape)
# output_var_reshape_dim.append(send_reshape_dim)
# output_var_reshape_name.append("{}.output_reshape@Heter".format(
# var_name))
info
=
{
"input_var_reshape_dim"
:
input_var_reshape_dim
,
"input_var_reshape_name"
:
input_var_reshape_name
,
"block_input_var_name"
:
block_input_var_name
,
# "output_var_reshape_dim": output_var_reshape_dim,
# "output_var_reshape_name": output_var_reshape_name,
# "block_output_var_name": block_output_var_name
}
return
info
def
union_forward_gradient_op
(
program_block_ops_list
):
"""
before analyzing the input & output of each block in program_block_list, we should
union the forward op and corresponding gradient op to elimincate the unnecessary variable
transmit
"""
"""
fix for 2emb model, re-place sum op
"""
block_length
=
len
(
program_block_ops_list
)
'''
## get the final part
final_part_idx = -1
for i in range(block_length):
op_list = program_block_ops_list[i]
for op in op_list:
if "_grad" in op.type:
final_part_idx = i
break
if final_part_idx != -1:
break
## eliminate wrong partition because of sum op
## lookup_table_v2_grad
## every looup_table_v2_grad op block should follow a sum op
var2idx = {}
for i in range(final_part_idx, block_length):
op_list = program_block_ops_list[i]
for j in range(len(op_list) - 1, -1, -1):
op = op_list[j]
#if op.type == "lookup_table_v2_grad":
# if j < len(op_list) - 1):
# else:
# ## get var and record place
if _grad in op.type:
forward_op_type = op.type.split("_grad")[0]
if forward_op_type in SPARSE_OP_TYPE_DICT.keys()
\
and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0]
var2idx[] = [i,j] ##
'''
union_program_block_ops_list
=
[]
assert
(
block_length
%
2
!=
0
),
"the length of program_block_ops_list should be odd"
for
i
in
range
(
0
,
block_length
//
2
):
block_op_list
=
{
"forward"
:
program_block_ops_list
[
i
]}
block_op_list
.
update
(
{
"backward"
:
program_block_ops_list
[
block_length
-
1
-
i
]}
)
union_program_block_ops_list
.
append
(
block_op_list
)
block_op_list
=
{
"forward"
:
[],
"backward"
:
[]}
for
op
in
program_block_ops_list
[
block_length
//
2
]:
if
not
"_grad"
in
op
.
type
and
not
(
op
.
type
==
"sum"
):
block_op_list
[
"forward"
].
append
(
op
)
else
:
block_op_list
[
"backward"
].
append
(
op
)
union_program_block_ops_list
.
append
(
block_op_list
)
return
union_program_block_ops_list
def
find_block_joints
(
program
,
program_block_ops_list
,
heter_ops
):
block_var_detail
=
find_entrance_exit_private
(
program
,
program_block_ops_list
)
block_var_detail
=
entrance_exit_check
(
program
,
program_block_ops_list
,
block_var_detail
,
heter_ops
)
block_var_detail
=
delete_block_useless_exit
(
program
,
program_block_ops_list
,
block_var_detail
)
return
block_var_detail
def
find_entrance_exit_private
(
program
,
program_block_ops_list
):
block_var_detail
=
[]
persistables
=
[]
for
index
,
block_op_list
in
enumerate
(
program_block_ops_list
):
## forward
block_input
,
block_output
=
find_ops_list_input_output
(
program
,
block_op_list
[
"forward"
]
)
persistables
=
screen_persistables
(
program
,
block_input
)
+
screen_persistables
(
program
,
block_output
)
# find entrance & exit
block_private_vars
=
list
(
set
(
block_input
)
&
set
(
block_output
))
block_entrance
=
list
(
set
(
block_input
)
-
set
(
block_private_vars
))
block_exit
=
list
(
set
(
block_output
)
-
set
(
block_private_vars
))
detail
=
{
"forward"
:
{
"entrance"
:
block_entrance
,
"exit"
:
block_exit
,
"private"
:
block_private_vars
,
"persistables"
:
persistables
,
}
}
## backward
bp_block_input
,
bp_block_output
=
find_ops_list_input_output
(
program
,
block_op_list
[
"backward"
]
)
bp_persistables
=
screen_persistables
(
program
,
bp_block_input
)
+
screen_persistables
(
program
,
bp_block_output
)
# find entrance & exit
bp_block_private_vars
=
list
(
set
(
bp_block_input
)
&
set
(
bp_block_output
))
bp_block_entrance
=
list
(
set
(
bp_block_input
)
-
set
(
bp_block_private_vars
)
)
bp_block_exit
=
list
(
set
(
bp_block_output
)
-
set
(
bp_block_private_vars
))
detail
.
update
(
{
"backward"
:
{
"entrance"
:
bp_block_entrance
,
"exit"
:
bp_block_exit
,
"private"
:
bp_block_private_vars
,
"persistables"
:
bp_persistables
,
}
}
)
block_var_detail
.
append
(
detail
)
return
block_var_detail
def
entrance_exit_check
(
program
,
program_block_ops_list
,
block_var_detail
,
heter_ops
):
for
index
in
range
(
len
(
block_var_detail
)
-
1
,
-
1
,
-
1
):
if
index
-
1
<
0
:
break
previous_block_exit
=
block_var_detail
[
index
-
1
][
"forward"
][
"exit"
]
previous_block_exit
.
sort
()
current_block_entrance
=
block_var_detail
[
index
][
"forward"
][
"entrance"
]
backward_entrance
=
block_var_detail
[
index
][
"backward"
][
"entrance"
]
forward_all
=
(
block_var_detail
[
index
][
"forward"
][
"entrance"
]
+
block_var_detail
[
index
][
"forward"
][
"private"
]
+
block_var_detail
[
index
][
"forward"
][
"exit"
]
)
for
var
in
backward_entrance
:
if
not
(
"@GRAD"
in
var
)
and
not
(
var
in
forward_all
):
current_block_entrance
.
append
(
var
)
current_block_entrance
.
sort
()
if
previous_block_exit
==
current_block_entrance
:
continue
exist_vars
=
list
(
set
(
previous_block_exit
)
&
set
(
current_block_entrance
)
)
need_add_vars
=
list
(
set
(
current_block_entrance
)
-
set
(
exist_vars
))
# var in different stage should not be ignored, since they are not placed in the same program & device
# need_add_vars = find_need_var_from_previous_block(
# need_add_vars, block_var_detail, index, heter_ops)
previous_block_private
=
block_var_detail
[
index
-
1
][
"forward"
][
"private"
]
previous_block_entrance
=
block_var_detail
[
index
-
1
][
"forward"
][
"entrance"
]
for
var
in
need_add_vars
:
if
(
var
not
in
previous_block_private
and
var
not
in
previous_block_entrance
):
previous_block_entrance
.
append
(
var
)
previous_block_exit
.
append
(
var
)
if
not
var
in
current_block_entrance
:
current_block_entrance
.
append
(
var
)
for
index
in
range
(
0
,
len
(
block_var_detail
)
-
1
,
1
):
previous_block_exit
=
block_var_detail
[
index
+
1
][
"backward"
][
"exit"
]
previous_block_exit
.
sort
()
current_block_entrance
=
block_var_detail
[
index
][
"backward"
][
"entrance"
]
current_block_entrance
.
sort
()
if
previous_block_exit
==
current_block_entrance
:
continue
exist_vars
=
list
(
set
(
previous_block_exit
)
&
set
(
current_block_entrance
)
)
need_add_vars
=
list
(
set
(
current_block_entrance
)
-
set
(
exist_vars
))
need_ignore_vars
=
[]
for
var
in
need_add_vars
:
if
not
"@GRAD"
in
var
:
need_ignore_vars
.
append
(
var
)
need_add_vars
=
list
(
set
(
need_add_vars
).
difference
(
set
(
need_ignore_vars
))
)
previous_block_private
=
block_var_detail
[
index
+
1
][
"backward"
][
"private"
]
previous_block_entrance
=
block_var_detail
[
index
+
1
][
"backward"
][
"entrance"
]
for
var
in
need_add_vars
:
if
(
var
not
in
previous_block_private
and
var
not
in
previous_block_entrance
):
previous_block_entrance
.
append
(
var
)
previous_block_exit
.
append
(
var
)
return
block_var_detail
def
find_need_var_from_previous_block
(
need_add_vars
,
block_var_detail
,
current_index
,
heter_ops
):
# create index_device_map
index_device_map
=
{}
for
index
in
range
(
len
(
block_var_detail
)):
index_device_map
[
index
]
=
DEFAULT_DEVICE
for
device
in
heter_ops
:
for
index
in
heter_ops
[
device
].
keys
():
if
index
<
len
(
block_var_detail
):
index_device_map
[
index
]
=
device
pre_index
=
current_index
-
1
need_ignore_var
=
[]
# if need_add_var in current device, no need communicate
for
var
in
need_add_vars
:
while
pre_index
>=
0
:
previous_block_private
=
block_var_detail
[
pre_index
][
"private"
]
previous_block_exit
=
block_var_detail
[
pre_index
][
"exit"
]
previous_block_entrance
=
block_var_detail
[
pre_index
][
"entrance"
]
total_var
=
(
previous_block_private
+
previous_block_exit
+
previous_block_entrance
)
if
var
in
total_var
:
if
(
index_device_map
[
current_index
]
==
index_device_map
[
pre_index
]
and
index_device_map
[
current_index
]
==
DEFAULT_DEVICE
):
need_ignore_var
.
append
(
var
)
break
pre_index
-=
1
need_add_vars
=
list
(
set
(
need_add_vars
).
difference
(
set
(
need_ignore_var
)))
return
need_add_vars
def
delete_block_useless_exit
(
program
,
program_block_ops_list
,
block_var_detail
):
## forward
for
index
in
range
(
len
(
block_var_detail
)):
if
index
==
len
(
block_var_detail
)
-
1
:
break
current_block_exit
=
block_var_detail
[
index
][
"forward"
][
"exit"
]
next_block_entrance
=
block_var_detail
[
index
+
1
][
"forward"
][
"entrance"
]
need_delete_var
=
[]
for
var
in
current_block_exit
:
if
var
not
in
next_block_entrance
:
need_delete_var
.
append
(
var
)
for
var
in
need_delete_var
:
current_block_exit
.
remove
(
var
)
## backward
for
index
in
range
(
len
(
block_var_detail
)
-
1
,
-
1
,
-
1
):
if
index
-
1
<
0
:
break
current_block_exit
=
block_var_detail
[
index
][
"backward"
][
"exit"
]
next_block_entrance
=
block_var_detail
[
index
-
1
][
"backward"
][
"entrance"
]
need_delete_var
=
[]
for
var
in
current_block_exit
:
if
var
not
in
next_block_entrance
:
need_delete_var
.
append
(
var
)
for
var
in
need_delete_var
:
current_block_exit
.
remove
(
var
)
return
block_var_detail
def
check_op_device
(
block
,
device
):
for
op
in
block
.
ops
:
op
.
_set_attr
(
'op_device'
,
device
)
def
screen_persistables
(
program
,
var_list
):
need_remove
=
[]
for
var_name
in
var_list
:
if
"@GRAD"
in
var_name
:
if
"GRAD"
!=
var_name
.
split
(
"@"
)[
-
1
]:
continue
origin_var_name
=
var_name
.
split
(
"@GRAD"
)[
0
]
var
=
program
.
global_block
().
vars
[
origin_var_name
]
else
:
var
=
program
.
global_block
().
vars
[
var_name
]
if
paddle
.
static
.
is_persistable
(
var
):
need_remove
.
append
(
var_name
)
for
var_name
in
need_remove
:
var_list
.
remove
(
var_name
)
return
need_remove
def
insert_reshape_op
(
program
,
block
,
index
,
var_name
,
new_var_name
,
new_var_shape
=
None
):
input_var
=
block
.
vars
[
var_name
]
if
new_var_name
not
in
block
.
vars
:
out
=
block
.
create_var
(
name
=
new_var_name
,
shape
=
new_var_shape
,
dtype
=
input_var
.
dtype
,
type
=
input_var
.
type
,
)
else
:
out
=
block
.
vars
[
new_var_name
]
new_var_shape
=
out
.
shape
x_shape
=
block
.
create_var
(
name
=
"{}.xshape@Heter"
.
format
(
var_name
),
dtype
=
input_var
.
dtype
)
block
.
_insert_op
(
index
=
index
,
type
=
"reshape2"
,
inputs
=
{
"X"
:
input_var
},
attrs
=
{
'shape'
:
new_var_shape
},
outputs
=
{
"Out"
:
out
,
"XShape"
:
x_shape
},
)
def
insert_send_concat_op
(
program
,
block
,
index
,
var_name_list
,
new_var_name
,
new_var_shape
):
input_var_list
=
[
block
.
vars
[
var_name
]
for
var_name
in
var_name_list
]
out
=
program
.
global_block
().
create_var
(
name
=
new_var_name
,
shape
=
new_var_shape
,
dtype
=
input_var_list
[
0
].
dtype
,
type
=
input_var_list
[
0
].
type
,
)
block
.
_insert_op
(
index
=
index
,
type
=
'concat'
,
inputs
=
{
"X"
:
input_var_list
},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'axis'
:
-
1
,
'use_stack'
:
False
},
)
def
insert_recv_slice_op
(
program
,
block
,
index
,
var_name
,
var_shape
,
dtype
,
type
,
new_var_name_list
,
new_var_shape_list
,
):
if
var_name
not
in
program
.
global_block
().
vars
:
input_var
=
program
.
global_block
().
create_var
(
name
=
var_name
,
shape
=
var_shape
,
dtype
=
dtype
,
type
=
type
)
else
:
input_var
=
program
.
global_block
().
vars
[
var_name
]
out_list
=
[]
for
i
in
range
(
len
(
new_var_name_list
)):
if
new_var_name_list
[
i
]
not
in
block
.
vars
:
out
=
block
.
create_var
(
name
=
new_var_name_list
[
i
],
shape
=
new_var_shape_list
[
i
],
dtype
=
input_var
.
dtype
,
type
=
input_var
.
type
,
)
else
:
out
=
block
.
vars
[
new_var_name_list
[
i
]]
out_list
.
append
(
out
)
start_index
=
0
end_index
=
0
for
i
in
range
(
len
(
new_var_name_list
)):
starts
=
[]
ends
=
[]
attrs
=
{
'axes'
:
[
1
]}
end_index
+=
new_var_shape_list
[
i
][
1
]
starts
.
append
(
start_index
)
ends
.
append
(
end_index
)
attrs
[
'starts'
]
=
starts
attrs
[
'ends'
]
=
ends
block
.
_insert_op
(
index
=
index
,
type
=
'slice'
,
inputs
=
{
'Input'
:
input_var
},
attrs
=
attrs
,
outputs
=
{
'Out'
:
out_list
[
i
]},
)
start_index
=
end_index
index
+=
1
def
add_heter_trainer_useful_vars
(
config
,
program
,
heter_program
,
heter_block
,
static_var
):
static_var
=
list
(
set
(
static_var
))
for
var_name
in
static_var
:
if
(
var_name
not
in
heter_program
.
global_block
().
vars
and
var_name
not
in
heter_block
.
vars
):
var
=
program
.
global_block
().
vars
[
var_name
]
if
var
.
persistable
:
heter_program
.
global_block
().
_clone_variable
(
var
,
force_persistable
=
False
)
else
:
heter_block
.
_clone_variable
(
var
,
force_persistable
=
False
)
def
delete_trainer_useless_var
(
config
,
program
,
static_var
):
static_var
=
list
(
set
(
static_var
))
program_useful_var_list
=
[]
for
op
in
program
.
global_block
().
ops
:
input_var_list
,
output_var_list
=
find_op_input_output
(
program
,
program
.
global_block
(),
op
)
op_var_list
=
list
(
set
(
input_var_list
).
union
(
set
(
output_var_list
)))
program_useful_var_list
=
list
(
set
(
program_useful_var_list
).
union
(
set
(
op_var_list
))
)
program_useful_var_list
+=
static_var
program_useless_var_list
=
list
(
set
(
get_vars_name_in_block
(
program
.
global_block
())).
difference
(
set
(
program_useful_var_list
)
)
)
for
var
in
program_useless_var_list
:
program
.
global_block
().
_remove_var
(
var
)
return
program_useless_var_list
def
block_append_op
(
program
,
origin_program
,
block
,
op
):
merge_ordereddict
=
origin_program
.
global_block
().
vars
.
copy
()
merge_ordereddict
.
update
(
block
.
vars
)
inputs
=
_get_input_map_from_op
(
merge_ordereddict
,
op
)
for
key
,
varlist
in
inputs
.
items
():
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
var
in
varlist
:
if
(
var
.
name
not
in
program
.
global_block
().
vars
and
var
.
name
not
in
block
.
vars
):
if
var
.
persistable
:
program
.
global_block
().
_clone_variable
(
var
,
force_persistable
=
False
)
else
:
block
.
_clone_variable
(
var
,
force_persistable
=
False
)
outputs
=
_get_output_map_from_op
(
origin_program
.
global_block
().
vars
,
op
)
for
key
,
varlist
in
outputs
.
items
():
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
var
in
varlist
:
if
(
var
.
name
not
in
program
.
global_block
().
vars
and
var
.
name
not
in
block
.
vars
):
if
var
.
persistable
:
program
.
global_block
().
_clone_variable
(
var
,
force_persistable
=
False
)
else
:
block
.
_clone_variable
(
var
,
force_persistable
=
False
)
if
"_grad"
not
in
op
.
type
:
# for forward op
return
block
.
append_op
(
type
=
op
.
type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
op
.
all_attrs
()
)
else
:
# for grad op
op_desc
=
op
.
desc
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
backward
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
device_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpDeviceAttrName
()
# append grad op
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op_desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
# set device gard
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op_desc
.
attr
(
device_attr_name
)
new_op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
block
.
_sync_with_cpp
()
def
add_vars_by_var_list
(
var_name_list
,
origin_program
,
program
,
block
):
for
var_name
in
var_name_list
:
if
(
var_name
not
in
program
.
global_block
().
vars
and
var_name
not
in
block
.
vars
):
var
=
origin_program
.
global_block
().
vars
[
var_name
]
if
var
.
persistable
:
program
.
global_block
().
_clone_variable
(
var
,
force_persistable
=
False
)
else
:
block
.
_clone_variable
(
var
,
force_persistable
=
False
)
def
get_varlist_from_op_map
(
var_map
):
var_list
=
[]
for
key
,
varlist
in
var_map
.
items
():
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
i
in
range
(
len
(
varlist
)):
var
=
varlist
[
i
]
var_list
.
append
(
var
.
name
)
return
var_list
def
find_ops_list_input_output
(
program
,
ops_list
):
input_var_list
=
[]
output_var_list
=
[]
for
op
in
ops_list
:
inputs
=
_get_input_map_from_op
(
program
.
global_block
().
vars
,
op
)
input_var_list
+=
get_varlist_from_op_map
(
inputs
)
outputs
=
_get_output_map_from_op
(
program
.
global_block
().
vars
,
op
)
output_var_list
+=
get_varlist_from_op_map
(
outputs
)
input_var_list
=
list
(
set
(
input_var_list
))
output_var_list
=
list
(
set
(
output_var_list
))
return
input_var_list
,
output_var_list
def
find_op_input_output
(
program
,
block
,
op
):
input_var_list
=
[]
output_var_list
=
[]
inputs
=
_get_input_map_from_op
(
block
.
vars
,
op
)
input_var_list
+=
get_varlist_from_op_map
(
inputs
)
outputs
=
_get_output_map_from_op
(
block
.
vars
,
op
)
output_var_list
+=
get_varlist_from_op_map
(
outputs
)
input_var_list
=
list
(
set
(
input_var_list
))
output_var_list
=
list
(
set
(
output_var_list
))
return
input_var_list
,
output_var_list
def
get_vars_name_in_block
(
block
):
vars_list
=
block
.
vars
.
keys
()
vars_name_list
=
[
var_name
for
var_name
in
vars_list
]
return
vars_name_list
def
is_same_op
(
op1
,
op2
):
if
str
(
op1
)
!=
str
(
op2
):
return
False
return
True
def
_get_input_map_from_op
(
varmap
,
op
):
"""Returns a dict from op input name to the vars in varmap."""
iomap
=
collections
.
OrderedDict
()
for
key
in
op
.
input_names
:
vars
=
[]
for
varname
in
op
.
input
(
key
):
if
varname
==
"@EMPTY@"
:
continue
if
"lod_tensor_blocking_queue"
in
varname
:
continue
vars
.
append
(
varmap
[
varname
])
if
len
(
vars
)
==
1
:
iomap
[
key
]
=
vars
[
0
]
else
:
iomap
[
key
]
=
vars
return
iomap
def
_get_output_map_from_op
(
varmap
,
op
):
"""Returns a dict from op output name to the vars in varmap."""
iomap
=
collections
.
OrderedDict
()
for
key
in
op
.
output_names
:
vars
=
[]
for
varname
in
op
.
output
(
key
):
if
varname
==
"@EMPTY@"
:
continue
if
"lod_tensor_blocking_queue"
in
varname
:
continue
vars
.
append
(
varmap
[
varname
])
if
len
(
vars
)
==
1
:
iomap
[
key
]
=
vars
[
0
]
else
:
iomap
[
key
]
=
vars
return
iomap
def
delete_same_ops
(
block
,
ops
):
for
op
in
ops
:
try
:
for
origin_op
in
block
.
ops
:
if
is_same_op
(
origin_op
,
op
):
idx
=
list
(
block
.
ops
).
index
(
origin_op
)
block
.
_remove_op
(
idx
)
break
except
Exception
as
e
:
print
(
e
)
python/paddle/fluid/incubate/fleet/parameter_server/ir/ufind.py
已删除
100644 → 0
浏览文件 @
e84fa263
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
UnionFind
:
"""Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def
__init__
(
self
,
elementes
=
None
):
self
.
_parents
=
[]
# index -> parent index
self
.
_index
=
{}
# element -> index
self
.
_curr_idx
=
0
if
not
elementes
:
elementes
=
[]
for
ele
in
elementes
:
self
.
_parents
.
append
(
self
.
_curr_idx
)
self
.
_index
.
update
({
ele
:
self
.
_curr_idx
})
self
.
_curr_idx
+=
1
def
find
(
self
,
x
):
# Find the root index of given element x,
# execute the path compress while findind the root index
if
not
x
in
self
.
_index
:
return
-
1
idx
=
self
.
_index
[
x
]
while
idx
!=
self
.
_parents
[
idx
]:
t
=
self
.
_parents
[
idx
]
self
.
_parents
[
idx
]
=
self
.
_parents
[
t
]
idx
=
t
return
idx
def
union
(
self
,
x
,
y
):
# Union two given element
x_root
=
self
.
find
(
x
)
y_root
=
self
.
find
(
y
)
if
x_root
==
y_root
:
return
self
.
_parents
[
x_root
]
=
y_root
def
is_connected
(
self
,
x
,
y
):
# If two given elements have the same root index,
# then they are connected.
return
self
.
find
(
x
)
==
self
.
find
(
y
)
python/paddle/fluid/incubate/fleet/parameter_server/ir/vars_metatools.py
已删除
100644 → 0
浏览文件 @
e84fa263
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
from
paddle.framework.io
import
Variable
from
paddle.framework
import
core
dtype_to_size
=
{
core
.
VarDesc
.
VarType
.
FP16
:
2
,
core
.
VarDesc
.
VarType
.
FP32
:
4
,
core
.
VarDesc
.
VarType
.
FP64
:
8
,
core
.
VarDesc
.
VarType
.
INT16
:
2
,
core
.
VarDesc
.
VarType
.
INT32
:
4
,
core
.
VarDesc
.
VarType
.
INT64
:
8
,
core
.
VarDesc
.
VarType
.
BOOL
:
1
,
core
.
VarDesc
.
VarType
.
UINT8
:
1
,
}
class
VarBlock
:
def
__init__
(
self
,
varname
,
offset
,
size
):
self
.
varname
=
varname
# NOTE: real offset is offset * size
self
.
offset
=
offset
self
.
size
=
size
def
__str__
(
self
):
return
"%s:%d:%d"
%
(
self
.
varname
,
self
.
offset
,
self
.
size
)
def
create_var_struct
(
var
):
if
var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
lod_level
=
None
elif
var
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR
:
lod_level
=
var
.
lod_level
else
:
raise
ValueError
(
"can only support SELECTED_ROWS/LOD_TENSOR now"
)
return
VarStruct
(
var
.
name
,
var
.
shape
,
var
.
dtype
,
var
.
type
,
lod_level
,
var
.
persistable
)
class
VarStruct
:
"""
record part properties of a Variable in python.
"""
def
__init__
(
self
,
name
,
shape
,
dtype
,
type
,
lod_level
,
persistable
):
self
.
name
=
name
self
.
shape
=
shape
self
.
dtype
=
dtype
self
.
type
=
type
self
.
lod_level
=
lod_level
self
.
persistable
=
persistable
self
.
m_size
=
1
self
.
m_size
=
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
self
.
m_size
*=
dtype_to_size
[
dtype
]
def
__str__
(
self
):
return
"N: {}, S: {}, D: {}, T: {}, LL: {}, P: {}, M: {}"
.
format
(
self
.
name
,
self
.
shape
,
self
.
dtype
,
self
.
type
,
self
.
lod_level
,
self
.
persistable
,
self
.
m_size
,
)
class
VarDistributed
:
"""
a class to record the var distributed on parameter servers.
the class will record the relationship between origin var and slice var.
the slice var's properties, such as type/shape/offset/endpoint.
"""
def
__init__
(
self
,
origin_var
,
slice_var
,
is_slice
=
None
,
block_id
=
None
,
offset
=
None
,
vtype
=
None
,
endpoint
=
None
,
):
"""
Args:
origin_var(Variable|VarStruct): origin var properties
slice_var(Variable|VarStruct): slice var properties
is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard.
block_id(int|None): the number about the slice var.
offset(int|None): if the slice var is sliced, offset is the numel before the var.
vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch.
endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001"
"""
if
isinstance
(
origin_var
,
Variable
):
self
.
origin
=
create_var_struct
(
origin_var
)
else
:
self
.
origin
=
origin_var
if
isinstance
(
slice_var
,
Variable
):
self
.
slice
=
create_var_struct
(
slice_var
)
else
:
self
.
slice
=
slice_var
if
self
.
equal
(
self
.
origin
,
self
.
slice
):
self
.
is_slice
=
False
self
.
block_id
=
0
self
.
offset
=
0
else
:
self
.
is_slice
=
True
self
.
block_id
=
0
self
.
offset
=
0
if
is_slice
is
not
None
:
self
.
is_slice
=
is_slice
if
block_id
is
not
None
:
self
.
block_id
=
block_id
if
offset
is
not
None
:
self
.
offset
=
offset
self
.
vtype
=
vtype
self
.
endpoint
=
endpoint
@
staticmethod
def
equal
(
var1
,
var2
):
"""
the two var is equal or not.
Returns:
bool: equal will return True else False
"""
assert
isinstance
(
var1
,
VarStruct
)
and
isinstance
(
var2
,
VarStruct
)
return
(
var1
.
name
==
var2
.
name
and
var1
.
type
==
var2
.
type
and
var1
.
shape
==
var2
.
shape
and
var1
.
dtype
==
var2
.
dtype
and
var1
.
lod_level
==
var2
.
lod_level
and
var1
.
persistable
==
var2
.
persistable
)
def
__str__
(
self
):
origin_var_str
=
(
"{name} : fluid.{type}.shape{shape}.astype({dtype})"
.
format
(
i
=
"{"
,
e
=
"}"
,
name
=
self
.
origin
.
name
,
type
=
self
.
origin
.
type
,
shape
=
self
.
origin
.
shape
,
dtype
=
self
.
origin
.
dtype
,
)
)
slice_var_str
=
(
"{name} : fluid.{type}.shape{shape}.astype({dtype})"
".slice({is_slice}).block({block_id}).offset({offset})"
.
format
(
i
=
"{"
,
e
=
"}"
,
name
=
self
.
slice
.
name
,
type
=
self
.
slice
.
type
,
shape
=
self
.
slice
.
shape
,
dtype
=
self
.
slice
.
dtype
,
is_slice
=
self
.
is_slice
,
block_id
=
self
.
block_id
,
offset
=
self
.
offset
,
)
)
return
"var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} "
.
format
(
self
.
vtype
,
origin_var_str
,
slice_var_str
,
self
.
endpoint
)
class
VarsDistributed
:
"""
a gather about VarDistributed with many methods to find distributed vars.
through the class, we can get overview about the distributed parameters on parameter servers.
this class may centralized and convenient for developer to manage and get variable's distribute.
other module can also use this to find variables such io.py.
"""
def
__init__
(
self
):
self
.
distributed_vars
=
[]
def
add_distributed_var
(
self
,
origin_var
,
slice_var
,
is_slice
=
None
,
block_id
=
None
,
offset
=
None
,
vtype
=
None
,
endpoint
=
None
,
):
"""
add distributed var in this.
Args:
origin_var(Variable|VarStruct): origin var properties
slice_var(Variable|VarStruct): slice var properties
is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard.
block_id(int|None): the number about the slice var.
offset(int|None): if the slice var is sliced, offset is the numel before the var.
vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch.
endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001"
Returns:
None
"""
self
.
distributed_vars
.
append
(
VarDistributed
(
origin_var
,
slice_var
,
is_slice
,
block_id
,
offset
,
vtype
,
endpoint
,
)
)
python/paddle/fluid/tests/unittests/ps/ps_dnn_trainer.py
浏览文件 @
1882a74f
...
@@ -444,7 +444,7 @@ class DnnTrainer:
...
@@ -444,7 +444,7 @@ class DnnTrainer:
print
(
print
(
"entering run {} - old"
.
format
(
str
(
config
[
"applied_pass_name"
]))
"entering run {} - old"
.
format
(
str
(
config
[
"applied_pass_name"
]))
)
)
from
paddle.
fluid.
incubate.fleet.parameter_server.ir
import
(
from
paddle.incubate.fleet.parameter_server.ir
import
(
public
as
public
,
public
as
public
,
)
)
...
@@ -458,7 +458,7 @@ class DnnTrainer:
...
@@ -458,7 +458,7 @@ class DnnTrainer:
_main
=
compiled_config
.
origin_main_program
.
clone
()
_main
=
compiled_config
.
origin_main_program
.
clone
()
_startup
=
compiled_config
.
origin_startup_program
.
clone
()
_startup
=
compiled_config
.
origin_startup_program
.
clone
()
from
paddle.
fluid.
incubate.fleet.parameter_server.ir
import
(
from
paddle.incubate.fleet.parameter_server.ir
import
(
trainer_pass
as
worker
,
trainer_pass
as
worker
,
)
)
...
...
python/paddle/fluid/tests/unittests/test_fleet_ps.py
浏览文件 @
1882a74f
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
unittest
import
unittest
from
paddle.fluid.framework
import
default_main_program
from
paddle.fluid.framework
import
default_main_program
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.pserver_pass
import
(
from
paddle.incubate.fleet.parameter_server.ir.pserver_pass
import
(
_get_optimizer_input_shape
,
_get_optimizer_input_shape
,
)
)
...
...
python/paddle/fluid/tests/unittests/test_ps_dispatcher.py
浏览文件 @
1882a74f
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
unittest
import
unittest
from
paddle.
fluid.
incubate.fleet.parameter_server.ir.ps_dispatcher
import
(
from
paddle.incubate.fleet.parameter_server.ir.ps_dispatcher
import
(
HashName
,
HashName
,
PSDispatcher
,
PSDispatcher
,
RoundRobin
,
RoundRobin
,
...
...
python/paddle/incubate/fleet/parameter_server/ir/pserver_pass.py
浏览文件 @
1882a74f
...
@@ -14,7 +14,8 @@
...
@@ -14,7 +14,8 @@
import
collections
import
collections
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
from
paddle.framework
import
core
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_lr_ops
,
_get_lr_ops
,
_get_optimize_ops
,
_get_optimize_ops
,
_get_varname_parts
,
_get_varname_parts
,
...
@@ -23,7 +24,6 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
...
@@ -23,7 +24,6 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablenames
,
get_sparse_tablenames
,
is_distributed_sparse_op
,
is_distributed_sparse_op
,
)
)
from
paddle.framework
import
core
LEARNING_RATE_DECAY_COUNTER
=
"@LR_DECAY_COUNTER@"
LEARNING_RATE_DECAY_COUNTER
=
"@LR_DECAY_COUNTER@"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
...
...
python/paddle/incubate/fleet/parameter_server/ir/public.py
浏览文件 @
1882a74f
...
@@ -19,12 +19,10 @@ import warnings
...
@@ -19,12 +19,10 @@ import warnings
from
functools
import
reduce
from
functools
import
reduce
import
paddle
import
paddle
from
paddle.fluid.incubate.fleet.parameter_server.ir
import
vars_metatools
from
paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher
import
(
RoundRobin
,
)
from
paddle.fluid.incubate.fleet.parameter_server.mode
import
DistributedMode
from
paddle.fluid.incubate.fleet.parameter_server.mode
import
DistributedMode
from
paddle.framework
import
core
from
paddle.framework
import
core
from
paddle.incubate.fleet.parameter_server.ir
import
vars_metatools
from
paddle.incubate.fleet.parameter_server.ir.ps_dispatcher
import
RoundRobin
OP_NAME_SCOPE
=
"op_namescope"
OP_NAME_SCOPE
=
"op_namescope"
CLIP_OP_NAME_SCOPE
=
"gradient_clip"
CLIP_OP_NAME_SCOPE
=
"gradient_clip"
...
...
python/paddle/incubate/fleet/parameter_server/ir/trainer_pass.py
浏览文件 @
1882a74f
...
@@ -21,12 +21,12 @@ from functools import reduce
...
@@ -21,12 +21,12 @@ from functools import reduce
import
paddle
import
paddle
import
paddle.framework
as
framework
import
paddle.framework
as
framework
from
paddle.distributed.transpiler.details.program_utils
import
delete_ops
from
paddle.distributed.transpiler.details.program_utils
import
delete_ops
from
paddle.fluid.incubate.fleet.parameter_server.ir.public
import
(
from
paddle.framework
import
core
from
paddle.incubate.fleet.parameter_server.ir.public
import
(
_get_lr_ops
,
_get_lr_ops
,
_get_optimize_ops
,
_get_optimize_ops
,
get_sparse_tablenames
,
get_sparse_tablenames
,
)
)
from
paddle.framework
import
core
from
paddle.incubate.fleet.parameter_server.mode
import
DistributedMode
from
paddle.incubate.fleet.parameter_server.mode
import
DistributedMode
OP_NAME_SCOPE
=
"op_namescope"
OP_NAME_SCOPE
=
"op_namescope"
...
...
python/setup.py.in
浏览文件 @
1882a74f
...
@@ -407,7 +407,6 @@ packages=['paddle',
...
@@ -407,7 +407,6 @@ packages=['paddle',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.collective',
'paddle.fluid.incubate.fleet.collective',
'paddle.fluid.incubate.fleet.utils',
'paddle.fluid.incubate.fleet.utils',
'paddle.fluid.incubate.fleet.parameter_server.ir',
'paddle.fluid.incubate.fleet.parameter_server',
'paddle.fluid.incubate.fleet.parameter_server',
'paddle.amp',
'paddle.amp',
'paddle.cost_model',
'paddle.cost_model',
...
...
setup.py
浏览文件 @
1882a74f
...
@@ -1301,7 +1301,6 @@ def get_setup_parameters():
...
@@ -1301,7 +1301,6 @@ def get_setup_parameters():
'paddle.fluid.incubate.fleet.collective'
,
'paddle.fluid.incubate.fleet.collective'
,
'paddle.fluid.incubate.fleet.utils'
,
'paddle.fluid.incubate.fleet.utils'
,
'paddle.fluid.incubate.fleet.parameter_server'
,
'paddle.fluid.incubate.fleet.parameter_server'
,
'paddle.fluid.incubate.fleet.parameter_server.ir'
,
'paddle.amp'
,
'paddle.amp'
,
'paddle.cost_model'
,
'paddle.cost_model'
,
'paddle.hapi'
,
'paddle.hapi'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录