Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
20c24c05
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
20c24c05
编写于
5月 29, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
singleton rpc_client
上级
28596a33
变更
18
显示空白变更内容
内联
并排
Showing
18 changed file
with
161 addition
and
240 deletion
+161
-240
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+3
-10
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+0
-2
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+1
-1
paddle/fluid/framework/op_proto_maker.h
paddle/fluid/framework/op_proto_maker.h
+1
-0
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+2
-2
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+15
-0
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+10
-0
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+7
-4
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+1
-21
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+1
-21
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+1
-9
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-21
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+2
-22
paddle/fluid/operators/send_recv_op_test.cc
paddle/fluid/operators/send_recv_op_test.cc
+68
-68
paddle/fluid/operators/send_vars_op.cc
paddle/fluid/operators/send_vars_op.cc
+1
-21
paddle/fluid/pybind/const_value.cc
paddle/fluid/pybind/const_value.cc
+2
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+8
-6
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+36
-31
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
20c24c05
...
@@ -146,15 +146,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
...
@@ -146,15 +146,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
checker
(
op
.
InputArgumentNames
(),
recv_vars
);
checker
(
op
.
InputArgumentNames
(),
recv_vars
);
}
}
bool
MultiDevSSAGraphBuilder
::
IsRPCOp
(
const
OpDesc
&
op
)
const
{
for
(
auto
&
name
:
op
.
OutputNames
())
{
if
(
name
==
"RPCClient"
)
{
return
true
;
}
}
return
false
;
}
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
const
ProgramDesc
&
program
)
const
{
const
ProgramDesc
&
program
)
const
{
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
var_types
;
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
var_types
;
...
@@ -184,7 +175,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -184,7 +175,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
IsRPCOp
(
*
op
))
{
if
(
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
// append rpc op if program is distributed trainer main program.
// append rpc op if program is distributed trainer main program.
// always use the first device
// always use the first device
CreateRPCOp
(
&
result
,
*
op
);
CreateRPCOp
(
&
result
,
*
op
);
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
20c24c05
...
@@ -80,8 +80,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -80,8 +80,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
ProgramDesc
&
program
)
const
;
const
ProgramDesc
&
program
)
const
;
bool
IsRPCOp
(
const
OpDesc
&
op
)
const
;
void
ConnectOp
(
SSAGraph
*
result
,
OpHandleBase
*
op
,
void
ConnectOp
(
SSAGraph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
;
const
std
::
string
&
prev_op_name
)
const
;
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
20c24c05
...
@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
...
@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.
InEnum
(
.
InEnum
(
{
static_cast
<
int
>
(
OpRole
::
kForward
),
{
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kOptimize
),
static_cast
<
int
>
(
OpRole
::
kOptimize
),
static_cast
<
int
>
(
OpRole
::
kRPC
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kBackward
),
...
...
paddle/fluid/framework/op_proto_maker.h
浏览文件 @
20c24c05
...
@@ -24,6 +24,7 @@ enum class OpRole {
...
@@ -24,6 +24,7 @@ enum class OpRole {
kForward
=
0x0000
,
kForward
=
0x0000
,
kBackward
=
0x0001
,
kBackward
=
0x0001
,
kOptimize
=
0x0002
,
kOptimize
=
0x0002
,
kRPC
=
0x0003
,
kLoss
=
0x0100
,
kLoss
=
0x0100
,
// The default value of op's role. This should be only used for unittests and
// The default value of op's role. This should be only used for unittests and
...
...
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
浏览文件 @
20c24c05
...
@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
...
@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
auto
nodes
=
trait
.
nodes
();
auto
nodes
=
trait
.
nodes
();
in
t
count
=
0
;
size_
t
count
=
0
;
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
++
count
;
++
count
;
...
@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
...
@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg
.
Build
();
dfg
.
Build
();
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
auto
nodes
=
trait
.
nodes_in_DFS
();
auto
nodes
=
trait
.
nodes_in_DFS
();
in
t
count
=
0
;
size_
t
count
=
0
;
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
for
(
auto
it
=
nodes
.
begin
();
it
!=
nodes
.
end
();
++
it
)
{
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
LOG
(
INFO
)
<<
"visiting "
<<
it
->
name
();
++
count
;
++
count
;
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
20c24c05
...
@@ -25,6 +25,21 @@ namespace paddle {
...
@@ -25,6 +25,21 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
std
::
once_flag
RPCClient
::
init_flag_
;
std
::
unique_ptr
<
RPCClient
>
RPCClient
::
rpc_client_
(
nullptr
);
RPCClient
*
RPCClient
::
GetInstance
()
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
);
return
rpc_client_
.
get
();
}
void
RPCClient
::
Init
()
{
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
RPCClient
());
}
}
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
20c24c05
...
@@ -36,6 +36,7 @@ limitations under the License. */
...
@@ -36,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -162,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
...
@@ -162,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
class
RPCClient
{
class
RPCClient
{
public:
public:
RPCClient
()
{}
static
RPCClient
*
GetInstance
();
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
...
@@ -192,12 +197,17 @@ class RPCClient {
...
@@ -192,12 +197,17 @@ class RPCClient {
private:
private:
bool
Proceed
();
bool
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
// Init is called by GetInstance.
static
void
Init
();
private:
private:
grpc
::
CompletionQueue
cq_
;
grpc
::
CompletionQueue
cq_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
atomic
<
int64_t
>
req_count_
{
0
};
std
::
atomic
<
int64_t
>
req_count_
{
0
};
std
::
mutex
mutex_
;
std
::
mutex
mutex_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
static
std
::
once_flag
init_flag_
;
DISABLE_COPY_AND_ASSIGN
(
RPCClient
);
};
};
}
// namespace detail
}
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
20c24c05
...
@@ -121,10 +121,13 @@ TEST(PREFETCH, DISABLED_CPU) {
...
@@ -121,10 +121,13 @@ TEST(PREFETCH, DISABLED_CPU) {
std
::
string
in_var_name
(
"ids"
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
std
::
string
out_var_name
(
"out"
);
detail
::
RPCClient
client
;
detail
::
RPCClient
::
GetInstance
();
client
.
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
// detail::RPCClient::GetInstance();
client
.
Wait
();
// client->Wait();
// client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
// out_var_name);
// client->Wait();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
...
...
paddle/fluid/operators/fetch_barrier_op.cc
浏览文件 @
20c24c05
...
@@ -43,12 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase {
...
@@ -43,12 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
...
@@ -63,9 +58,6 @@ class FetchBarrierOp : public framework::OperatorBase {
...
@@ -63,9 +58,6 @@ class FetchBarrierOp : public framework::OperatorBase {
class
FetchBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
FetchBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
{
void
Make
()
{
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
SendBarrier operator
SendBarrier operator
...
@@ -80,17 +72,6 @@ the Parameter Server would knew all variables have been sent.
...
@@ -80,17 +72,6 @@ the Parameter Server would knew all variables have been sent.
}
}
};
};
class
FetchBarrierOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
FetchBarrierOpShapeInference
:
public
framework
::
InferShapeBase
{
class
FetchBarrierOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
@@ -103,5 +84,4 @@ namespace ops = paddle::operators;
...
@@ -103,5 +84,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
fetch_barrier
,
ops
::
FetchBarrierOp
,
REGISTER_OPERATOR
(
fetch_barrier
,
ops
::
FetchBarrierOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
FetchBarrierOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
FetchBarrierOpMaker
,
ops
::
FetchBarrierOpVarTypeInference
,
ops
::
FetchBarrierOpShapeInference
);
ops
::
FetchBarrierOpShapeInference
);
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
20c24c05
...
@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
...
@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void
Make
()
{
void
Make
()
{
AddInput
(
"X"
,
"(LoDTensor) Input Id variables to be sent"
).
AsDuplicable
();
AddInput
(
"X"
,
"(LoDTensor) Input Id variables to be sent"
).
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which will be"
"initialized at most once."
);
AddOutput
(
"Out"
,
AddOutput
(
"Out"
,
"(LoDTensor) result "
"(LoDTensor) result "
"to be fetched from parameter server"
)
"to be fetched from parameter server"
)
...
@@ -87,17 +79,6 @@ the parameter server and fetch result back.
...
@@ -87,17 +79,6 @@ the parameter server and fetch result back.
}
}
};
};
class
PrefetchOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
PrefetchOpShapeInference
:
public
framework
::
InferShapeBase
{
class
PrefetchOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
...
@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
prefetch
,
ops
::
PrefetchOp
,
REGISTER_OPERATOR
(
prefetch
,
ops
::
PrefetchOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
PrefetchOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
PrefetchOpMaker
,
ops
::
PrefetchOpVarTypeInference
,
ops
::
PrefetchOpShapeInference
);
ops
::
PrefetchOpShapeInference
);
paddle/fluid/operators/recv_op.cc
浏览文件 @
20c24c05
...
@@ -37,7 +37,6 @@ class RecvOp : public framework::OperatorBase {
...
@@ -37,7 +37,6 @@ class RecvOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
outs
=
Outputs
(
"Out"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
int
sync_mode
=
Attr
<
int
>
(
"sync_mode"
);
int
sync_mode
=
Attr
<
int
>
(
"sync_mode"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
@@ -45,11 +44,7 @@ class RecvOp : public framework::OperatorBase {
...
@@ -45,11 +44,7 @@ class RecvOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
...
@@ -65,9 +60,6 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -65,9 +60,6 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void
Make
()
{
void
Make
()
{
AddOutput
(
"Out"
,
"(Tensor) Variables to get from server."
).
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Variables to get from server."
).
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Recv operator
Recv operator
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
20c24c05
...
@@ -43,12 +43,8 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -43,12 +43,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
// need to wait before sending send_barrier message
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
...
@@ -65,9 +61,6 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -65,9 +61,6 @@ class SendBarrierOp : public framework::OperatorBase {
class
SendBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SendBarrierOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
{
void
Make
()
{
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
SendBarrier operator
SendBarrier operator
...
@@ -83,17 +76,6 @@ the Parameter Server would knew all variables have been sent.
...
@@ -83,17 +76,6 @@ the Parameter Server would knew all variables have been sent.
}
}
};
};
class
SendBarrierOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
SendBarrierOpShapeInference
:
public
framework
::
InferShapeBase
{
class
SendBarrierOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
@@ -106,5 +88,4 @@ namespace ops = paddle::operators;
...
@@ -106,5 +88,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
send_barrier
,
ops
::
SendBarrierOp
,
REGISTER_OPERATOR
(
send_barrier
,
ops
::
SendBarrierOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendBarrierOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendBarrierOpMaker
,
ops
::
SendBarrierOpVarTypeInference
,
ops
::
SendBarrierOpShapeInference
);
ops
::
SendBarrierOpShapeInference
);
paddle/fluid/operators/send_op.cc
浏览文件 @
20c24c05
...
@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
...
@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"(Tensor) Input tensor to be sent"
).
AsDuplicable
();
AddInput
(
"X"
,
"(Tensor) Input tensor to be sent"
).
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Output tensor to be received from server"
)
AddOutput
(
"Out"
,
"(Tensor) Output tensor to be received from server"
)
.
AsDuplicable
();
.
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which is"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Send operator
Send operator
...
@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
...
@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
}
}
};
};
class
SendOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
SendOpShapeInference
:
public
framework
::
InferShapeBase
{
class
SendOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
...
@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
send
,
ops
::
SendOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
REGISTER_OPERATOR
(
send
,
ops
::
SendOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendOpMaker
,
ops
::
SendOpVarTypeInference
,
ops
::
SendOpMaker
,
ops
::
SendOpShapeInference
);
ops
::
SendOpShapeInference
);
paddle/fluid/operators/send_recv_op_test.cc
浏览文件 @
20c24c05
...
@@ -177,7 +177,7 @@ TEST(SendRecvOp, CPUDense) {
...
@@ -177,7 +177,7 @@ TEST(SendRecvOp, CPUDense) {
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
"send"
,
{{
"X"
,
{
"x1"
}}},
"send"
,
{{
"X"
,
{
"x1"
}}},
{{
"Out"
,
{
"Out"
}},
{
"RPCClient"
,
{
"RPC_CLIENT_VAR"
}}},
attrs
);
{{
"Out"
,
{
"Out"
}},
attrs
);
send_op
->
Run
(
scope
,
place
);
send_op
->
Run
(
scope
,
place
);
auto
in_var
=
scope
.
Var
(
"x1"
);
auto
in_var
=
scope
.
Var
(
"x1"
);
...
@@ -217,12 +217,12 @@ TEST(SendRecvOp, CPUSparse) {
...
@@ -217,12 +217,12 @@ TEST(SendRecvOp, CPUSparse) {
scope
.
Var
(
"RPC_CLIENT_VAR"
);
scope
.
Var
(
"RPC_CLIENT_VAR"
);
f
::
AttributeMap
attrs
;
f
::
AttributeMap
attrs
;
selected_port
=
listen_and_serv_op_ptr
->
GetSelectedPort
();
selected_port
=
listen_and_serv_op_ptr
->
GetSelectedPort
();
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
auto
send_op
=
f
::
OpRegistry
::
CreateOp
(
"send"
,
{{
"X"
,
{
"x1"
}}},
"send"
,
{{
"X"
,
{
"x1"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
{{
"Out"
,
{
"Out"
}},
{
"RPCClient"
,
{
"RPC_CLIENT_VAR"
}}},
attrs
);
send_op
->
Run
(
scope
,
place
);
send_op
->
Run
(
scope
,
place
);
auto
x0
=
scope
.
Var
(
"x0"
)
->
GetMutable
<
f
::
SelectedRows
>
();
auto
x0
=
scope
.
Var
(
"x0"
)
->
GetMutable
<
f
::
SelectedRows
>
();
...
...
paddle/fluid/operators/send_vars_op.cc
浏览文件 @
20c24c05
...
@@ -45,12 +45,7 @@ class SendVarsOp : public framework::OperatorBase {
...
@@ -45,12 +45,7 @@ class SendVarsOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
@@ -73,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -73,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
{
void
Make
()
{
AddInput
(
"X"
,
"(Tensor, SelectedRows) Input variables to be sent"
)
AddInput
(
"X"
,
"(Tensor, SelectedRows) Input variables to be sent"
)
.
AsDuplicable
();
.
AsDuplicable
();
AddOutput
(
"RPCClient"
,
"(RPCClient) The RPC client object which will be"
"initialized at most once."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Send operator
Send operator
...
@@ -93,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
...
@@ -93,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
}
}
};
};
class
SendVarsOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"RPCClient"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
};
class
SendVarsOpShapeInference
:
public
framework
::
InferShapeBase
{
class
SendVarsOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
...
@@ -116,5 +97,4 @@ namespace ops = paddle::operators;
...
@@ -116,5 +97,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
send_vars
,
ops
::
SendVarsOp
,
REGISTER_OPERATOR
(
send_vars
,
ops
::
SendVarsOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendVarsOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SendVarsOpMaker
,
ops
::
SendVarsOpVarTypeInference
,
ops
::
SendVarsOpShapeInference
);
ops
::
SendVarsOpShapeInference
);
paddle/fluid/pybind/const_value.cc
浏览文件 @
20c24c05
...
@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
...
@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
.
value
(
"Forward"
,
framework
::
OpRole
::
kForward
)
.
value
(
"Forward"
,
framework
::
OpRole
::
kForward
)
.
value
(
"Backward"
,
framework
::
OpRole
::
kBackward
)
.
value
(
"Backward"
,
framework
::
OpRole
::
kBackward
)
.
value
(
"Optimize"
,
framework
::
OpRole
::
kOptimize
)
.
value
(
"Optimize"
,
framework
::
OpRole
::
kOptimize
)
.
value
(
"Loss"
,
framework
::
OpRole
::
kLoss
);
.
value
(
"Loss"
,
framework
::
OpRole
::
kLoss
)
.
value
(
"RPC"
,
framework
::
OpRole
::
kRPC
);
op_proto_and_checker_maker
.
def
(
op_proto_and_checker_maker
.
def
(
"kOpRoleAttrName"
,
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
);
"kOpRoleAttrName"
,
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
);
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
20c24c05
...
@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
...
@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints
=
list
(
set
(
epmap
))
endpoints
=
list
(
set
(
epmap
))
helper
=
LayerHelper
(
"Send"
,
**
locals
())
helper
=
LayerHelper
(
"Send"
,
**
locals
())
rpc_client_var
=
default_main_program
().
global_block
().
create_var
(
name
=
"RPC_CLIENT_VAR"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
if
not
get_vars
:
if
not
get_vars
:
get_vars
=
[]
get_vars
=
[]
for
s
in
send_vars
:
for
s
in
send_vars
:
v
=
helper
.
create_tmp_variable
(
dtype
=
s
.
dtype
,
stop_gradient
=
True
)
v
=
helper
.
create_tmp_variable
(
dtype
=
s
.
dtype
,
stop_gradient
=
True
)
get_vars
.
append
(
v
)
get_vars
.
append
(
v
)
rpc_op_role_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
helper
.
append_op
(
helper
.
append_op
(
type
=
"send"
,
type
=
"send"
,
inputs
=
{
"X"
:
send_vars
},
inputs
=
{
"X"
:
send_vars
},
outputs
=
{
"Out"
:
get_vars
,
outputs
=
{
"Out"
:
get_vars
},
"RPCClient"
:
rpc_client_var
},
attrs
=
{
attrs
=
{
"endpoints"
:
endpoints
,
"endpoints"
:
endpoints
,
"epmap"
:
epmap
})
"epmap"
:
epmap
,
rpc_op_role_name
:
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
})
return
get_vars
return
get_vars
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
20c24c05
...
@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
...
@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
RPC_CLIENT_VAR_NAME
=
"RPC_CLIENT_VAR"
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
class
VarBlock
:
class
VarBlock
:
...
@@ -297,11 +299,6 @@ class DistributeTranspiler:
...
@@ -297,11 +299,6 @@ class DistributeTranspiler:
grad_param_mapping
[
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
grad_param_mapping
[
grad_var_mapping
[
g_name
][
int
(
g_bid
)]]
=
\
param_var_mapping
[
p_name
][
int
(
p_bid
)]
param_var_mapping
[
p_name
][
int
(
p_bid
)]
rpc_client_var
=
program
.
global_block
().
create_var
(
name
=
RPC_CLIENT_VAR_NAME
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
# step 3: transpile trainer side program, insert recv op and send op.
# step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program
# create mapping of endpoint -> split var to create pserver side program
...
@@ -338,8 +335,11 @@ class DistributeTranspiler:
...
@@ -338,8 +335,11 @@ class DistributeTranspiler:
index
=
index
+
1
,
index
=
index
+
1
,
type
=
"send_vars"
,
type
=
"send_vars"
,
inputs
=
{
"X"
:
splited_vars
},
inputs
=
{
"X"
:
splited_vars
},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
outputs
=
{},
attrs
=
{
"epmap"
:
eplist
})
attrs
=
{
"epmap"
:
eplist
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
for
_
,
var
in
enumerate
(
splited_vars
):
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
send_vars
.
append
(
var
)
...
@@ -347,10 +347,11 @@ class DistributeTranspiler:
...
@@ -347,10 +347,11 @@ class DistributeTranspiler:
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"send_barrier"
,
type
=
"send_barrier"
,
inputs
=
{},
inputs
=
{},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
outputs
=
{},
attrs
=
{
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"endpoints"
:
pserver_endpoints
,
"sync_mode"
:
self
.
sync_mode
"sync_mode"
:
self
.
sync_mode
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
})
# step 3.2: insert recv op to receive parameters from parameter server
# step 3.2: insert recv op to receive parameters from parameter server
...
@@ -373,15 +374,20 @@ class DistributeTranspiler:
...
@@ -373,15 +374,20 @@ class DistributeTranspiler:
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"recv"
,
type
=
"recv"
,
inputs
=
{},
inputs
=
{},
outputs
=
{
"Out"
:
splited_var
,
outputs
=
{
"Out"
:
splited_var
},
"RPCClient"
:
rpc_client_var
},
attrs
=
{
attrs
=
{
"epmap"
:
eps
})
"epmap"
:
eps
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"fetch_barrier"
,
type
=
"fetch_barrier"
,
inputs
=
{},
inputs
=
{},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
outputs
=
{},
attrs
=
{
"endpoints"
:
pserver_endpoints
})
attrs
=
{
"endpoints"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
...
@@ -394,10 +400,8 @@ class DistributeTranspiler:
...
@@ -394,10 +400,8 @@ class DistributeTranspiler:
attrs
=
{
"axis"
:
0
})
attrs
=
{
"axis"
:
0
})
if
self
.
has_distributed_lookup_table
:
if
self
.
has_distributed_lookup_table
:
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
rpc_client_var
,
self
.
_replace_lookup_table_op_with_prefetch
(
program
,
eplist
)
eplist
)
self
.
_split_table_grad_and_add_send_vars
(
program
,
pserver_endpoints
)
self
.
_split_table_grad_and_add_send_vars
(
program
,
rpc_client_var
,
pserver_endpoints
)
def
get_trainer_program
(
self
):
def
get_trainer_program
(
self
):
# remove optimize ops and add a send op to main_program
# remove optimize ops and add a send op to main_program
...
@@ -617,8 +621,7 @@ class DistributeTranspiler:
...
@@ -617,8 +621,7 @@ class DistributeTranspiler:
return
s_prog
return
s_prog
# transpiler function for dis lookup_table
# transpiler function for dis lookup_table
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
rpc_client_var
,
def
_replace_lookup_table_op_with_prefetch
(
self
,
program
,
eplist
):
eplist
):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self
.
prefetch_input_vars
=
None
self
.
prefetch_input_vars
=
None
self
.
prefetch_output_vars
=
None
self
.
prefetch_output_vars
=
None
...
@@ -665,11 +668,11 @@ class DistributeTranspiler:
...
@@ -665,11 +668,11 @@ class DistributeTranspiler:
index
=
op_index
+
1
,
index
=
op_index
+
1
,
type
=
"prefetch"
,
type
=
"prefetch"
,
inputs
=
{
'X'
:
self
.
prefetch_input_vars
},
inputs
=
{
'X'
:
self
.
prefetch_input_vars
},
outputs
=
{
outputs
=
{
"Out"
:
self
.
prefetch_output_vars
},
"Out"
:
self
.
prefetch_output_vars
,
attrs
=
{
"
RPCClient"
:
rpc_client_var
"
epmap"
:
eplist
,
},
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
attrs
=
{
"epmap"
:
eplist
})
})
# insert concat_op
# insert concat_op
program
.
global_block
().
insert_op
(
program
.
global_block
().
insert_op
(
...
@@ -689,8 +692,7 @@ class DistributeTranspiler:
...
@@ -689,8 +692,7 @@ class DistributeTranspiler:
# break for loop
# break for loop
break
break
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
rpc_client_var
,
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
pserver_endpoints
):
pserver_endpoints
):
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
all_ops
=
program
.
global_block
().
ops
...
@@ -710,9 +712,12 @@ class DistributeTranspiler:
...
@@ -710,9 +712,12 @@ class DistributeTranspiler:
index
=
op_index
+
2
,
index
=
op_index
+
2
,
type
=
"send_vars"
,
type
=
"send_vars"
,
inputs
=
{
'X'
:
self
.
table_grad_list
},
inputs
=
{
'X'
:
self
.
table_grad_list
},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
outputs
=
{},
attrs
=
{
"sync_send"
:
True
,
attrs
=
{
"epmap"
:
pserver_endpoints
})
"sync_send"
:
True
,
"epmap"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
break
break
def
_create_prefetch_block
(
self
,
pserver_index
,
pserver_program
,
def
_create_prefetch_block
(
self
,
pserver_index
,
pserver_program
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录