Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d92a75be
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d92a75be
编写于
5月 29, 2018
作者:
Y
Yancey
提交者:
GitHub
5月 29, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10550 from Yancey1989/overlap_send_op
overlap send ops and backward ops
上级
7655e4cf
5d7c58e4
变更
26
显示空白变更内容
内联
并排
Showing
26 changed file
with
559 addition
and
283 deletion
+559
-283
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-2
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+104
-37
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+14
-8
paddle/fluid/framework/details/rpc_op_handle.cc
paddle/fluid/framework/details/rpc_op_handle.cc
+9
-7
paddle/fluid/framework/details/rpc_op_handle.h
paddle/fluid/framework/details/rpc_op_handle.h
+4
-3
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+1
-1
paddle/fluid/framework/op_proto_maker.h
paddle/fluid/framework/op_proto_maker.h
+1
-0
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+2
-2
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+3
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+17
-3
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+13
-1
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+4
-4
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+87
-0
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+1
-21
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+15
-6
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+15
-26
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+2
-22
paddle/fluid/operators/send_recv_op_test.cc
paddle/fluid/operators/send_recv_op_test.cc
+7
-6
paddle/fluid/operators/send_recv_util.h
paddle/fluid/operators/send_recv_util.h
+3
-0
paddle/fluid/operators/send_vars_op.cc
paddle/fluid/operators/send_vars_op.cc
+5
-21
paddle/fluid/pybind/const_value.cc
paddle/fluid/pybind/const_value.cc
+2
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+8
-6
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+7
-4
python/paddle/fluid/transpiler/__init__.py
python/paddle/fluid/transpiler/__init__.py
+2
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+153
-100
python/paddle/fluid/transpiler/ps_dispatcher.py
python/paddle/fluid/transpiler/ps_dispatcher.py
+78
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
d92a75be
...
...
@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
send_op_handle SRCS send
_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc
_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
...
...
@@ -26,7 +26,7 @@ endif()
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle
send
_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
scale_loss_grad_op_handle
rpc
_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
d92a75be
...
...
@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -28,6 +29,10 @@
#include <string>
#include <vector>
DEFINE_string
(
ssa_graph_path
,
"/tmp/ssa_graph.dot"
,
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot"
);
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
}
}
bool
MultiDevSSAGraphBuilder
::
IsDistTrainOp
(
const
OpDesc
&
op
,
OpDesc
*
send_op
)
const
{
if
(
send_op
==
nullptr
)
{
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
const
ProgramDesc
&
program
)
const
{
std
::
vector
<
std
::
string
>
send_vars
;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if
(
op
->
Type
()
==
"send_vars"
)
{
auto
op_vars
=
op
->
InputArgumentNames
();
send_vars
.
reserve
(
send_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
send_vars
.
insert
(
send_vars
.
end
(),
op_vars
.
begin
(),
op_vars
.
end
());
}
}
return
send_vars
;
}
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainRecvVars
(
const
ProgramDesc
&
program
)
const
{
std
::
vector
<
std
::
string
>
recv_vars
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if
(
op
->
Type
()
==
"recv"
)
{
auto
op_vars
=
op
->
OutputArgumentNames
();
recv_vars
.
reserve
(
recv_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
recv_vars
.
insert
(
recv_vars
.
end
(),
op_vars
.
begin
(),
op_vars
.
end
());
}
}
return
recv_vars
;
}
bool
MultiDevSSAGraphBuilder
::
IsDistTrainOp
(
const
OpDesc
&
op
,
const
std
::
vector
<
std
::
string
>
&
send_vars
,
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
{
if
(
send_vars
.
size
()
==
0
||
recv_vars
.
size
()
==
0
)
{
return
false
;
}
...
...
@@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
* Check any of opvars contains `.block` and in sendvars
*/
auto
checker
=
[](
const
std
::
vector
<
std
::
string
>
&
opvars
,
const
std
::
vector
<
std
::
string
>
&
send
vars
)
->
bool
{
const
std
::
vector
<
std
::
string
>
&
rpc_
vars
)
->
bool
{
for
(
auto
&
var
:
opvars
)
{
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if
(
var
.
find
(
".block"
)
!=
std
::
string
::
npos
&&
std
::
find
(
sendvars
.
begin
(),
sendvars
.
end
(),
var
)
!=
send
vars
.
end
())
{
std
::
find
(
rpc_vars
.
begin
(),
rpc_vars
.
end
(),
var
)
!=
rpc_
vars
.
end
())
{
return
true
;
}
}
return
false
;
};
if
(
op
.
Type
()
==
"split"
||
op
.
Type
()
==
"split_byref"
)
{
return
checker
(
op
.
OutputArgumentNames
(),
send_op
->
InputArgumentNames
());
}
else
if
(
op
.
Type
()
==
"concat"
)
{
return
checker
(
op
.
InputArgumentNames
(),
send_op
->
OutputArgumentNames
());
}
return
false
;
return
checker
(
op
.
OutputArgumentNames
(),
send_vars
)
||
checker
(
op
.
InputArgumentNames
(),
recv_vars
);
}
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
...
...
@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
places_
.
size
());
// Find "send" op first for split is in front of send.
OpDesc
*
send_op
=
GetSendOpDesc
(
program
);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
auto
send_vars
=
FindDistTrainSendVars
(
program
);
auto
recv_vars
=
FindDistTrainRecvVars
(
program
);
size_t
cur_device_id
=
0
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
...
...
@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"send"
)
{
// append send op if program is distributed trainer main program.
if
(
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
// append rpc op if program is distributed trainer main program.
// always use the first device
Create
Send
Op
(
&
result
,
*
op
);
}
else
if
(
IsDistTrainOp
(
*
op
,
send_
op
))
{
Create
ComputationalOps
(
&
result
,
*
op
,
1
);
Create
RPC
Op
(
&
result
,
*
op
);
}
else
if
(
IsDistTrainOp
(
*
op
,
send_
vars
,
recv_vars
))
{
Create
DistTrainOp
(
&
result
,
*
op
);
}
else
if
(
IsScaleLossOp
(
*
op
))
{
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
...
...
@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps
(
&
result
);
if
(
VLOG_IS_ON
(
10
))
{
std
::
ostringstream
sout
;
PrintGraphviz
(
*
graph
,
sout
);
VLOG
(
10
)
<<
sout
.
str
();
std
::
ofstream
fout
(
FLAGS_ssa_graph_path
);
PrintGraphviz
(
*
graph
,
fout
);
}
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
...
...
@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs
(
result
,
op
,
dev_id
);
}
OpDesc
*
MultiDevSSAGraphBuilder
::
GetSendOpDesc
(
const
ProgramDesc
&
program
)
const
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"send"
)
{
return
op
;
}
}
return
nullptr
;
}
void
MultiDevSSAGraphBuilder
::
InsertNCCLAllReduceOp
(
SSAGraph
*
result
,
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
...
...
@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return
var
;
}
void
MultiDevSSAGraphBuilder
::
CreateSendOp
(
SSAGraph
*
result
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
SSAGraph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
ops_
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
();
prev_op
->
AddOutput
(
dep_var
);
result
->
dep_vars_
.
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
CreateComputationalOp
(
result
,
op
,
0
);
if
(
op
.
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"fetch_barrier"
);
}
}
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
auto
&
p
=
places_
[
0
];
auto
*
s
=
local_scopes_
[
0
];
// FIXME(wuyi): send op always copy from GPU 0
result
->
ops_
.
emplace_back
(
new
SendOpHandle
(
op
,
s
,
p
));
// Create inputs for output on original place and no ssa output
// is created for send op.
result
->
ops_
.
emplace_back
(
new
RPCOpHandle
(
op
,
s
,
p
,
op
.
Type
()));
if
(
op
.
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send_vars"
);
}
else
if
(
op
.
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send_barrier"
);
}
else
if
(
op
.
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"recv"
);
}
else
if
(
op
.
Type
()
==
"send_vars"
)
{
// do nothing
}
else
{
PADDLE_THROW
(
"rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]"
);
}
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs
(
result
,
op
,
0
);
}
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
d92a75be
...
...
@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
void
CreateSendOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool
IsDistTrainOp
(
const
OpDesc
&
op
,
OpDesc
*
send_op
)
const
;
bool
IsDistTrainOp
(
const
OpDesc
&
op
,
const
std
::
vector
<
std
::
string
>
&
send_vars
,
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
const
ProgramDesc
&
program
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
ProgramDesc
&
program
)
const
;
void
ConnectOp
(
SSAGraph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
;
void
CreateComputationalOps
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
size_t
num_places
)
const
;
...
...
@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void
CreateBroadcastOp
(
SSAGraph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc
*
GetSendOpDesc
(
const
ProgramDesc
&
program
)
const
;
bool
IsSparseGradient
(
const
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
&
var_types
,
const
std
::
string
&
og
)
const
;
...
...
paddle/fluid/framework/details/
send
_op_handle.cc
→
paddle/fluid/framework/details/
rpc
_op_handle.cc
浏览文件 @
d92a75be
...
...
@@ -12,24 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/
send
_op_handle.h"
#include "paddle/fluid/framework/details/
rpc
_op_handle.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
SendOpHandle
::
Send
OpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scop
e
,
const
platform
::
Place
&
plac
e
)
RPCOpHandle
::
RPC
OpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
plac
e
,
const
std
::
string
&
nam
e
)
:
op_
(
framework
::
OpRegistry
::
CreateOp
(
op_desc
)),
local_scope_
(
local_scope
),
place_
(
place
)
{}
place_
(
place
),
name_
(
name
)
{}
void
Send
OpHandle
::
RunImpl
()
{
void
RPC
OpHandle
::
RunImpl
()
{
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if
(
in
->
DebugString
()
==
"dummy"
)
{
// HACK
continue
;
}
...
...
@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() {
op_
->
Run
(
*
tmp_scope
,
place_
);
}
std
::
string
SendOpHandle
::
Name
()
const
{
return
"send"
;
}
std
::
string
RPCOpHandle
::
Name
()
const
{
return
name_
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/
send
_op_handle.h
→
paddle/fluid/framework/details/
rpc
_op_handle.h
浏览文件 @
d92a75be
...
...
@@ -27,9 +27,9 @@ namespace paddle {
namespace
framework
{
namespace
details
{
struct
Send
OpHandle
:
public
OpHandleBase
{
Send
OpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
plac
e
);
struct
RPC
OpHandle
:
public
OpHandleBase
{
RPC
OpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
place
,
const
std
::
string
&
nam
e
);
std
::
string
Name
()
const
override
;
...
...
@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase {
std
::
unique_ptr
<
OperatorBase
>
op_
;
const
Scope
*
local_scope_
;
const
platform
::
Place
&
place_
;
const
std
::
string
name_
;
};
}
// namespace details
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
d92a75be
...
...
@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.
InEnum
(
{
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kOptimize
),
static_cast
<
int
>
(
OpRole
::
kOptimize
),
static_cast
<
int
>
(
OpRole
::
kRPC
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kBackward
),
...
...
paddle/fluid/framework/op_proto_maker.h
浏览文件 @
d92a75be
...
...
@@ -24,6 +24,7 @@ enum class OpRole {
kForward
=
0x0000
,
kBackward
=
0x0001
,
kOptimize
=
0x0002
,
kRPC
=
0x0003
,
kLoss
=
0x0100
,
// The default value of op's role. This should be only used for unittests and
...
...
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
浏览文件 @
d92a75be
...
...
@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
auto
nodes
=
trait
.
nodes
();
in
t
count
=
0
;
size_
t
count
=
0
;
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
++
count
;
...
...
@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg
.
Build
();
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
auto
nodes
=
trait
.
nodes_in_DFS
();
in
t
count
=
0
;
size_
t
count
=
0
;
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
++
count
;
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
d92a75be
...
...
@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE)
op_library
(
send_vars_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
send_vars_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
op_library
(
send_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
fetch_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
send_barrier_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL)
...
...
@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE)
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
endif
()
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op
fetch_barrier_op
gen_nccl_id_op
)
endif
()
op_library
(
cross_entropy_op DEPS cross_entropy
)
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
d92a75be
...
...
@@ -25,6 +25,21 @@ namespace paddle {
namespace
operators
{
namespace
detail
{
std
::
once_flag
RPCClient
::
init_flag_
;
std
::
unique_ptr
<
RPCClient
>
RPCClient
::
rpc_client_
(
nullptr
);
RPCClient
*
RPCClient
::
GetInstance
()
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
);
return
rpc_client_
.
get
();
}
void
RPCClient
::
Init
()
{
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
RPCClient
());
}
}
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
...
...
@@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
});
req_count_
++
;
return
true
;
...
...
@@ -249,8 +263,9 @@ bool RPCClient::Proceed() {
delete
c
;
return
true
;
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
// TODO(Yancey1989): make grpc client completely thread-safe
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
return
it
->
second
;
...
...
@@ -263,7 +278,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto
ch
=
grpc
::
CreateCustomChannel
(
ep
,
grpc
::
InsecureChannelCredentials
(),
args
);
channels_
[
ep
]
=
ch
;
return
ch
;
}
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
d92a75be
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include <functional>
#include <iostream>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <vector>
...
...
@@ -35,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace
paddle
{
namespace
operators
{
...
...
@@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
class
RPCClient
{
public:
RPCClient
()
{}
static
RPCClient
*
GetInstance
();
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
...
...
@@ -191,11 +197,17 @@ class RPCClient {
private:
bool
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
// Init is called by GetInstance.
static
void
Init
();
private:
grpc
::
CompletionQueue
cq_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
int64_t
req_count_
=
0
;
std
::
atomic
<
int64_t
>
req_count_
{
0
};
std
::
mutex
mutex_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
static
std
::
once_flag
init_flag_
;
DISABLE_COPY_AND_ASSIGN
(
RPCClient
);
};
}
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
d92a75be
...
...
@@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) {
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
detail
::
RPCClient
client
;
client
.
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
auto
client
=
detail
::
RPCClient
::
GetInstance
()
;
client
->
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
.
Wait
();
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
...
...
paddle/fluid/operators/fetch_barrier_op.cc
0 → 100644
浏览文件 @
d92a75be
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
class
FetchBarrierOp
:
public
framework
::
OperatorBase
{
public:
FetchBarrierOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"fetch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
};
class
FetchBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddComment
(
R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
.
SetDefault
({
"127.0.0.1:6164"
});
}
};
class
FetchBarrierOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fetch_barrier
,
ops
::
FetchBarrierOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
FetchBarrierOpMaker
,
ops
::
FetchBarrierOpShapeInference
);
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
d92a75be
...
...
@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
...
@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void
Make
()
{
AddInput
(
"X"
,
"(LoDTensor) Input Id variables to be sent"
).
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which will be"
"initialized at most once."
);
AddOutput
(
"Out"
,
"(LoDTensor) result "
"to be fetched from parameter server"
)
...
...
@@ -87,17 +79,6 @@ the parameter server and fetch result back.
}
};
class
PrefetchOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
PrefetchOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
...
@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
prefetch
,
ops
::
PrefetchOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
PrefetchOpMaker
,
ops
::
PrefetchOpVarTypeInference
,
ops
::
PrefetchOpShapeInference
);
paddle/fluid/operators/recv_op.cc
浏览文件 @
d92a75be
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
int
sync_mode
=
Attr
<
int
>
(
"sync_mode"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
if
(
sync_mode
)
{
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
PADDLE_ENFORCE
(
client_
.
Wait
());
}
private:
mutable
detail
::
RPCClient
client_
;
};
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -65,6 +70,10 @@ This operator can get variables from server side.
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({});
AddAttr
<
int
>
(
"sync_mode"
,
"(int, default 0)"
"sync recv or async recv."
)
.
SetDefault
(
0
);
}
};
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
d92a75be
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
(
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
sync_mode
)
{
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"send barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
}
};
class
SendBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
SendBarrier operator
...
...
@@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
.
SetDefault
({
"127.0.0.1:6164"
});
}
};
class
SendBarrierOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
AddAttr
<
bool
>
(
"sync_mode"
,
"work in sync_mode or not"
).
SetDefault
(
true
);
}
};
...
...
@@ -98,5 +88,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
send_barrier
,
ops
::
SendBarrierOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendBarrierOpMaker
,
ops
::
SendBarrierOpVarTypeInference
,
ops
::
SendBarrierOpShapeInference
);
paddle/fluid/operators/send_op.cc
浏览文件 @
d92a75be
...
...
@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
...
@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"(Tensor) Input tensor to be sent"
).
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Output tensor to be received from server"
)
.
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
Send operator
...
...
@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
}
};
class
SendOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
SendOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
...
@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
send
,
ops
::
SendOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendOpMaker
,
ops
::
SendOpVarTypeInference
,
ops
::
SendOpShapeInference
);
ops
::
SendOpMaker
,
ops
::
SendOpShapeInference
);
paddle/fluid/operators/send_recv_op_test.cc
浏览文件 @
d92a75be
...
...
@@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) {
std
::
thread
server_thread
(
StartServerNet
,
false
,
&
initialized
);
while
(
!
initialized
)
{
}
static_cast
<
paddle
::
operators
::
ListenAndServOp
*>
(
listen_and_serv_op
.
get
())
->
WaitServerReady
();
...
...
@@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) {
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
"send"
,
{{
"X"
,
{
"x1"
}}},
{{
"Out"
,
{
"Out"
}},
{
"RPCClient"
,
{
"RPC_CLIENT_VAR"
}}},
attrs
);
const
f
::
VariableNameMap
&
inputs
=
{{
"X"
,
{
"x1"
}}};
const
f
::
VariableNameMap
&
outputs
=
{{
"Out"
,
{
"Out"
}}};
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
"send"
,
inputs
,
outputs
,
attrs
);
send_op
->
Run
(
scope
,
place
);
auto
in_var
=
scope
.
Var
(
"x1"
);
...
...
@@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) {
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
"send"
,
{{
"X"
,
{
"x1"
}}},
{{
"Out"
,
{
"Out"
}},
{
"RPCClient"
,
{
"RPC_CLIENT_VAR"
}}},
attrs
);
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
"send"
,
{{
"X"
,
{
"x1"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
send_op
->
Run
(
scope
,
place
);
auto
x0
=
scope
.
Var
(
"x0"
)
->
GetMutable
<
f
::
SelectedRows
>
();
...
...
paddle/fluid/operators/send_recv_util.h
浏览文件 @
d92a75be
...
...
@@ -20,6 +20,9 @@ namespace operators {
inline
bool
NeedSend
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
varname
)
{
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if
(
varname
==
"dummy"
)
return
false
;
auto
*
var
=
scope
.
FindVar
(
varname
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
varname
);
...
...
paddle/fluid/operators/send_vars_op.cc
浏览文件 @
d92a75be
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
...
@@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
{
AddInput
(
"X"
,
"(Tensor, SelectedRows) Input variables to be sent"
)
.
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which will be"
"initialized at most once."
);
AddComment
(
R"DOC(
Send operator
...
...
@@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
}
};
class
SendVarsOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
SendVarsOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
...
@@ -112,5 +97,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
send_vars
,
ops
::
SendVarsOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendVarsOpMaker
,
ops
::
SendVarsOpVarTypeInference
,
ops
::
SendVarsOpShapeInference
);
paddle/fluid/pybind/const_value.cc
浏览文件 @
d92a75be
...
...
@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
.
value
(
"Forward"
,
framework
::
OpRole
::
kForward
)
.
value
(
"Backward"
,
framework
::
OpRole
::
kBackward
)
.
value
(
"Optimize"
,
framework
::
OpRole
::
kOptimize
)
.
value
(
"Loss"
,
framework
::
OpRole
::
kLoss
);
.
value
(
"Loss"
,
framework
::
OpRole
::
kLoss
)
.
value
(
"RPC"
,
framework
::
OpRole
::
kRPC
);
op_proto_and_checker_maker
.
def
(
"kOpRoleAttrName"
,
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
);
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
d92a75be
...
...
@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints
=
list
(
set
(
epmap
))
helper
=
LayerHelper
(
"Send"
,
**
locals
())
rpc_client_var
=
default_main_program
().
global_block
().
create_var
(
name
=
"RPC_CLIENT_VAR"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
if
not
get_vars
:
get_vars
=
[]
for
s
in
send_vars
:
v
=
helper
.
create_tmp_variable
(
dtype
=
s
.
dtype
,
stop_gradient
=
True
)
get_vars
.
append
(
v
)
rpc_op_role_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
helper
.
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
send_vars
},
outputs
=
{
"Out"
:
get_vars
,
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"endpoints"
:
endpoints
,
"epmap"
:
epmap
})
outputs
=
{
"Out"
:
get_vars
},
attrs
=
{
"endpoints"
:
endpoints
,
"epmap"
:
epmap
,
rpc_op_role_name
:
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
})
return
get_vars
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
d92a75be
...
...
@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase):
def
test_transpiler
(
self
):
trainer
=
self
.
get_trainer
()
pserver
,
startup
=
self
.
get_pserver
(
self
.
current_pserver_ep
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
self
.
get_expect_trainer_ops
())
...
...
@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase):
optimize_ops
,
params_grads
=
self
.
net_conf
()
delete_ops
(
trainer
.
global_block
(),
optimize_ops
)
return
[
op
.
type
for
op
in
trainer
.
global_block
().
ops
]
+
[
"split_byref"
,
"send"
,
"concat"
]
ops
=
[
op
.
type
for
op
in
trainer
.
global_block
().
ops
]
+
[
"split_byref"
,
"send_vars"
,
"send_barrier"
,
"recv"
,
"recv"
,
"fetch_barrier"
,
"concat"
]
ops
.
insert
(
ops
.
index
(
"elementwise_add_grad"
)
+
1
,
"send_vars"
)
return
ops
def
get_trainer
(
self
):
return
self
.
_transpiler_instance
().
get_trainer_program
()
...
...
python/paddle/fluid/transpiler/__init__.py
浏览文件 @
d92a75be
...
...
@@ -16,8 +16,9 @@ from distribute_transpiler import DistributeTranspiler
from
inference_transpiler
import
InferenceTranspiler
from
memory_optimization_transpiler
import
memory_optimize
,
release_memory
from
distribute_transpiler_simple
import
SimpleDistributeTranspiler
from
ps_dispatcher
import
HashName
,
RoundRobin
__all__
=
[
"DistributeTranspiler"
,
"InferenceTranspiler"
,
"SimpleDistributeTranspiler"
,
"memory_optimize"
,
"release_memory"
"memory_optimize"
,
"release_memory"
,
"HashName"
,
"RoundRobin"
]
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
d92a75be
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
math
import
distributed_splitter
as
splitt
er
from
ps_dispatcher
import
RoundRobin
,
HashName
,
PSDispatch
er
from
..
import
core
,
framework
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
\
...
...
@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
RPC_CLIENT_VAR_NAME
=
"RPC_CLIENT_VAR"
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
class
VarBlock
:
...
...
@@ -149,13 +151,27 @@ def delete_ops(block, ops):
block
.
program
.
sync_with_cpp
()
def
find_op_by_input_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
input_arg_names
:
return
index
return
-
1
def
find_op_by_output_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
output_arg_names
:
return
index
return
-
1
class
DistributeTranspiler
:
def
transpile
(
self
,
trainer_id
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
split_method
=
splitter
.
round_r
obin
,
split_method
=
RoundR
obin
,
sync_mode
=
True
):
"""
Transpile the program to distributed data-parallelism programs.
...
...
@@ -196,7 +212,7 @@ class DistributeTranspiler:
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert
(
callable
(
split_method
)
)
assert
(
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
if
program
is
None
:
program
=
default_main_program
()
self
.
origin_program
=
program
...
...
@@ -209,6 +225,7 @@ class DistributeTranspiler:
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
optimize_ops
,
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
split_method
(
pserver_endpoints
)
# process lookup_table_op
# 1. check all lookup_table_op is distributed
...
...
@@ -268,54 +285,110 @@ class DistributeTranspiler:
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
assert
(
len
(
grad_blocks
)
==
len
(
param_blocks
))
# step2: Create new vars for the parameters and gradients blocks and
# add ops to do the split.
grad_var_mapping
=
self
.
_append_split_op
(
program
,
grad_blocks
)
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
param_blocks
)
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
grad_blocks
,
add_trainer_suffix
=
self
.
trainer_num
>
1
)
grad_param_mapping
=
dict
()
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
grad_param_mapping
[
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
param_var_mapping
[
p_name
][
int
(
p_bid
)]
# step 3: transpile trainer side program, insert recv op and send op.
# step3: Add gradients as send op inputs and parameters as send
# op outputs.
send_inputs
=
[]
send_outputs
=
[]
for
b
in
grad_blocks
:
# append by order
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_inputs
.
append
(
grad_var_mapping
[
varname
][
int
(
block_id
)])
for
b
in
param_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_outputs
.
append
(
param_var_mapping
[
varname
][
int
(
block_id
)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist
=
split_method
(
send_inputs
,
pserver_endpoints
)
# create mapping of endpoint -> split var to create pserver side program
self
.
param_grad_ep_mapping
=
dict
()
for
i
,
ep
in
enumerate
(
eplist
):
param
=
send_outputs
[
i
]
grad
=
send_inputs
[
i
]
if
not
self
.
param_grad_ep_mapping
.
has_key
(
ep
):
self
.
param_grad_ep_mapping
[
ep
]
=
{
"params"
:
[],
"grads"
:
[]}
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
param
)
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
grad
)
rpc_client_var
=
program
.
global_block
().
create_var
(
name
=
RPC_CLIENT_VAR_NAME
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
[
self
.
param_grad_ep_mapping
.
update
({
ep
:
{
"params"
:
[],
"grads"
:
[]
}
})
for
ep
in
self
.
pserver_endpoints
]
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher
.
reset
()
send_vars
=
[]
for
orig_varname
,
splited_vars
in
grad_var_mapping
.
items
():
eplist
=
ps_dispatcher
.
dispatch
(
splited_vars
)
if
len
(
splited_vars
)
==
1
:
orig_varname
=
splited_vars
[
0
].
name
index
=
find_op_by_output_arg
(
program
.
global_block
(),
orig_varname
)
elif
len
(
splited_vars
)
>
1
:
orig_var
=
program
.
global_block
().
vars
[
orig_varname
]
index
=
find_op_by_output_arg
(
program
.
global_block
(),
orig_varname
)
self
.
_insert_split_op
(
program
,
orig_var
,
index
,
splited_vars
)
index
+=
1
else
:
AssertionError
(
"Can not insert the send op by original "
"variable name :"
,
orig_varname
)
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
type
=
"send_vars"
,
inputs
=
{
"X"
:
splited_vars
},
outputs
=
{},
attrs
=
{
"epmap"
:
eplist
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
# create send_op
if
self
.
sync_mode
:
program
.
global_block
().
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
send_inputs
},
outputs
=
{
"Out"
:
send_outputs
,
"RPCClient"
:
rpc_client_var
},
type
=
"send_barrier"
,
inputs
=
{},
outputs
=
{},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"epmap"
:
eplist
,
"sync_mode"
:
self
.
sync_mode
"sync_mode"
:
self
.
sync_mode
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
# step 3.2: insert recv op to receive parameters from parameter server
recv_vars
=
[]
for
_
,
var
in
enumerate
(
send_vars
):
recv_vars
.
append
(
grad_param_mapping
[
var
])
ps_dispatcher
.
reset
()
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
for
i
,
ep
in
enumerate
(
eplist
):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
eps
=
[]
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
eps
.
append
(
eplist
[
index
])
program
.
global_block
().
append_op
(
type
=
"recv"
,
inputs
=
{},
outputs
=
{
"Out"
:
splited_var
},
attrs
=
{
"epmap"
:
eps
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
program
.
global_block
().
append_op
(
type
=
"fetch_barrier"
,
inputs
=
{},
outputs
=
{},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
continue
...
...
@@ -327,10 +400,8 @@ class DistributeTranspiler:
attrs
=
{
"axis"
:
0
})
if
self
.
has_distributed_lookup_table
:
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
rpc_client_var
,
eplist
)
self
.
_split_table_grad_and_add_send_vars
(
program
,
rpc_client_var
,
pserver_endpoints
)
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
eplist
)
self
.
_split_table_grad_and_add_send_vars
(
program
,
pserver_endpoints
)
def
get_trainer_program
(
self
):
# remove optimize ops and add a send op to main_program
...
...
@@ -550,8 +621,7 @@ class DistributeTranspiler:
return
s_prog
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
rpc_client_var
,
eplist
):
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
eplist
):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self
.
prefetch_input_vars
=
None
self
.
prefetch_output_vars
=
None
...
...
@@ -598,11 +668,11 @@ class DistributeTranspiler:
index
=
op_index
+
1
,
type
=
"prefetch"
,
inputs
=
{
'X'
:
self
.
prefetch_input_vars
},
outputs
=
{
"Out"
:
self
.
prefetch_output_vars
,
"
RPCClient"
:
rpc_client_var
},
attrs
=
{
"epmap"
:
eplist
})
outputs
=
{
"Out"
:
self
.
prefetch_output_vars
},
attrs
=
{
"
epmap"
:
eplist
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
# insert concat_op
program
.
global_block
().
insert_op
(
...
...
@@ -622,8 +692,7 @@ class DistributeTranspiler:
# break for loop
break
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
rpc_client_var
,
pserver_endpoints
):
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
pserver_endpoints
):
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
...
...
@@ -643,9 +712,12 @@ class DistributeTranspiler:
index
=
op_index
+
2
,
type
=
"send_vars"
,
inputs
=
{
'X'
:
self
.
table_grad_list
},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"sync_send"
:
True
,
"epmap"
:
pserver_endpoints
})
outputs
=
{},
attrs
=
{
"sync_send"
:
True
,
"epmap"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
break
def
_create_prefetch_block
(
self
,
pserver_index
,
pserver_program
,
...
...
@@ -838,32 +910,13 @@ class DistributeTranspiler:
lod_level
=
var
.
lod_level
,
persistable
=
persistable
)
def
_append_split_op
(
self
,
program
,
gradblocks
):
"""
Split variables that need to be split and append respective ops
Args:
program (ProgramDesc): ProgramDesc that gradients blong.
gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks.
Returns:
var_mapping (dict(varname->[new_splitted_variable])):A dict mapping
from original var name to each var split.
"""
add_suffix
=
False
if
self
.
trainer_num
>
1
:
add_suffix
=
True
var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
gradblocks
,
add_trainer_suffix
=
add_suffix
)
for
varname
,
splited_vars
in
var_mapping
.
iteritems
():
# variable that don't need to split have empty splited_vars
if
len
(
splited_vars
)
<=
1
:
continue
orig_var
=
program
.
global_block
().
vars
[
varname
]
def
_insert_split_op
(
self
,
program
,
orig_var
,
index
,
splited_vars
):
if
orig_var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
height_sections
=
[]
for
v
in
splited_vars
:
height_sections
.
append
(
v
.
shape
[
0
])
program
.
global_block
().
append_op
(
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
type
=
"split_selected_rows"
,
inputs
=
{
"X"
:
orig_var
},
outputs
=
{
"Out"
:
splited_vars
},
...
...
@@ -872,7 +925,8 @@ class DistributeTranspiler:
sections
=
[]
for
v
in
splited_vars
:
sections
.
append
(
v
.
shape
[
0
])
program
.
global_block
().
append_op
(
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
type
=
"split_byref"
,
inputs
=
{
"X"
:
orig_var
},
outputs
=
{
"Out"
:
splited_vars
},
...
...
@@ -881,7 +935,6 @@ class DistributeTranspiler:
else
:
AssertionError
(
"Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]"
)
return
var_mapping
def
_get_optimizer_input_shape
(
self
,
op_type
,
varkey
,
orig_shape
,
param_shape
):
...
...
python/paddle/fluid/transpiler/
distributed_splitt
er.py
→
python/paddle/fluid/transpiler/
ps_dispatch
er.py
浏览文件 @
d92a75be
...
...
@@ -13,45 +13,66 @@
# limitations under the License.
def
hash_name
(
varlist
,
pserver_endpoints
):
class
PSDispatcher
(
object
):
"""
hash variable names to several endpoints.
PSDispatcher is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` inferface.
"""
def
__init__
(
self
,
pserver_endpoints
):
self
.
_eps
=
pserver_endpoints
self
.
_step
=
0
@
property
def
eps
(
self
):
return
self
.
_eps
def
reset
(
self
):
self
.
_step
=
0
def
dispatch
(
self
,
varlist
):
"""
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
"""
AssertionError
(
"Interface has not been implemented."
)
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
class
HashName
(
PSDispatcher
):
"""
Hash variable names to several endpoints
"""
def
_hash_block
(
block_str
,
total
):
def
__init__
(
self
,
pserver_endpoints
):
super
(
self
.
__class__
,
self
).
__init__
(
pserver_endpoints
)
def
_hash_block
(
self
,
block_str
,
total
):
return
hash
(
block_str
)
%
total
def
dispatch
(
self
,
varlist
):
eplist
=
[]
for
var
in
varlist
:
server_id
=
_hash_block
(
var
.
name
(),
len
(
pserver_endpoint
s
))
server_for_param
=
pserver_endpoint
s
[
server_id
]
server_id
=
self
.
_hash_block
(
var
.
name
(),
len
(
self
.
_ep
s
))
server_for_param
=
self
.
_ep
s
[
server_id
]
eplist
.
append
(
server_for_param
)
return
eplist
def
round_robin
(
varlist
,
pserver_endpoints
):
class
RoundRobin
(
PSDispatcher
):
"""
Distribute variables to several endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
Distribute variables to serveral endpoints.
"""
assert
(
len
(
varlist
)
>=
len
(
pserver_endpoints
))
def
__init__
(
self
,
pserver_endpoints
):
super
(
self
.
__class__
,
self
).
__init__
(
pserver_endpoints
)
def
dispatch
(
self
,
varlist
):
eplist
=
[]
pserver_idx
=
0
for
var
in
varlist
:
server_for_param
=
pserver_endpoints
[
pserver_idx
]
server_for_param
=
self
.
_eps
[
self
.
_step
]
eplist
.
append
(
server_for_param
)
pserver_idx
+=
1
if
pserver_idx
>=
len
(
pserver_endpoints
):
pserver_idx
=
0
self
.
_step
+=
1
if
self
.
_step
>=
len
(
self
.
_eps
):
self
.
_step
=
0
return
eplist
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录