Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
b1e51836
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1e51836
编写于
5月 10, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
overlap sendop and backward ops
上级
2a22da6c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
167 addition
and
55 deletion
+167
-55
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+12
-6
python/paddle/fluid/transpiler/__init__.py
python/paddle/fluid/transpiler/__init__.py
+2
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+75
-48
python/paddle/fluid/transpiler/ps_dispatcher.py
python/paddle/fluid/transpiler/ps_dispatcher.py
+78
-0
未找到文件。
paddle/fluid/operators/recv_op.cc
浏览文件 @
b1e51836
...
@@ -36,19 +36,22 @@ class RecvOp : public framework::OperatorBase {
...
@@ -36,19 +36,22 @@ class RecvOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
outs
=
Outputs
(
"Out"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
]
;
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
PADDLE_ENFORCE
(
client_
.
Wait
());
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
}
private:
mutable
detail
::
RPCClient
client_
;
};
};
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -56,6 +59,9 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -56,6 +59,9 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
RecvOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
RecvOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"(Tensor) Variables to get from server."
).
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Variables to get from server."
).
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Recv operator
Recv operator
...
...
python/paddle/fluid/transpiler/__init__.py
浏览文件 @
b1e51836
...
@@ -15,8 +15,9 @@ from distribute_transpiler import DistributeTranspiler
...
@@ -15,8 +15,9 @@ from distribute_transpiler import DistributeTranspiler
from
inference_transpiler
import
InferenceTranspiler
from
inference_transpiler
import
InferenceTranspiler
from
memory_optimization_transpiler
import
memory_optimize
,
release_memory
from
memory_optimization_transpiler
import
memory_optimize
,
release_memory
from
distribute_transpiler_simple
import
SimpleDistributeTranspiler
from
distribute_transpiler_simple
import
SimpleDistributeTranspiler
from
ps_dispatcher
import
HashName
,
RoundRobin
__all__
=
[
__all__
=
[
"DistributeTranspiler"
,
"InferenceTranspiler"
,
"SimpleDistributeTranspiler"
,
"DistributeTranspiler"
,
"InferenceTranspiler"
,
"SimpleDistributeTranspiler"
,
"memory_optimize"
,
"release_memory"
"memory_optimize"
,
"release_memory"
,
"HashName"
,
"RoundRobin"
]
]
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b1e51836
...
@@ -17,7 +17,8 @@ from __future__ import print_function
...
@@ -17,7 +17,8 @@ from __future__ import print_function
import
math
import
math
import
distributed_splitter
as
splitter
import
distributed_splitter
as
splitter
from
..
import
core
from
ps_dispatcher
import
RoundRobin
,
HashName
,
PSDispatcher
from
..
import
core
,
framework
from
..framework
import
Program
,
default_main_program
,
Variable
,
Parameter
from
..framework
import
Program
,
default_main_program
,
Variable
,
Parameter
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_TYPE
=
"lookup_table"
...
@@ -144,13 +145,27 @@ def delete_ops(block, ops):
...
@@ -144,13 +145,27 @@ def delete_ops(block, ops):
block
.
program
.
sync_with_cpp
()
block
.
program
.
sync_with_cpp
()
def
find_op_by_input_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
input_arg_names
:
return
index
return
-
1
def
find_op_by_output_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
output_arg_names
:
return
index
return
-
1
class
DistributeTranspiler
:
class
DistributeTranspiler
:
def
transpile
(
self
,
def
transpile
(
self
,
trainer_id
,
trainer_id
,
program
=
None
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
trainers
=
1
,
split_method
=
splitter
.
round_r
obin
,
split_method
=
RoundR
obin
,
sync_mode
=
True
):
sync_mode
=
True
):
"""
"""
Transpile the program to distributed data-parallelism programs.
Transpile the program to distributed data-parallelism programs.
...
@@ -184,14 +199,14 @@ class DistributeTranspiler:
...
@@ -184,14 +199,14 @@ class DistributeTranspiler:
:type pservers: string
:type pservers: string
:param trainers: total number of workers/trainers in the job
:param trainers: total number of workers/trainers in the job
:type trainers: int
:type trainers: int
:param split_method: A
function to determin how to split variables
:param split_method: A
instance to determin how to dispatch variable
to different servers equally.
blocks
to different servers equally.
:type split_method:
function
:type split_method:
A instance based on PSDispatcher class.
:type sync_mode: boolean default True
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
will transpile the program into sync_mode pserver and trainer program.
"""
"""
assert
(
callable
(
split_method
)
)
assert
(
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
if
program
is
None
:
if
program
is
None
:
program
=
default_main_program
()
program
=
default_main_program
()
self
.
origin_program
=
program
self
.
origin_program
=
program
...
@@ -204,6 +219,7 @@ class DistributeTranspiler:
...
@@ -204,6 +219,7 @@ class DistributeTranspiler:
pserver_endpoints
=
pservers
.
split
(
","
)
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
pserver_endpoints
=
pserver_endpoints
self
.
optimize_ops
,
params_grads
=
self
.
_get_optimize_pass
()
self
.
optimize_ops
,
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
split_method
(
pserver_endpoints
)
# process lookup_table_op
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 1. check all lookup_table_op is distributed
...
@@ -268,56 +284,67 @@ class DistributeTranspiler:
...
@@ -268,56 +284,67 @@ class DistributeTranspiler:
grad_var_mapping
=
self
.
_append_split_op
(
program
,
grad_blocks
)
grad_var_mapping
=
self
.
_append_split_op
(
program
,
grad_blocks
)
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
param_blocks
)
param_blocks
)
# step3: Add gradients as send op inputs and parameters as send
# op outputs.
send_inputs
=
[]
send_outputs
=
[]
for
b
in
grad_blocks
:
# append by order
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_inputs
.
append
(
grad_var_mapping
[
varname
][
int
(
block_id
)])
for
b
in
param_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_outputs
.
append
(
param_var_mapping
[
varname
][
int
(
block_id
)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist
=
split_method
(
send_inputs
,
pserver_endpoints
)
# create mapping of endpoint -> split var to create pserver side program
self
.
param_grad_ep_mapping
=
dict
()
for
i
,
ep
in
enumerate
(
eplist
):
param
=
send_outputs
[
i
]
grad
=
send_inputs
[
i
]
if
not
self
.
param_grad_ep_mapping
.
has_key
(
ep
):
self
.
param_grad_ep_mapping
[
ep
]
=
{
"params"
:
[],
"grads"
:
[]}
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
param
)
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
grad
)
rpc_client_var
=
program
.
global_block
().
create_var
(
rpc_client_var
=
program
.
global_block
().
create_var
(
name
=
RPC_CLIENT_VAR_NAME
,
name
=
RPC_CLIENT_VAR_NAME
,
persistable
=
True
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
type
=
core
.
VarDesc
.
VarType
.
RAW
)
# create send_op
# step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program
self
.
param_grad_ep_mapping
=
dict
()
[
self
.
param_grad_ep_mapping
.
update
({
ep
:
{
"params"
:
[],
"grads"
:
[]
}
})
for
ep
in
self
.
pserver_endpoints
]
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher
.
reset
()
for
varname
,
send_vars
in
grad_var_mapping
.
items
():
index
=
find_op_by_output_arg
(
program
.
global_block
(),
varname
)
eplist
=
ps_dispatcher
.
dispatch
(
send_vars
)
program
.
global_block
().
insert_op
(
index
=
index
,
type
=
"send_vars"
,
inputs
=
{
"X"
:
send_vars
},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"epmap"
:
eplist
})
if
self
.
sync_mode
:
program
.
global_block
().
append_op
(
type
=
"send_barrier"
,
inputs
=
{},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"endpoints"
:
pserver_endpoints
})
# step 3.2: insert recv op to receive parameters from parameter server
ps_dispatcher
.
reset
()
recv_vars
=
[]
for
b
in
param_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
recv_vars
.
append
(
param_var_mapping
[
varname
][
int
(
block_id
)])
for
b
in
grad_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_vars
.
append
(
grad_var_mapping
[
varname
][
int
(
block_id
)])
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
for
i
,
ep
in
enumerate
(
eplist
):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"
send
"
,
type
=
"
recv
"
,
inputs
=
{
"X"
:
send_inputs
},
inputs
=
{},
outputs
=
{
"Out"
:
send_output
s
,
outputs
=
{
"Out"
:
recv_var
s
,
"RPCClient"
:
rpc_client_var
},
"RPCClient"
:
rpc_client_var
},
attrs
=
{
attrs
=
{
"epmap"
:
eplist
})
"endpoints"
:
pserver_endpoints
,
"epmap"
:
eplist
,
"sync_mode"
:
self
.
sync_mode
})
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
continue
orig_param
=
program
.
global_block
().
vars
[
varname
]
program
.
global_block
().
append_op
(
type
=
"concat"
,
inputs
=
{
"X"
:
splited_var
},
outputs
=
{
"Out"
:
[
orig_param
]},
attrs
=
{
"axis"
:
0
})
# TODO(Yancey1989): check dist lookup table
if
self
.
has_distributed_lookup_table
:
if
self
.
has_distributed_lookup_table
:
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
rpc_client_var
,
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
rpc_client_var
,
eplist
)
eplist
)
...
...
python/paddle/fluid/transpiler/
distributed_splitt
er.py
→
python/paddle/fluid/transpiler/
ps_dispatch
er.py
浏览文件 @
b1e51836
...
@@ -13,45 +13,66 @@
...
@@ -13,45 +13,66 @@
# limitations under the License.
# limitations under the License.
def
hash_name
(
varlist
,
pserver_endpoints
):
class
PSDispatcher
(
object
):
"""
"""
hash variable names to several endpoints.
DistributedSpliter is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` inferface.
"""
def
__init__
(
self
,
pserver_endpoints
):
self
.
_eps
=
pserver_endpoints
self
.
_step
=
0
@
property
def
eps
(
self
):
return
self
.
_eps
def
reset
(
self
):
self
.
_step
=
0
def
dispatch
(
self
,
varlist
):
"""
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
"""
AssertionError
(
"Interface has not been implemented."
)
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
class
HashName
(
PSDispatcher
):
"""
"""
Hash variable names to servral endpoints
"""
def
__init__
(
self
,
pserver_endpoints
):
super
(
self
.
__class__
,
self
).
__init__
(
pserver_endpoints
)
def
_hash_block
(
block_str
,
total
):
def
_hash_block
(
self
,
block_str
,
total
):
return
hash
(
block_str
)
%
total
return
hash
(
block_str
)
%
total
eplist
=
[]
def
dispatch
(
self
,
varlist
):
for
var
in
varlist
:
eplist
=
[]
server_id
=
_hash_block
(
var
.
name
(),
len
(
pserver_endpoints
))
for
var
in
varlist
:
server_for_param
=
pserver_endpoints
[
server_id
]
server_id
=
self
.
_hash_block
(
var
.
name
(),
len
(
self
.
_eps
))
eplist
.
append
(
server_for_param
)
server_for_param
=
self
.
_eps
[
server_id
]
return
eplist
eplist
.
append
(
server_for_param
)
return
eplist
def
round_robin
(
varlist
,
pserver_endpoints
):
class
RoundRobin
(
PSDispatcher
):
"""
"""
Distribute variables to several endpoints.
Distribute variables to serveral endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
"""
"""
assert
(
len
(
varlist
)
>=
len
(
pserver_endpoints
))
def
__init__
(
self
,
pserver_endpoints
):
eplist
=
[]
super
(
self
.
__class__
,
self
).
__init__
(
pserver_endpoints
)
pserver_idx
=
0
for
var
in
varlist
:
def
dispatch
(
self
,
varlist
):
server_for_param
=
pserver_endpoints
[
pserver_idx
]
eplist
=
[]
eplist
.
append
(
server_for_param
)
for
var
in
varlist
:
server_for_param
=
self
.
_eps
[
self
.
_step
]
pserver_idx
+=
1
eplist
.
append
(
server_for_param
)
if
pserver_idx
>=
len
(
pserver_endpoints
):
self
.
_step
+=
1
pserver_idx
=
0
if
self
.
_step
>=
len
(
self
.
_eps
):
return
eplist
self
.
_step
=
0
return
eplist
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录