Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
62af10d4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
62af10d4
编写于
5月 21, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support multiple devices
上级
274df85c
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
208 addition
and
26 deletion
+208
-26
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-1
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+49
-11
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+5
-0
paddle/fluid/framework/details/rpc_op_handle.cc
paddle/fluid/framework/details/rpc_op_handle.cc
+50
-0
paddle/fluid/framework/details/rpc_op_handle.h
paddle/fluid/framework/details/rpc_op_handle.h
+52
-0
paddle/fluid/framework/variable.h
paddle/fluid/framework/variable.h
+3
-0
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+16
-10
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+4
-1
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+6
-0
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+7
-3
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+5
-0
paddle/fluid/operators/send_recv_util.h
paddle/fluid/operators/send_recv_util.h
+3
-0
paddle/fluid/operators/send_vars_op.cc
paddle/fluid/operators/send_vars_op.cc
+6
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
62af10d4
...
@@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h
...
@@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
...
@@ -26,7 +27,7 @@ endif()
...
@@ -26,7 +27,7 @@ endif()
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
scale_loss_grad_op_handle send_op_handle
rpc_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
62af10d4
...
@@ -12,10 +12,12 @@
...
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
...
@@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
CreateOpOutput
(
result
,
op_handle
,
each_var_name
,
p
,
place_id
);
CreateOpOutput
(
result
,
op_handle
,
each_var_name
,
p
,
place_id
);
}
}
}
}
bool
MultiDevSSAGraphBuilder
::
IsDistTrainOp
(
const
OpDesc
&
op
,
bool
MultiDevSSAGraphBuilder
::
IsDistTrainOp
(
const
OpDesc
&
op
,
OpDesc
*
send_op
)
const
{
OpDesc
*
send_op
)
const
{
if
(
send_op
==
nullptr
)
{
if
(
send_op
==
nullptr
)
{
...
@@ -98,7 +99,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
...
@@ -98,7 +99,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return
false
;
return
false
;
};
};
if
(
op
.
Type
()
==
"split"
)
{
if
(
op
.
Type
()
==
"split"
||
op
.
Type
()
==
"split_byref"
)
{
return
checker
(
op
.
OutputArgumentNames
(),
send_op
->
InputArgumentNames
());
return
checker
(
op
.
OutputArgumentNames
(),
send_op
->
InputArgumentNames
());
}
else
if
(
op
.
Type
()
==
"concat"
)
{
}
else
if
(
op
.
Type
()
==
"concat"
)
{
return
checker
(
op
.
InputArgumentNames
(),
send_op
->
OutputArgumentNames
());
return
checker
(
op
.
InputArgumentNames
(),
send_op
->
OutputArgumentNames
());
...
@@ -106,6 +107,15 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
...
@@ -106,6 +107,15 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return
false
;
return
false
;
}
}
bool
MultiDevSSAGraphBuilder
::
IsRPCOp
(
const
OpDesc
&
op
)
const
{
for
(
auto
&
name
:
op
.
OutputNames
())
{
if
(
name
==
"RPCClient"
)
{
return
true
;
}
}
return
false
;
}
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
const
ProgramDesc
&
program
)
const
{
const
ProgramDesc
&
program
)
const
{
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
var_types
;
std
::
unordered_map
<
std
::
string
,
proto
::
VarType
::
Type
>
var_types
;
...
@@ -133,10 +143,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -133,10 +143,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"send"
)
{
if
(
IsRPCOp
(
*
op
)
)
{
// append
send
op if program is distributed trainer main program.
// append
rpc
op if program is distributed trainer main program.
// always use the first device
// always use the first device
Create
Send
Op
(
&
result
,
*
op
);
Create
RPC
Op
(
&
result
,
*
op
);
}
else
if
(
IsDistTrainOp
(
*
op
,
send_op
))
{
}
else
if
(
IsDistTrainOp
(
*
op
,
send_op
))
{
CreateComputationalOps
(
&
result
,
*
op
,
1
);
CreateComputationalOps
(
&
result
,
*
op
,
1
);
}
else
if
(
IsScaleLossOp
(
*
op
))
{
}
else
if
(
IsScaleLossOp
(
*
op
))
{
...
@@ -203,9 +213,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -203,9 +213,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps
(
&
result
);
AddOutputToLeafOps
(
&
result
);
if
(
VLOG_IS_ON
(
10
))
{
if
(
VLOG_IS_ON
(
10
))
{
std
::
ostringstream
sout
;
std
::
string
filename
=
"/tmp/graph"
;
PrintGraphviz
(
*
graph
,
sout
);
std
::
ofstream
fout
(
filename
);
VLOG
(
10
)
<<
sout
.
str
(
);
PrintGraphviz
(
*
graph
,
fout
);
}
}
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
...
@@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
...
@@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return
var
;
return
var
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateSendOp
(
SSAGraph
*
result
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
std
::
string
op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
ops_
)
{
if
(
prev_op
->
Name
()
==
op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
();
prev_op
->
AddOutput
(
dep_var
);
result
->
dep_vars_
.
emplace
(
dep_var
);
result
->
ops_
.
back
().
get
()
->
AddInput
(
dep_var
);
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
auto
&
p
=
places_
[
0
];
auto
&
p
=
places_
[
0
];
auto
*
s
=
local_scopes_
[
0
];
auto
*
s
=
local_scopes_
[
0
];
VLOG
(
3
)
<<
"create rpc op: "
<<
op
.
Type
();
result
->
ops_
.
emplace_back
(
new
RPCOpHandle
(
op
,
s
,
p
,
op
.
Type
()));
if
(
op
.
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
"send_vars"
);
}
else
if
(
op
.
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
"send_barrier"
);
}
else
if
(
op
.
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
"recv"
);
}
else
if
(
op
.
Type
()
==
"send"
||
op
.
Type
()
==
"send_vars"
)
{
// do nothing
}
else
{
PADDLE_THROW
(
"rpc op should be in [send,"
"send_vars, send_barrier. recv, fetch_barrier]"
);
}
// FIXME(wuyi): send op always copy from GPU 0
// FIXME(wuyi): send op always copy from GPU 0
result
->
ops_
.
emplace_back
(
new
SendOpHandle
(
op
,
s
,
p
));
// result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()
));
// Create inputs for output on original place and no ssa output
// Create inputs for output on original place and no ssa output
// is created for send op.
// is created for send op.
CreateOpHandleIOs
(
result
,
op
,
0
);
CreateOpHandleIOs
(
result
,
op
,
0
);
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
62af10d4
...
@@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
void
CreateSendOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateSendOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
/**
/**
* Is this operator as the end-point operator before/after send operator.
* Is this operator as the end-point operator before/after send operator.
*/
*/
bool
IsDistTrainOp
(
const
OpDesc
&
op
,
OpDesc
*
send_op
)
const
;
bool
IsDistTrainOp
(
const
OpDesc
&
op
,
OpDesc
*
send_op
)
const
;
bool
IsRPCOp
(
const
OpDesc
&
op
)
const
;
void
ConnectOp
(
SSAGraph
*
result
,
std
::
string
op_name
)
const
;
void
CreateComputationalOps
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
void
CreateComputationalOps
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
size_t
num_places
)
const
;
size_t
num_places
)
const
;
...
...
paddle/fluid/framework/details/rpc_op_handle.cc
0 → 100644
浏览文件 @
62af10d4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/rpc_op_handle.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
RPCOpHandle
::
RPCOpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
place
,
const
std
::
string
&
name
)
:
op_
(
framework
::
OpRegistry
::
CreateOp
(
op_desc
)),
local_scope_
(
local_scope
),
place_
(
place
),
name_
(
name
)
{}
void
RPCOpHandle
::
RunImpl
()
{
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
if
(
in
->
DebugString
()
==
"dummy"
)
{
// HACK
continue
;
}
if
(
in
->
generated_op_
)
{
in
->
generated_op_
->
RecordWaitEventOnCtx
(
dev_ctxes_
[
p
]);
}
}
auto
&
tmp_scope
=
local_scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock.
op_
->
Run
(
*
tmp_scope
,
place_
);
}
std
::
string
RPCOpHandle
::
Name
()
const
{
return
name_
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/rpc_op_handle.h
0 → 100644
浏览文件 @
62af10d4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
struct
RPCOpHandle
:
public
OpHandleBase
{
RPCOpHandle
(
const
framework
::
OpDesc
&
op_desc
,
const
Scope
*
local_scope
,
const
platform
::
Place
&
place
,
const
std
::
string
&
name
);
std
::
string
Name
()
const
override
;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool
IsMultiDeviceTransfer
()
override
{
return
false
;
};
protected:
void
RunImpl
()
override
;
private:
std
::
unique_ptr
<
OperatorBase
>
op_
;
const
Scope
*
local_scope_
;
const
platform
::
Place
&
place_
;
const
std
::
string
name_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/variable.h
浏览文件 @
62af10d4
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <memory>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <string>
#include <typeindex>
#include <typeindex>
#include <typeinfo>
#include <typeinfo>
...
@@ -38,6 +39,7 @@ class Variable {
...
@@ -38,6 +39,7 @@ class Variable {
template
<
typename
T
>
template
<
typename
T
>
T
*
GetMutable
()
{
T
*
GetMutable
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
!
IsType
<
T
>
())
{
if
(
!
IsType
<
T
>
())
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
new
T
()));
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
new
T
()));
}
}
...
@@ -90,6 +92,7 @@ class Variable {
...
@@ -90,6 +92,7 @@ class Variable {
// by its address but not the unreadable name.
// by its address but not the unreadable name.
friend
class
Scope
;
friend
class
Scope
;
const
std
::
string
*
name_
;
const
std
::
string
*
name_
;
std
::
mutex
mutex_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
62af10d4
...
@@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
...
@@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch
=
GetChannel
(
ep_val
,
ep_val
+
":"
+
var_name_val
);
framework
::
AsyncIO
([
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch
,
framework
::
AsyncIO
([
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch
,
this
]
{
this
]
{
...
@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
...
@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch
=
GetChannel
(
ep_val
,
ep_val
+
":"
+
var_name_val
);
framework
::
AsyncIO
([
var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
framework
::
AsyncIO
([
var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{
this
]
{
...
@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
...
@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch
=
GetChannel
(
ep_val
,
ep_val
+
":"
+
in_var_name_val
);
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{
time_out
,
ch
,
this
]
{
...
@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
...
@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
}
}
void
RPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
RPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
,
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
s
->
Prepare
(
time_out
);
...
@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
}
}
void
RPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
RPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
,
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
s
->
Prepare
(
time_out
);
...
@@ -243,12 +243,19 @@ bool RPCClient::Proceed() {
...
@@ -243,12 +243,19 @@ bool RPCClient::Proceed() {
delete
c
;
delete
c
;
return
true
;
return
true
;
}
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
,
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
const
std
::
string
&
key
)
{
auto
it
=
channels_
.
find
(
ep
);
VLOG
(
3
)
<<
"this addr: "
<<
this
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
channels_
.
find
(
key
);
if
(
it
!=
channels_
.
end
())
{
if
(
it
!=
channels_
.
end
())
{
VLOG
(
3
)
<<
"find ep: "
<<
ep
;
return
it
->
second
;
return
it
->
second
;
}
}
VLOG
(
3
)
<<
"can not find ep: "
<<
ep
;
for
(
auto
it
=
channels_
.
begin
();
it
!=
channels_
.
end
();
++
it
)
{
VLOG
(
3
)
<<
"ep: "
<<
it
->
first
;
}
grpc
::
ChannelArguments
args
;
grpc
::
ChannelArguments
args
;
args
.
SetCompressionAlgorithm
(
GRPC_COMPRESS_NONE
);
args
.
SetCompressionAlgorithm
(
GRPC_COMPRESS_NONE
);
...
@@ -257,8 +264,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
...
@@ -257,8 +264,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto
ch
=
auto
ch
=
grpc
::
CreateCustomChannel
(
ep
,
grpc
::
InsecureChannelCredentials
(),
args
);
grpc
::
CreateCustomChannel
(
ep
,
grpc
::
InsecureChannelCredentials
(),
args
);
channels_
[
key
]
=
ch
;
channels_
[
ep
]
=
ch
;
return
ch
;
return
ch
;
}
}
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
62af10d4
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include <functional>
#include <functional>
#include <iostream>
#include <iostream>
#include <map>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -190,12 +191,14 @@ class RPCClient {
...
@@ -190,12 +191,14 @@ class RPCClient {
private:
private:
bool
Proceed
();
bool
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
,
const
std
::
string
&
key
);
private:
private:
grpc
::
CompletionQueue
cq_
;
grpc
::
CompletionQueue
cq_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
int64_t
req_count_
=
0
;
int64_t
req_count_
=
0
;
std
::
mutex
mutex_
;
};
};
}
// namespace detail
}
// namespace detail
...
...
paddle/fluid/operators/fetch_barrier_op.cc
浏览文件 @
62af10d4
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -37,6 +38,11 @@ class FetchBarrierOp : public framework::OperatorBase {
...
@@ -37,6 +38,11 @@ class FetchBarrierOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
"Can not find variable '%s' in the scope."
,
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
62af10d4
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -37,15 +38,18 @@ class RecvOp : public framework::OperatorBase {
...
@@ -37,15 +38,18 @@ class RecvOp : public framework::OperatorBase {
auto
outs
=
Outputs
(
"Out"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
"Can not find variable '%s' in the scope."
,
client_var_name
);
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
62af10d4
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -37,6 +38,10 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -37,6 +38,10 @@ class SendBarrierOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
"Can not find variable '%s' in the scope."
,
...
...
paddle/fluid/operators/send_recv_util.h
浏览文件 @
62af10d4
...
@@ -20,6 +20,9 @@ namespace operators {
...
@@ -20,6 +20,9 @@ namespace operators {
inline
bool
NeedSend
(
const
framework
::
Scope
&
scope
,
inline
bool
NeedSend
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
varname
)
{
const
std
::
string
&
varname
)
{
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if
(
varname
==
"dummy"
)
return
false
;
auto
*
var
=
scope
.
FindVar
(
varname
);
auto
*
var
=
scope
.
FindVar
(
varname
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
varname
);
varname
);
...
...
paddle/fluid/operators/send_vars_op.cc
浏览文件 @
62af10d4
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -41,12 +42,17 @@ class SendVarsOp : public framework::OperatorBase {
...
@@ -41,12 +42,17 @@ class SendVarsOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
client_var_name
),
"Can not find variable '%s' in the scope."
,
"Can not find variable '%s' in the scope."
,
client_var_name
);
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
VLOG
(
3
)
<<
"client var addr: "
<<
client_var
;
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
VLOG
(
3
)
<<
"rpc_client addr: "
<<
rpc_client
;
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录