Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cd2855b0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cd2855b0
编写于
1月 10, 2022
作者:
L
LiYuRio
提交者:
GitHub
1月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Add barrier rpc (#38799)
上级
492e6dd0
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
44 addition
and
34 deletion
+44
-34
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+3
-3
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+0
-1
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+0
-1
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
...luid/distributed/fleet_executor/interceptor_message.proto
+3
-2
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+14
-18
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+2
-2
paddle/fluid/distributed/fleet_executor/message_service.cc
paddle/fluid/distributed/fleet_executor/message_service.cc
+14
-3
paddle/fluid/distributed/fleet_executor/message_service.h
paddle/fluid/distributed/fleet_executor/message_service.h
+8
-4
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
cd2855b0
...
...
@@ -13,7 +13,7 @@ endif()
cc_library
(
task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog
)
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc
interceptor_
message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
...
...
@@ -29,8 +29,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
message_bus.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
fleet_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
carrier.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
interceptor_
message_service.h PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
interceptor_
message_service.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_service.h PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_service.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
add_subdirectory
(
test
)
endif
()
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
cd2855b0
...
...
@@ -15,7 +15,6 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
cd2855b0
...
...
@@ -137,7 +137,6 @@ void FleetExecutor::Run(const std::string& carrier_id) {
// Set current running carrier
if
(
*
GlobalVal
<
std
::
string
>::
Get
()
!=
carrier_id
)
{
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
// TODO(liyurui): Move barrier to service
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
carrier
->
Start
();
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
浏览文件 @
cd2855b0
...
...
@@ -34,7 +34,8 @@ message InterceptorMessage {
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
service
TheInterceptor
MessageService
{
rpc
InterceptorMessageServic
e
(
InterceptorMessage
)
service
MessageService
{
rpc
ReceiveInterceptorMessag
e
(
InterceptorMessage
)
returns
(
InterceptorResponse
);
rpc
IncreaseBarrierCount
(
InterceptorMessage
)
returns
(
InterceptorResponse
);
}
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
cd2855b0
...
...
@@ -163,18 +163,9 @@ void MessageBus::Barrier() {
bool
MessageBus
::
DispatchMsgToCarrier
(
const
InterceptorMessage
&
interceptor_message
)
{
if
(
interceptor_message
.
ctrl_message
())
{
VLOG
(
3
)
<<
"Receiving control message from rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
// for barrier
IncreaseBarrierCount
();
return
true
;
}
else
{
const
std
::
string
&
carrier_id
=
*
GlobalVal
<
std
::
string
>::
Get
();
return
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
interceptor_message
);
}
const
std
::
string
&
carrier_id
=
*
GlobalVal
<
std
::
string
>::
Get
();
return
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
interceptor_message
);
}
void
MessageBus
::
ListenPort
()
{
...
...
@@ -185,10 +176,9 @@ void MessageBus::ListenPort() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ
(
server_
.
AddService
(
&
interceptor_message_service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc service error."
));
PADDLE_ENFORCE_EQ
(
server_
.
AddService
(
&
message_service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc service error."
));
// start the server
const
char
*
ip_for_brpc
=
addr_
.
c_str
();
...
...
@@ -229,11 +219,16 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
PADDLE_ENFORCE_EQ
(
channel
.
Init
(
dst_addr_for_brpc
,
&
options
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc channel error."
));
TheInterceptor
MessageService_Stub
stub
(
&
channel
);
MessageService_Stub
stub
(
&
channel
);
InterceptorResponse
response
;
brpc
::
Controller
ctrl
;
ctrl
.
set_log_id
(
0
);
stub
.
InterceptorMessageService
(
&
ctrl
,
&
interceptor_message
,
&
response
,
NULL
);
if
(
interceptor_message
.
ctrl_message
())
{
stub
.
IncreaseBarrierCount
(
&
ctrl
,
&
interceptor_message
,
&
response
,
NULL
);
}
else
{
stub
.
ReceiveInterceptorMessage
(
&
ctrl
,
&
interceptor_message
,
&
response
,
NULL
);
}
if
(
!
ctrl
.
Failed
())
{
if
(
response
.
rst
())
{
VLOG
(
3
)
<<
"Message bus: brpc sends success."
;
...
...
@@ -248,6 +243,7 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
return
false
;
}
}
#endif
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
cd2855b0
...
...
@@ -24,7 +24,7 @@
!defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/
interceptor_
message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
...
...
@@ -83,7 +83,7 @@ class MessageBus final {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
InterceptorMessageServiceImpl
interceptor_
message_service_
;
MessageServiceImpl
message_service_
;
// brpc server
brpc
::
Server
server_
;
#endif
...
...
paddle/fluid/distributed/fleet_executor/
interceptor_
message_service.cc
→
paddle/fluid/distributed/fleet_executor/message_service.cc
浏览文件 @
cd2855b0
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/
interceptor_
message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -21,18 +21,29 @@
namespace
paddle
{
namespace
distributed
{
void
InterceptorMessageServiceImpl
::
InterceptorMessageServic
e
(
void
MessageServiceImpl
::
ReceiveInterceptorMessag
e
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
VLOG
(
3
)
<<
"
Interceptor
Message Service receives a message from interceptor "
VLOG
(
3
)
<<
"Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
bool
flag
=
GlobalVal
<
MessageBus
>::
Get
()
->
DispatchMsgToCarrier
(
*
request
);
response
->
set_rst
(
flag
);
}
void
MessageServiceImpl
::
IncreaseBarrierCount
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
VLOG
(
3
)
<<
"Barrier Service receives a message from rank "
<<
request
->
src_id
()
<<
" to rank "
<<
request
->
dst_id
();
GlobalVal
<
MessageBus
>::
Get
()
->
IncreaseBarrierCount
();
response
->
set_rst
(
true
);
}
}
// namespace distributed
}
// namespace paddle
#endif
paddle/fluid/distributed/fleet_executor/
interceptor_
message_service.h
→
paddle/fluid/distributed/fleet_executor/message_service.h
浏览文件 @
cd2855b0
...
...
@@ -21,11 +21,15 @@
namespace
paddle
{
namespace
distributed
{
class
InterceptorMessageServiceImpl
:
public
TheInterceptor
MessageService
{
class
MessageServiceImpl
:
public
MessageService
{
public:
InterceptorMessageServiceImpl
()
{}
virtual
~
InterceptorMessageServiceImpl
()
{}
virtual
void
InterceptorMessageService
(
MessageServiceImpl
()
{}
virtual
~
MessageServiceImpl
()
{}
virtual
void
ReceiveInterceptorMessage
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
);
virtual
void
IncreaseBarrierCount
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录