Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d92a75be
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d92a75be
编写于
5月 29, 2018
作者:
Y
Yancey
提交者:
GitHub
5月 29, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10550 from Yancey1989/overlap_send_op
overlap send ops and backward ops
上级
7655e4cf
5d7c58e4
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
559 addition
and
283 deletion
+559
-283
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-2
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+104
-37
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+14
-8
paddle/fluid/framework/details/rpc_op_handle.cc
paddle/fluid/framework/details/rpc_op_handle.cc
+9
-7
paddle/fluid/framework/details/rpc_op_handle.h
paddle/fluid/framework/details/rpc_op_handle.h
+4
-3
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+1
-1
paddle/fluid/framework/op_proto_maker.h
paddle/fluid/framework/op_proto_maker.h
+1
-0
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+2
-2
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+3
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+17
-3
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+13
-1
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+4
-4
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+87
-0
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+1
-21
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+15
-6
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+15
-26
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+2
-22
paddle/fluid/operators/send_recv_op_test.cc
paddle/fluid/operators/send_recv_op_test.cc
+7
-6
paddle/fluid/operators/send_recv_util.h
paddle/fluid/operators/send_recv_util.h
+3
-0
paddle/fluid/operators/send_vars_op.cc
paddle/fluid/operators/send_vars_op.cc
+5
-21
paddle/fluid/pybind/const_value.cc
paddle/fluid/pybind/const_value.cc
+2
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+8
-6
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+7
-4
python/paddle/fluid/transpiler/__init__.py
python/paddle/fluid/transpiler/__init__.py
+2
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+153
-100
python/paddle/fluid/transpiler/ps_dispatcher.py
python/paddle/fluid/transpiler/ps_dispatcher.py
+78
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
d92a75be
...
@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
...
@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
send_op_handle SRCS send
_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc
_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
...
@@ -26,7 +26,7 @@ endif()
...
@@ -26,7 +26,7 @@ endif()
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle
send
_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
scale_loss_grad_op_handle
rpc
_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
d92a75be
...
@@ -12,12 +12,13 @@
...
@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -28,6 +29,10 @@
...
@@ -28,6 +29,10 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
DEFINE_string
(
ssa_graph_path
,
"/tmp/ssa_graph.dot"
,
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot"
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
...
@@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
...
@@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
}
}
}
}
bool
MultiDevSSAGraphBuilder
::
IsDistTrainOp
(
const
OpDesc
&
op
,
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
OpDesc
*
send_op
)
const
{
const
ProgramDesc
&
program
)
const
{
if
(
send_op
==
nullptr
)
{
std
::
vector
<
std
::
string
>
send_vars
;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if
(
op
->
Type
()
==
"send_vars"
)
{
auto
op_vars
=
op
->
InputArgumentNames
();
send_vars
.
reserve
(
send_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
send_vars
.
insert
(
send_vars
.
end
(),
op_vars
.
begin
(),
op_vars
.
end
());
}
}
return
send_vars
;
}
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainRecvVars
(
const
ProgramDesc
&
program
)
const
{
std
::
vector
<
std
::
string
>
recv_vars
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if
(
op
->
Type
()
==
"recv"
)
{
auto
op_vars
=
op
->
OutputArgumentNames
();
recv_vars
.
reserve
(
recv_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
recv_vars
.
insert
(
recv_vars
.
end
(),
op_vars
.
begin
(),
op_vars
.
end
());
}
}
return
recv_vars
;
}
bool
MultiDevSSAGraphBuilder
::
IsDistTrainOp
(
const
OpDesc
&
op
,
const
std
::
vector
<
std
::
string
>
&
send_vars
,
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
{
if
(
send_vars
.
size
()
==
0
||
recv_vars
.
size
()
==
0
)
{
return
false
;
return
false
;
}
}
...
@@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
...
@@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
* Check any of opvars contains `.block` and in sendvars
* Check any of opvars contains `.block` and in sendvars
*/
*/
auto
checker
=
[](
const
std
::
vector
<
std
::
string
>
&
opvars
,
auto
checker
=
[](
const
std
::
vector
<
std
::
string
>
&
opvars
,
const
std
::
vector
<
std
::
string
>
&
send
vars
)
->
bool
{
const
std
::
vector
<
std
::
string
>
&
rpc_
vars
)
->
bool
{
for
(
auto
&
var
:
opvars
)
{
for
(
auto
&
var
:
opvars
)
{
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if
(
var
.
find
(
".block"
)
!=
std
::
string
::
npos
&&
if
(
var
.
find
(
".block"
)
!=
std
::
string
::
npos
&&
std
::
find
(
sendvars
.
begin
(),
sendvars
.
end
(),
var
)
!=
send
vars
.
end
())
{
std
::
find
(
rpc_vars
.
begin
(),
rpc_vars
.
end
(),
var
)
!=
rpc_
vars
.
end
())
{
return
true
;
return
true
;
}
}
}
}
return
false
;
return
false
;
};
};
if
(
op
.
Type
()
==
"split"
||
op
.
Type
()
==
"split_byref"
)
{
return
checker
(
op
.
OutputArgumentNames
(),
send_vars
)
||
return
checker
(
op
.
OutputArgumentNames
(),
send_op
->
InputArgumentNames
());
checker
(
op
.
InputArgumentNames
(),
recv_vars
);
}
else
if
(
op
.
Type
()
==
"concat"
)
{
return
checker
(
op
.
InputArgumentNames
(),
send_op
->
OutputArgumentNames
());
}
return
false
;
}
}
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
...
@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
places_
.
size
());
places_
.
size
());
// Find "send" op first for split is in front of send.
// find send/recv vars so that we can place the distributed training
OpDesc
*
send_op
=
GetSendOpDesc
(
program
);
// realted op in the place 0
auto
send_vars
=
FindDistTrainSendVars
(
program
);
auto
recv_vars
=
FindDistTrainRecvVars
(
program
);
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
var_name_on_devices
;
...
@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"send"
)
{
if
(
boost
::
get
<
int
>
(
// append send op if program is distributed trainer main program.
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
// append rpc op if program is distributed trainer main program.
// always use the first device
// always use the first device
Create
Send
Op
(
&
result
,
*
op
);
Create
RPC
Op
(
&
result
,
*
op
);
}
else
if
(
IsDistTrainOp
(
*
op
,
send_
op
))
{
}
else
if
(
IsDistTrainOp
(
*
op
,
send_
vars
,
recv_vars
))
{
Create
ComputationalOps
(
&
result
,
*
op
,
1
);
Create
DistTrainOp
(
&
result
,
*
op
);
}
else
if
(
IsScaleLossOp
(
*
op
))
{
}
else
if
(
IsScaleLossOp
(
*
op
))
{
// user can customize loss@grad if not use_default_grad_scale_
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
if
(
strategy_
.
gradient_scale_
!=
...
@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps
(
&
result
);
AddOutputToLeafOps
(
&
result
);
if
(
VLOG_IS_ON
(
10
))
{
if
(
VLOG_IS_ON
(
10
))
{
std
::
ostringstream
sout
;
std
::
ofstream
fout
(
FLAGS_ssa_graph_path
);
PrintGraphviz
(
*
graph
,
sout
);
PrintGraphviz
(
*
graph
,
fout
);
VLOG
(
10
)
<<
sout
.
str
();
}
}
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
...
@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
...
@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs
(
result
,
op
,
dev_id
);
CreateOpHandleIOs
(
result
,
op
,
dev_id
);
}
}
OpDesc
*
MultiDevSSAGraphBuilder
::
GetSendOpDesc
(
const
ProgramDesc
&
program
)
const
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"send"
)
{
return
op
;
}
}
return
nullptr
;
}
void
MultiDevSSAGraphBuilder
::
InsertNCCLAllReduceOp
(
void
MultiDevSSAGraphBuilder
::
InsertNCCLAllReduceOp
(
SSAGraph
*
result
,
const
std
::
string
&
og
)
const
{
SSAGraph
*
result
,
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
...
@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return
var
;
return
var
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateSendOp
(
SSAGraph
*
result
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
SSAGraph
*
result
,
OpHandleBase
*
op
,
const
OpDesc
&
op
)
const
{
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
ops_
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
();
prev_op
->
AddOutput
(
dep_var
);
result
->
dep_vars_
.
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
CreateComputationalOp
(
result
,
op
,
0
);
if
(
op
.
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"fetch_barrier"
);
}
}
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
auto
&
p
=
places_
[
0
];
auto
&
p
=
places_
[
0
];
auto
*
s
=
local_scopes_
[
0
];
auto
*
s
=
local_scopes_
[
0
];
// FIXME(wuyi): send op always copy from GPU 0
result
->
ops_
.
emplace_back
(
new
RPCOpHandle
(
op
,
s
,
p
,
op
.
Type
()));
result
->
ops_
.
emplace_back
(
new
SendOpHandle
(
op
,
s
,
p
));
// Create inputs for output on original place and no ssa output
if
(
op
.
Type
()
==
"send_barrier"
)
{
// is created for send op.
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send_vars"
);
}
else
if
(
op
.
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send_barrier"
);
}
else
if
(
op
.
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"recv"
);
}
else
if
(
op
.
Type
()
==
"send_vars"
)
{
// do nothing
}
else
{
PADDLE_THROW
(
"rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]"
);
}
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs
(
result
,
op
,
0
);
CreateOpHandleIOs
(
result
,
op
,
0
);
}
}
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
d92a75be
...
@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
void
CreateSendOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
/**
/**
* Is this operator as the end-point operator before/after send operator.
* Is this operator as the end-point operator before/after send operator.
*/
*/
bool
IsDistTrainOp
(
const
OpDesc
&
op
,
OpDesc
*
send_op
)
const
;
bool
IsDistTrainOp
(
const
OpDesc
&
op
,
const
std
::
vector
<
std
::
string
>
&
send_vars
,
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
const
ProgramDesc
&
program
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
ProgramDesc
&
program
)
const
;
void
ConnectOp
(
SSAGraph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
;
void
CreateComputationalOps
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
void
CreateComputationalOps
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
size_t
num_places
)
const
;
size_t
num_places
)
const
;
...
@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void
CreateBroadcastOp
(
SSAGraph
*
result
,
const
std
::
string
&
p_name
,
void
CreateBroadcastOp
(
SSAGraph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
size_t
src_dev_id
)
const
;
/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc
*
GetSendOpDesc
(
const
ProgramDesc
&
program
)
const
;
bool
IsSparseGradient
(
bool
IsSparseGradient
(
const
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
&
var_types
,
const
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
&
var_types
,
const
std
::
string
&
og
)
const
;
const
std
::
string
&
og
)
const
;
...
...
paddle/fluid/framework/details/
send
_op_handle.cc
→
paddle/fluid/framework/details/
rpc
_op_handle.cc
浏览文件 @
d92a75be
...
@@ -12,24 +12,26 @@
...
@@ -12,24 +12,26 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/
send
_op_handle.h"
#include "paddle/fluid/framework/details/
rpc
_op_handle.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
SendOpHandle
::
Send
OpHandle
(
const
framework
::
OpDesc
&
op_desc
,
RPCOpHandle
::
RPC
OpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scop
e
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
plac
e
,
const
platform
::
Place
&
plac
e
)
const
std
::
string
&
nam
e
)
:
op_
(
framework
::
OpRegistry
::
CreateOp
(
op_desc
)),
:
op_
(
framework
::
OpRegistry
::
CreateOp
(
op_desc
)),
local_scope_
(
local_scope
),
local_scope_
(
local_scope
),
place_
(
place
)
{}
place_
(
place
),
name_
(
name
)
{}
void
Send
OpHandle
::
RunImpl
()
{
void
RPC
OpHandle
::
RunImpl
()
{
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if
(
in
->
DebugString
()
==
"dummy"
)
{
// HACK
if
(
in
->
DebugString
()
==
"dummy"
)
{
// HACK
continue
;
continue
;
}
}
...
@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() {
...
@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() {
op_
->
Run
(
*
tmp_scope
,
place_
);
op_
->
Run
(
*
tmp_scope
,
place_
);
}
}
std
::
string
SendOpHandle
::
Name
()
const
{
return
"send"
;
}
std
::
string
RPCOpHandle
::
Name
()
const
{
return
name_
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/
send
_op_handle.h
→
paddle/fluid/framework/details/
rpc
_op_handle.h
浏览文件 @
d92a75be
...
@@ -27,9 +27,9 @@ namespace paddle {
...
@@ -27,9 +27,9 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
struct
Send
OpHandle
:
public
OpHandleBase
{
struct
RPC
OpHandle
:
public
OpHandleBase
{
Send
OpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
RPC
OpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
plac
e
);
const
platform
::
Place
&
place
,
const
std
::
string
&
nam
e
);
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
...
@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase {
...
@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase {
std
::
unique_ptr
<
OperatorBase
>
op_
;
std
::
unique_ptr
<
OperatorBase
>
op_
;
const
Scope
*
local_scope_
;
const
Scope
*
local_scope_
;
const
platform
::
Place
&
place_
;
const
platform
::
Place
&
place_
;
const
std
::
string
name_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
d92a75be
...
@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
...
@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.
InEnum
(
.
InEnum
(
{
static_cast
<
int
>
(
OpRole
::
kForward
),
{
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kOptimize
),
static_cast
<
int
>
(
OpRole
::
kOptimize
),
static_cast
<
int
>
(
OpRole
::
kRPC
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kBackward
),
...
...
paddle/fluid/framework/op_proto_maker.h
浏览文件 @
d92a75be
...
@@ -24,6 +24,7 @@ enum class OpRole {
...
@@ -24,6 +24,7 @@ enum class OpRole {
kForward
=
0x0000
,
kForward
=
0x0000
,
kBackward
=
0x0001
,
kBackward
=
0x0001
,
kOptimize
=
0x0002
,
kOptimize
=
0x0002
,
kRPC
=
0x0003
,
kLoss
=
0x0100
,
kLoss
=
0x0100
,
// The default value of op's role. This should be only used for unittests and
// The default value of op's role. This should be only used for unittests and
...
...
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
浏览文件 @
d92a75be
...
@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
...
@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
auto
nodes
=
trait
.
nodes
();
auto
nodes
=
trait
.
nodes
();
in
t
count
=
0
;
size_
t
count
=
0
;
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
++
count
;
++
count
;
...
@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
...
@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg
.
Build
();
dfg
.
Build
();
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
auto
nodes
=
trait
.
nodes_in_DFS
();
auto
nodes
=
trait
.
nodes_in_DFS
();
in
t
count
=
0
;
size_
t
count
=
0
;
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
++
count
;
++
count
;
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
d92a75be
...
@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE)
...
@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE)
op_library
(
send_vars_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
send_vars_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
send_vars_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
send_vars_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
op_library
(
send_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
send_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
fetch_barrier_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
send_barrier_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
send_barrier_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL)
# listen_and_serv_op sum_op executor SERIAL)
...
@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE)
...
@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE)
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
endif
()
endif
()
else
()
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op
fetch_barrier_op
gen_nccl_id_op
)
endif
()
endif
()
op_library
(
cross_entropy_op DEPS cross_entropy
)
op_library
(
cross_entropy_op DEPS cross_entropy
)
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
d92a75be
...
@@ -25,6 +25,21 @@ namespace paddle {
...
@@ -25,6 +25,21 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
std
::
once_flag
RPCClient
::
init_flag_
;
std
::
unique_ptr
<
RPCClient
>
RPCClient
::
rpc_client_
(
nullptr
);
RPCClient
*
RPCClient
::
GetInstance
()
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
);
return
rpc_client_
.
get
();
}
void
RPCClient
::
Init
()
{
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
RPCClient
());
}
}
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
...
@@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
...
@@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
call
->
StartCall
();
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
});
});
req_count_
++
;
req_count_
++
;
return
true
;
return
true
;
...
@@ -249,8 +263,9 @@ bool RPCClient::Proceed() {
...
@@ -249,8 +263,9 @@ bool RPCClient::Proceed() {
delete
c
;
delete
c
;
return
true
;
return
true
;
}
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
// TODO(Yancey1989): make grpc client completely thread-safe
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
channels_
.
find
(
ep
);
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
if
(
it
!=
channels_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
...
@@ -263,7 +278,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
...
@@ -263,7 +278,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto
ch
=
auto
ch
=
grpc
::
CreateCustomChannel
(
ep
,
grpc
::
InsecureChannelCredentials
(),
args
);
grpc
::
CreateCustomChannel
(
ep
,
grpc
::
InsecureChannelCredentials
(),
args
);
channels_
[
ep
]
=
ch
;
channels_
[
ep
]
=
ch
;
return
ch
;
return
ch
;
}
}
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
d92a75be
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include <functional>
#include <functional>
#include <iostream>
#include <iostream>
#include <map>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -35,6 +36,7 @@ limitations under the License. */
...
@@ -35,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
...
@@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
class
RPCClient
{
class
RPCClient
{
public:
public:
RPCClient
()
{}
static
RPCClient
*
GetInstance
();
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
...
@@ -191,11 +197,17 @@ class RPCClient {
...
@@ -191,11 +197,17 @@ class RPCClient {
private:
private:
bool
Proceed
();
bool
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
// Init is called by GetInstance.
static
void
Init
();
private:
private:
grpc
::
CompletionQueue
cq_
;
grpc
::
CompletionQueue
cq_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
int64_t
req_count_
=
0
;
std
::
atomic
<
int64_t
>
req_count_
{
0
};
std
::
mutex
mutex_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
static
std
::
once_flag
init_flag_
;
DISABLE_COPY_AND_ASSIGN
(
RPCClient
);
};
};
}
// namespace detail
}
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
d92a75be
...
@@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) {
...
@@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) {
std
::
string
in_var_name
(
"ids"
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
std
::
string
out_var_name
(
"out"
);
detail
::
RPCClient
client
;
auto
client
=
detail
::
RPCClient
::
GetInstance
()
;
client
.
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
client
->
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
out_var_name
);
client
.
Wait
();
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
...
...
paddle/fluid/operators/fetch_barrier_op.cc
0 → 100644
浏览文件 @
d92a75be
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
class
FetchBarrierOp
:
public
framework
::
OperatorBase
{
public:
FetchBarrierOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"fetch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
};
class
FetchBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddComment
(
R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
.
SetDefault
({
"127.0.0.1:6164"
});
}
};
class
FetchBarrierOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fetch_barrier
,
ops
::
FetchBarrierOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
FetchBarrierOpMaker
,
ops
::
FetchBarrierOpShapeInference
);
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
d92a75be
...
@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
...
@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void
Make
()
{
void
Make
()
{
AddInput
(
"X"
,
"(LoDTensor) Input Id variables to be sent"
).
AsDuplicable
();
AddInput
(
"X"
,
"(LoDTensor) Input Id variables to be sent"
).
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which will be"
"initialized at most once."
);
AddOutput
(
"Out"
,
AddOutput
(
"Out"
,
"(LoDTensor) result "
"(LoDTensor) result "
"to be fetched from parameter server"
)
"to be fetched from parameter server"
)
...
@@ -87,17 +79,6 @@ the parameter server and fetch result back.
...
@@ -87,17 +79,6 @@ the parameter server and fetch result back.
}
}
};
};
class
PrefetchOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
PrefetchOpShapeInference
:
public
framework
::
InferShapeBase
{
class
PrefetchOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
...
@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
prefetch
,
ops
::
PrefetchOp
,
REGISTER_OPERATOR
(
prefetch
,
ops
::
PrefetchOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
PrefetchOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
PrefetchOpMaker
,
ops
::
PrefetchOpVarTypeInference
,
ops
::
PrefetchOpShapeInference
);
ops
::
PrefetchOpShapeInference
);
paddle/fluid/operators/recv_op.cc
浏览文件 @
d92a75be
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase {
...
@@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
outs
=
Outputs
(
"Out"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
int
sync_mode
=
Attr
<
int
>
(
"sync_mode"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
if
(
sync_mode
)
{
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
}
PADDLE_ENFORCE
(
client_
.
Wait
());
}
}
private:
mutable
detail
::
RPCClient
client_
;
};
};
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -65,6 +70,10 @@ This operator can get variables from server side.
...
@@ -65,6 +70,10 @@ This operator can get variables from server side.
"Server endpoints in the order of input "
"Server endpoints in the order of input "
"variables for mapping"
)
"variables for mapping"
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
int
>
(
"sync_mode"
,
"(int, default 0)"
"sync recv or async recv."
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
d92a75be
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
(
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
auto
&
ctx
=
*
pool
.
Get
(
place
);
"Can not find variable '%s' in the scope."
,
// For profiling
client_var_name
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
// need to wait before sending send_barrier message
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
sync_mode
)
{
for
(
auto
&
ep
:
eps
)
{
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"send barrier, ep: "
<<
ep
;
VLOG
(
3
)
<<
"send barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
}
};
};
class
SendBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SendBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
{
void
Make
()
{
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
SendBarrier operator
SendBarrier operator
...
@@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent.
...
@@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)"
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
"Server endpoints to send variables to."
)
.
SetDefault
({
"127.0.0.1:6164"
});
.
SetDefault
({
"127.0.0.1:6164"
});
}
AddAttr
<
bool
>
(
"sync_mode"
,
"work in sync_mode or not"
).
SetDefault
(
true
);
};
class
SendBarrierOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
}
};
};
...
@@ -98,5 +88,4 @@ namespace ops = paddle::operators;
...
@@ -98,5 +88,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
send_barrier
,
ops
::
SendBarrierOp
,
REGISTER_OPERATOR
(
send_barrier
,
ops
::
SendBarrierOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendBarrierOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendBarrierOpMaker
,
ops
::
SendBarrierOpVarTypeInference
,
ops
::
SendBarrierOpShapeInference
);
ops
::
SendBarrierOpShapeInference
);
paddle/fluid/operators/send_op.cc
浏览文件 @
d92a75be
...
@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
...
@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"(Tensor) Input tensor to be sent"
).
AsDuplicable
();
AddInput
(
"X"
,
"(Tensor) Input tensor to be sent"
).
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Output tensor to be received from server"
)
AddOutput
(
"Out"
,
"(Tensor) Output tensor to be received from server"
)
.
AsDuplicable
();
.
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Send operator
Send operator
...
@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
...
@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
}
}
};
};
class
SendOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
SendOpShapeInference
:
public
framework
::
InferShapeBase
{
class
SendOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
...
@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
send
,
ops
::
SendOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
REGISTER_OPERATOR
(
send
,
ops
::
SendOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendOpMaker
,
ops
::
SendOpVarTypeInference
,
ops
::
SendOpMaker
,
ops
::
SendOpShapeInference
);
ops
::
SendOpShapeInference
);
paddle/fluid/operators/send_recv_op_test.cc
浏览文件 @
d92a75be
...
@@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) {
...
@@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) {
std
::
thread
server_thread
(
StartServerNet
,
false
,
&
initialized
);
std
::
thread
server_thread
(
StartServerNet
,
false
,
&
initialized
);
while
(
!
initialized
)
{
while
(
!
initialized
)
{
}
}
static_cast
<
paddle
::
operators
::
ListenAndServOp
*>
(
listen_and_serv_op
.
get
())
static_cast
<
paddle
::
operators
::
ListenAndServOp
*>
(
listen_and_serv_op
.
get
())
->
WaitServerReady
();
->
WaitServerReady
();
...
@@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) {
...
@@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) {
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
const
f
::
VariableNameMap
&
inputs
=
{{
"X"
,
{
"x1"
}}};
"send"
,
{{
"X"
,
{
"x1"
}}},
const
f
::
VariableNameMap
&
outputs
=
{{
"Out"
,
{
"Out"
}}};
{{
"Out"
,
{
"Out"
}},
{
"RPCClient"
,
{
"RPC_CLIENT_VAR"
}}},
attrs
);
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
"send"
,
inputs
,
outputs
,
attrs
);
send_op
->
Run
(
scope
,
place
);
send_op
->
Run
(
scope
,
place
);
auto
in_var
=
scope
.
Var
(
"x1"
);
auto
in_var
=
scope
.
Var
(
"x1"
);
...
@@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) {
...
@@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) {
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
"send"
,
{{
"X"
,
{
"x1"
}}},
"send"
,
{{
"X"
,
{
"x1"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
{{
"Out"
,
{
"Out"
}},
{
"RPCClient"
,
{
"RPC_CLIENT_VAR"
}}},
attrs
);
send_op
->
Run
(
scope
,
place
);
send_op
->
Run
(
scope
,
place
);
auto
x0
=
scope
.
Var
(
"x0"
)
->
GetMutable
<
f
::
SelectedRows
>
();
auto
x0
=
scope
.
Var
(
"x0"
)
->
GetMutable
<
f
::
SelectedRows
>
();
...
...
paddle/fluid/operators/send_recv_util.h
浏览文件 @
d92a75be
...
@@ -20,6 +20,9 @@ namespace operators {
...
@@ -20,6 +20,9 @@ namespace operators {
inline
bool
NeedSend
(
const
framework
::
Scope
&
scope
,
inline
bool
NeedSend
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
varname
)
{
const
std
::
string
&
varname
)
{
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if
(
varname
==
"dummy"
)
return
false
;
auto
*
var
=
scope
.
FindVar
(
varname
);
auto
*
var
=
scope
.
FindVar
(
varname
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
varname
);
varname
);
...
...
paddle/fluid/operators/send_vars_op.cc
浏览文件 @
d92a75be
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase {
...
@@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
client_var_name
=
Output
(
"RPCClient"
);
// For profiling
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
@@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
{
void
Make
()
{
AddInput
(
"X"
,
"(Tensor, SelectedRows) Input variables to be sent"
)
AddInput
(
"X"
,
"(Tensor, SelectedRows) Input variables to be sent"
)
.
AsDuplicable
();
.
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which will be"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Send operator
Send operator
...
@@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
...
@@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
}
}
};
};
class
SendVarsOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
SendVarsOpShapeInference
:
public
framework
::
InferShapeBase
{
class
SendVarsOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
@@ -112,5 +97,4 @@ namespace ops = paddle::operators;
...
@@ -112,5 +97,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
send_vars
,
ops
::
SendVarsOp
,
REGISTER_OPERATOR
(
send_vars
,
ops
::
SendVarsOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendVarsOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendVarsOpMaker
,
ops
::
SendVarsOpVarTypeInference
,
ops
::
SendVarsOpShapeInference
);
ops
::
SendVarsOpShapeInference
);
paddle/fluid/pybind/const_value.cc
浏览文件 @
d92a75be
...
@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
...
@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
.
value
(
"Forward"
,
framework
::
OpRole
::
kForward
)
.
value
(
"Forward"
,
framework
::
OpRole
::
kForward
)
.
value
(
"Backward"
,
framework
::
OpRole
::
kBackward
)
.
value
(
"Backward"
,
framework
::
OpRole
::
kBackward
)
.
value
(
"Optimize"
,
framework
::
OpRole
::
kOptimize
)
.
value
(
"Optimize"
,
framework
::
OpRole
::
kOptimize
)
.
value
(
"Loss"
,
framework
::
OpRole
::
kLoss
);
.
value
(
"Loss"
,
framework
::
OpRole
::
kLoss
)
.
value
(
"RPC"
,
framework
::
OpRole
::
kRPC
);
op_proto_and_checker_maker
.
def
(
op_proto_and_checker_maker
.
def
(
"kOpRoleAttrName"
,
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
);
"kOpRoleAttrName"
,
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
);
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
d92a75be
...
@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
...
@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints
=
list
(
set
(
epmap
))
endpoints
=
list
(
set
(
epmap
))
helper
=
LayerHelper
(
"Send"
,
**
locals
())
helper
=
LayerHelper
(
"Send"
,
**
locals
())
rpc_client_var
=
default_main_program
().
global_block
().
create_var
(
name
=
"RPC_CLIENT_VAR"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
if
not
get_vars
:
if
not
get_vars
:
get_vars
=
[]
get_vars
=
[]
for
s
in
send_vars
:
for
s
in
send_vars
:
v
=
helper
.
create_tmp_variable
(
dtype
=
s
.
dtype
,
stop_gradient
=
True
)
v
=
helper
.
create_tmp_variable
(
dtype
=
s
.
dtype
,
stop_gradient
=
True
)
get_vars
.
append
(
v
)
get_vars
.
append
(
v
)
rpc_op_role_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
helper
.
append_op
(
helper
.
append_op
(
type
=
"send"
,
type
=
"send"
,
inputs
=
{
"X"
:
send_vars
},
inputs
=
{
"X"
:
send_vars
},
outputs
=
{
"Out"
:
get_vars
,
outputs
=
{
"Out"
:
get_vars
},
"RPCClient"
:
rpc_client_var
},
attrs
=
{
attrs
=
{
"endpoints"
:
endpoints
,
"endpoints"
:
endpoints
,
"epmap"
:
epmap
})
"epmap"
:
epmap
,
rpc_op_role_name
:
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
})
return
get_vars
return
get_vars
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
d92a75be
...
@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase):
...
@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase):
def
test_transpiler
(
self
):
def
test_transpiler
(
self
):
trainer
=
self
.
get_trainer
()
trainer
=
self
.
get_trainer
()
pserver
,
startup
=
self
.
get_pserver
(
self
.
current_pserver_ep
)
pserver
,
startup
=
self
.
get_pserver
(
self
.
current_pserver_ep
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
self
.
get_expect_trainer_ops
())
self
.
get_expect_trainer_ops
())
...
@@ -67,7 +66,7 @@ class TestDistTranspiler(unittest.TestCase):
...
@@ -67,7 +66,7 @@ class TestDistTranspiler(unittest.TestCase):
"fill_constant"
,
"fill_constant"
,
"uniform_random"
,
"uniform_random"
"fill_constant"
,
"fill_constant"
,
"uniform_random"
,
"uniform_random"
])
])
# the variable #fc_w will be split into two blocks
# the variable #fc_w will be split into two blocks
fc_w_var
=
startup
.
global_block
().
var
(
"fc_w.block1"
)
fc_w_var
=
startup
.
global_block
().
var
(
"fc_w.block1"
)
self
.
assertEqual
(
fc_w_var
.
shape
,
(
500
,
1000
))
self
.
assertEqual
(
fc_w_var
.
shape
,
(
500
,
1000
))
...
@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase):
...
@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase):
optimize_ops
,
params_grads
=
self
.
net_conf
()
optimize_ops
,
params_grads
=
self
.
net_conf
()
delete_ops
(
trainer
.
global_block
(),
optimize_ops
)
delete_ops
(
trainer
.
global_block
(),
optimize_ops
)
return
[
op
.
type
for
op
in
trainer
.
global_block
().
ops
ops
=
[
op
.
type
for
op
in
trainer
.
global_block
().
ops
]
+
[
]
+
[
"split_byref"
,
"send"
,
"concat"
]
"split_byref"
,
"send_vars"
,
"send_barrier"
,
"recv"
,
"recv"
,
"fetch_barrier"
,
"concat"
]
ops
.
insert
(
ops
.
index
(
"elementwise_add_grad"
)
+
1
,
"send_vars"
)
return
ops
def
get_trainer
(
self
):
def
get_trainer
(
self
):
return
self
.
_transpiler_instance
().
get_trainer_program
()
return
self
.
_transpiler_instance
().
get_trainer_program
()
...
...
python/paddle/fluid/transpiler/__init__.py
浏览文件 @
d92a75be
...
@@ -16,8 +16,9 @@ from distribute_transpiler import DistributeTranspiler
...
@@ -16,8 +16,9 @@ from distribute_transpiler import DistributeTranspiler
from
inference_transpiler
import
InferenceTranspiler
from
inference_transpiler
import
InferenceTranspiler
from
memory_optimization_transpiler
import
memory_optimize
,
release_memory
from
memory_optimization_transpiler
import
memory_optimize
,
release_memory
from
distribute_transpiler_simple
import
SimpleDistributeTranspiler
from
distribute_transpiler_simple
import
SimpleDistributeTranspiler
from
ps_dispatcher
import
HashName
,
RoundRobin
__all__
=
[
__all__
=
[
"DistributeTranspiler"
,
"InferenceTranspiler"
,
"SimpleDistributeTranspiler"
,
"DistributeTranspiler"
,
"InferenceTranspiler"
,
"SimpleDistributeTranspiler"
,
"memory_optimize"
,
"release_memory"
"memory_optimize"
,
"release_memory"
,
"HashName"
,
"RoundRobin"
]
]
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
d92a75be
...
@@ -16,7 +16,7 @@ from __future__ import print_function
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
math
import
math
import
distributed_splitter
as
splitt
er
from
ps_dispatcher
import
RoundRobin
,
HashName
,
PSDispatch
er
from
..
import
core
,
framework
from
..
import
core
,
framework
from
..framework
import
Program
,
default_main_program
,
\
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
\
default_startup_program
,
\
...
@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
...
@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
RPC_CLIENT_VAR_NAME
=
"RPC_CLIENT_VAR"
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
class
VarBlock
:
class
VarBlock
:
...
@@ -149,13 +151,27 @@ def delete_ops(block, ops):
...
@@ -149,13 +151,27 @@ def delete_ops(block, ops):
block
.
program
.
sync_with_cpp
()
block
.
program
.
sync_with_cpp
()
def
find_op_by_input_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
input_arg_names
:
return
index
return
-
1
def
find_op_by_output_arg
(
block
,
arg_name
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
arg_name
in
op
.
output_arg_names
:
return
index
return
-
1
class
DistributeTranspiler
:
class
DistributeTranspiler
:
def
transpile
(
self
,
def
transpile
(
self
,
trainer_id
,
trainer_id
,
program
=
None
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
trainers
=
1
,
split_method
=
splitter
.
round_r
obin
,
split_method
=
RoundR
obin
,
sync_mode
=
True
):
sync_mode
=
True
):
"""
"""
Transpile the program to distributed data-parallelism programs.
Transpile the program to distributed data-parallelism programs.
...
@@ -196,7 +212,7 @@ class DistributeTranspiler:
...
@@ -196,7 +212,7 @@ class DistributeTranspiler:
:param sync_mode: if sync_mode is set True, it means that dist transpiler
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
will transpile the program into sync_mode pserver and trainer program.
"""
"""
assert
(
callable
(
split_method
)
)
assert
(
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
if
program
is
None
:
if
program
is
None
:
program
=
default_main_program
()
program
=
default_main_program
()
self
.
origin_program
=
program
self
.
origin_program
=
program
...
@@ -209,6 +225,7 @@ class DistributeTranspiler:
...
@@ -209,6 +225,7 @@ class DistributeTranspiler:
pserver_endpoints
=
pservers
.
split
(
","
)
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
pserver_endpoints
=
pserver_endpoints
self
.
optimize_ops
,
params_grads
=
self
.
_get_optimize_pass
()
self
.
optimize_ops
,
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
split_method
(
pserver_endpoints
)
# process lookup_table_op
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 1. check all lookup_table_op is distributed
...
@@ -268,54 +285,110 @@ class DistributeTranspiler:
...
@@ -268,54 +285,110 @@ class DistributeTranspiler:
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
assert
(
len
(
grad_blocks
)
==
len
(
param_blocks
))
# step2: Create new vars for the parameters and gradients blocks and
# step2: Create new vars for the parameters and gradients blocks and
# add ops to do the split.
# add ops to do the split.
grad_var_mapping
=
self
.
_append_split_op
(
program
,
grad_blocks
)
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
param_blocks
)
param_blocks
)
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
grad_blocks
,
add_trainer_suffix
=
self
.
trainer_num
>
1
)
grad_param_mapping
=
dict
()
for
g
,
p
in
zip
(
grad_blocks
,
param_blocks
):
g_name
,
g_bid
,
_
=
g
.
split
(
":"
)
p_name
,
p_bid
,
_
=
p
.
split
(
":"
)
grad_param_mapping
[
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
param_var_mapping
[
p_name
][
int
(
p_bid
)]
# step 3: transpile trainer side program, insert recv op and send op.
# step3: Add gradients as send op inputs and parameters as send
# op outputs.
send_inputs
=
[]
send_outputs
=
[]
for
b
in
grad_blocks
:
# append by order
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_inputs
.
append
(
grad_var_mapping
[
varname
][
int
(
block_id
)])
for
b
in
param_blocks
:
varname
,
block_id
,
_
=
b
.
split
(
":"
)
send_outputs
.
append
(
param_var_mapping
[
varname
][
int
(
block_id
)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist
=
split_method
(
send_inputs
,
pserver_endpoints
)
# create mapping of endpoint -> split var to create pserver side program
# create mapping of endpoint -> split var to create pserver side program
self
.
param_grad_ep_mapping
=
dict
()
self
.
param_grad_ep_mapping
=
dict
()
[
self
.
param_grad_ep_mapping
.
update
({
ep
:
{
"params"
:
[],
"grads"
:
[]
}
})
for
ep
in
self
.
pserver_endpoints
]
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher
.
reset
()
send_vars
=
[]
for
orig_varname
,
splited_vars
in
grad_var_mapping
.
items
():
eplist
=
ps_dispatcher
.
dispatch
(
splited_vars
)
if
len
(
splited_vars
)
==
1
:
orig_varname
=
splited_vars
[
0
].
name
index
=
find_op_by_output_arg
(
program
.
global_block
(),
orig_varname
)
elif
len
(
splited_vars
)
>
1
:
orig_var
=
program
.
global_block
().
vars
[
orig_varname
]
index
=
find_op_by_output_arg
(
program
.
global_block
(),
orig_varname
)
self
.
_insert_split_op
(
program
,
orig_var
,
index
,
splited_vars
)
index
+=
1
else
:
AssertionError
(
"Can not insert the send op by original "
"variable name :"
,
orig_varname
)
program
.
global_block
().
insert_op
(
index
=
index
+
1
,
type
=
"send_vars"
,
inputs
=
{
"X"
:
splited_vars
},
outputs
=
{},
attrs
=
{
"epmap"
:
eplist
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
if
self
.
sync_mode
:
program
.
global_block
().
append_op
(
type
=
"send_barrier"
,
inputs
=
{},
outputs
=
{},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"sync_mode"
:
self
.
sync_mode
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
# step 3.2: insert recv op to receive parameters from parameter server
recv_vars
=
[]
for
_
,
var
in
enumerate
(
send_vars
):
recv_vars
.
append
(
grad_param_mapping
[
var
])
ps_dispatcher
.
reset
()
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
for
i
,
ep
in
enumerate
(
eplist
):
for
i
,
ep
in
enumerate
(
eplist
):
param
=
send_outputs
[
i
]
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
grad
=
send_inputs
[
i
]
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
if
not
self
.
param_grad_ep_mapping
.
has_key
(
ep
):
# step4: Concat the parameters splits together after recv.
self
.
param_grad_ep_mapping
[
ep
]
=
{
"params"
:
[],
"grads"
:
[]}
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
param
)
eps
=
[]
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
grad
)
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
rpc_client_var
=
program
.
global_block
().
create_var
(
eps
.
append
(
eplist
[
index
])
name
=
RPC_CLIENT_VAR_NAME
,
persistable
=
True
,
program
.
global_block
().
append_op
(
type
=
core
.
VarDesc
.
VarType
.
RAW
)
type
=
"recv"
,
inputs
=
{},
# create send_op
outputs
=
{
"Out"
:
splited_var
},
attrs
=
{
"epmap"
:
eps
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"send"
,
type
=
"fetch_barrier"
,
inputs
=
{
"X"
:
send_inputs
},
inputs
=
{},
outputs
=
{
"Out"
:
send_outputs
,
outputs
=
{},
"RPCClient"
:
rpc_client_var
},
attrs
=
{
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"endpoints"
:
pserver_endpoints
,
"epmap"
:
eplist
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
"sync_mode"
:
self
.
sync_mode
})
})
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
continue
continue
...
@@ -327,10 +400,8 @@ class DistributeTranspiler:
...
@@ -327,10 +400,8 @@ class DistributeTranspiler:
attrs
=
{
"axis"
:
0
})
attrs
=
{
"axis"
:
0
})
if
self
.
has_distributed_lookup_table
:
if
self
.
has_distributed_lookup_table
:
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
rpc_client_var
,
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
eplist
)
eplist
)
self
.
_split_table_grad_and_add_send_vars
(
program
,
pserver_endpoints
)
self
.
_split_table_grad_and_add_send_vars
(
program
,
rpc_client_var
,
pserver_endpoints
)
def
get_trainer_program
(
self
):
def
get_trainer_program
(
self
):
# remove optimize ops and add a send op to main_program
# remove optimize ops and add a send op to main_program
...
@@ -550,8 +621,7 @@ class DistributeTranspiler:
...
@@ -550,8 +621,7 @@ class DistributeTranspiler:
return
s_prog
return
s_prog
# transpiler function for dis lookup_table
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
rpc_client_var
,
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
eplist
):
eplist
):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self
.
prefetch_input_vars
=
None
self
.
prefetch_input_vars
=
None
self
.
prefetch_output_vars
=
None
self
.
prefetch_output_vars
=
None
...
@@ -598,11 +668,11 @@ class DistributeTranspiler:
...
@@ -598,11 +668,11 @@ class DistributeTranspiler:
index
=
op_index
+
1
,
index
=
op_index
+
1
,
type
=
"prefetch"
,
type
=
"prefetch"
,
inputs
=
{
'X'
:
self
.
prefetch_input_vars
},
inputs
=
{
'X'
:
self
.
prefetch_input_vars
},
outputs
=
{
outputs
=
{
"Out"
:
self
.
prefetch_output_vars
},
"Out"
:
self
.
prefetch_output_vars
,
attrs
=
{
"
RPCClient"
:
rpc_client_var
"
epmap"
:
eplist
,
},
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
attrs
=
{
"epmap"
:
eplist
})
})
# insert concat_op
# insert concat_op
program
.
global_block
().
insert_op
(
program
.
global_block
().
insert_op
(
...
@@ -622,8 +692,7 @@ class DistributeTranspiler:
...
@@ -622,8 +692,7 @@ class DistributeTranspiler:
# break for loop
# break for loop
break
break
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
rpc_client_var
,
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
pserver_endpoints
):
pserver_endpoints
):
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
all_ops
=
program
.
global_block
().
ops
...
@@ -643,9 +712,12 @@ class DistributeTranspiler:
...
@@ -643,9 +712,12 @@ class DistributeTranspiler:
index
=
op_index
+
2
,
index
=
op_index
+
2
,
type
=
"send_vars"
,
type
=
"send_vars"
,
inputs
=
{
'X'
:
self
.
table_grad_list
},
inputs
=
{
'X'
:
self
.
table_grad_list
},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
outputs
=
{},
attrs
=
{
"sync_send"
:
True
,
attrs
=
{
"epmap"
:
pserver_endpoints
})
"sync_send"
:
True
,
"epmap"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
break
break
def
_create_prefetch_block
(
self
,
pserver_index
,
pserver_program
,
def
_create_prefetch_block
(
self
,
pserver_index
,
pserver_program
,
...
@@ -838,50 +910,31 @@ class DistributeTranspiler:
...
@@ -838,50 +910,31 @@ class DistributeTranspiler:
lod_level
=
var
.
lod_level
,
lod_level
=
var
.
lod_level
,
persistable
=
persistable
)
persistable
=
persistable
)
def
_append_split_op
(
self
,
program
,
gradblocks
):
def
_insert_split_op
(
self
,
program
,
orig_var
,
index
,
splited_vars
):
"""
if
orig_var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
Split variables that need to be split and append respective ops
height_sections
=
[]
Args:
for
v
in
splited_vars
:
program (ProgramDesc): ProgramDesc that gradients blong.
height_sections
.
append
(
v
.
shape
[
0
])
gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks.
program
.
global_block
().
insert_op
(
Returns:
index
=
index
+
1
,
var_mapping (dict(varname->[new_splitted_variable])):A dict mapping
type
=
"split_selected_rows"
,
from original var name to each var split.
inputs
=
{
"X"
:
orig_var
},
"""
outputs
=
{
"Out"
:
splited_vars
},
attrs
=
{
"height_sections"
:
height_sections
})
add_suffix
=
False
elif
orig_var
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR
:
if
self
.
trainer_num
>
1
:
sections
=
[]
add_suffix
=
True
for
v
in
splited_vars
:
var_mapping
=
self
.
_create_vars_from_blocklist
(
sections
.
append
(
v
.
shape
[
0
])
program
,
gradblocks
,
add_trainer_suffix
=
add_suffix
)
program
.
global_block
().
insert_op
(
for
varname
,
splited_vars
in
var_mapping
.
iteritems
():
index
=
index
+
1
,
# variable that don't need to split have empty splited_vars
type
=
"split_byref"
,
if
len
(
splited_vars
)
<=
1
:
inputs
=
{
"X"
:
orig_var
},
continue
outputs
=
{
"Out"
:
splited_vars
},
orig_var
=
program
.
global_block
().
vars
[
varname
]
attrs
=
{
"sections"
:
sections
}
# assume split evenly
if
orig_var
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
)
height_sections
=
[]
else
:
for
v
in
splited_vars
:
AssertionError
(
"Variable type should be in set "
height_sections
.
append
(
v
.
shape
[
0
])
"[LOD_TENSOR, SELECTED_ROWS]"
)
program
.
global_block
().
append_op
(
type
=
"split_selected_rows"
,
inputs
=
{
"X"
:
orig_var
},
outputs
=
{
"Out"
:
splited_vars
},
attrs
=
{
"height_sections"
:
height_sections
})
elif
orig_var
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR
:
sections
=
[]
for
v
in
splited_vars
:
sections
.
append
(
v
.
shape
[
0
])
program
.
global_block
().
append_op
(
type
=
"split_byref"
,
inputs
=
{
"X"
:
orig_var
},
outputs
=
{
"Out"
:
splited_vars
},
attrs
=
{
"sections"
:
sections
}
# assume split evenly
)
else
:
AssertionError
(
"Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]"
)
return
var_mapping
def
_get_optimizer_input_shape
(
self
,
op_type
,
varkey
,
orig_shape
,
def
_get_optimizer_input_shape
(
self
,
op_type
,
varkey
,
orig_shape
,
param_shape
):
param_shape
):
...
...
python/paddle/fluid/transpiler/
distributed_splitt
er.py
→
python/paddle/fluid/transpiler/
ps_dispatch
er.py
浏览文件 @
d92a75be
...
@@ -13,45 +13,66 @@
...
@@ -13,45 +13,66 @@
# limitations under the License.
# limitations under the License.
def
hash_name
(
varlist
,
pserver_endpoints
):
class
PSDispatcher
(
object
):
"""
"""
hash variable names to several endpoints.
PSDispatcher is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` inferface.
"""
def
__init__
(
self
,
pserver_endpoints
):
self
.
_eps
=
pserver_endpoints
self
.
_step
=
0
@
property
def
eps
(
self
):
return
self
.
_eps
def
reset
(
self
):
self
.
_step
=
0
def
dispatch
(
self
,
varlist
):
"""
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
"""
AssertionError
(
"Interface has not been implemented."
)
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
class
HashName
(
PSDispatcher
):
"""
"""
Hash variable names to several endpoints
"""
def
__init__
(
self
,
pserver_endpoints
):
super
(
self
.
__class__
,
self
).
__init__
(
pserver_endpoints
)
def
_hash_block
(
block_str
,
total
):
def
_hash_block
(
self
,
block_str
,
total
):
return
hash
(
block_str
)
%
total
return
hash
(
block_str
)
%
total
eplist
=
[]
def
dispatch
(
self
,
varlist
):
for
var
in
varlist
:
eplist
=
[]
server_id
=
_hash_block
(
var
.
name
(),
len
(
pserver_endpoints
))
for
var
in
varlist
:
server_for_param
=
pserver_endpoints
[
server_id
]
server_id
=
self
.
_hash_block
(
var
.
name
(),
len
(
self
.
_eps
))
eplist
.
append
(
server_for_param
)
server_for_param
=
self
.
_eps
[
server_id
]
return
eplist
eplist
.
append
(
server_for_param
)
return
eplist
def
round_robin
(
varlist
,
pserver_endpoints
):
class
RoundRobin
(
PSDispatcher
):
"""
"""
Distribute variables to several endpoints.
Distribute variables to serveral endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
"""
"""
assert
(
len
(
varlist
)
>=
len
(
pserver_endpoints
))
def
__init__
(
self
,
pserver_endpoints
):
eplist
=
[]
super
(
self
.
__class__
,
self
).
__init__
(
pserver_endpoints
)
pserver_idx
=
0
for
var
in
varlist
:
def
dispatch
(
self
,
varlist
):
server_for_param
=
pserver_endpoints
[
pserver_idx
]
eplist
=
[]
eplist
.
append
(
server_for_param
)
for
var
in
varlist
:
server_for_param
=
self
.
_eps
[
self
.
_step
]
pserver_idx
+=
1
eplist
.
append
(
server_for_param
)
if
pserver_idx
>=
len
(
pserver_endpoints
):
self
.
_step
+=
1
pserver_idx
=
0
if
self
.
_step
>=
len
(
self
.
_eps
):
return
eplist
self
.
_step
=
0
return
eplist
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录