Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ddc15a18
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ddc15a18
编写于
12月 22, 2021
作者:
L
LiYuRio
提交者:
GitHub
12月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Move IntraSend to Carrier. Using blocking queue (#38322)
上级
142ea171
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
228 addition
and
262 deletion
+228
-262
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+2
-2
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+48
-23
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+8
-2
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+15
-15
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+11
-2
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+5
-23
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+5
-14
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+2
-2
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+62
-110
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+8
-10
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
+1
-1
paddle/fluid/distributed/fleet_executor/runtime_graph.h
paddle/fluid/distributed/fleet_executor/runtime_graph.h
+10
-10
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+2
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+35
-28
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+2
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+2
-4
paddle/fluid/framework/blocking_queue.h
paddle/fluid/framework/blocking_queue.h
+6
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
ddc15a18
...
...
@@ -5,7 +5,7 @@ endif()
proto_library
(
interceptor_message_proto SRCS interceptor_message.proto
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
set
(
BRPC_DEPS brpc ssl crypto protobuf
gflags glog
zlib leveldb snappy gflags glog
)
set
(
BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog
)
else
()
set
(
BRPC_DEPS
""
)
endif
()
...
...
@@ -13,7 +13,7 @@ endif()
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper
${
BRPC_DEPS
}
)
executor_gc_helper
gflags glog
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
ddc15a18
...
...
@@ -27,14 +27,16 @@ namespace distributed {
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
void
Carrier
::
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Carrier
::
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"Carrier is already init."
));
rank_
=
rank
;
runtime_graph_
=
runtime_graph
;
interceptor_id_to_rank_
=
runtime_graph_
->
interceptor_id_to_rank
();
minibatch_scope_
=
minibatch_scope
;
microbatch_scopes_
=
microbatch_scopes
;
place_
=
place
;
...
...
@@ -48,12 +50,6 @@ void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
...
...
@@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool
Carrier
::
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
// enqueue message to interceptor
if
(
interceptor_message
.
ctrl_message
())
{
// handle control message
return
true
;
VLOG
(
3
)
<<
"Receiving control message from rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
}
else
{
{
std
::
unique_lock
<
std
::
mutex
>
lock_creating
(
creating_flag_mutex_
);
...
...
@@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
}
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
bool
rst
=
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
if
(
rst
)
{
std
::
condition_variable
&
interceptor_cond_var
=
dst_interceptor
->
GetCondVar
();
interceptor_cond_var
.
notify_all
();
}
return
rst
;
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
}
return
true
;
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
...
...
@@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
// TODO(liyurui): Move SendIntra into carrier
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
const
{
return
msg_bus_
->
Send
(
msg
);
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
PADDLE_ENFORCE_NE
(
interceptor_id_to_rank_
.
find
(
interceptor_id
),
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for interceptor id %lld."
,
interceptor_id
));
return
interceptor_id_to_rank_
.
at
(
interceptor_id
);
}
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
{
int64_t
src_id
=
(
msg
.
src_id
()
==
-
1
)
?
msg
.
dst_id
()
:
msg
.
src_id
();
int64_t
dst_id
=
msg
.
dst_id
();
int64_t
src_rank
=
GetRank
(
src_id
);
int64_t
dst_rank
=
GetRank
(
dst_id
);
PADDLE_ENFORCE_EQ
(
src_rank
,
rank_
,
platform
::
errors
::
Fatal
(
"The source rank id %lld, which is not equal to "
"the carrier rank id %lld."
,
src_rank
,
rank_
));
if
(
src_rank
==
dst_rank
)
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
EnqueueInterceptorMessage
(
msg
);
}
else
{
PADDLE_ENFORCE_NOT_NULL
(
msg_bus_
.
get
(),
platform
::
errors
::
Unavailable
(
"Message bus is released accidently"
));
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in different ranks."
;
return
msg_bus_
->
Send
(
dst_rank
,
msg
);
}
}
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
...
...
@@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
void
Carrier
::
CreateInterceptors
()
{
if
(
runtime_graph_
->
intercept
e
r_id_to_node
().
empty
())
return
;
if
(
runtime_graph_
->
intercept
o
r_id_to_node
().
empty
())
return
;
auto
gc
=
GetGC
(
place_
);
// create each Interceptor
// no auto init since there is no config
for
(
const
auto
&
item
:
runtime_graph_
->
intercept
e
r_id_to_node
())
{
for
(
const
auto
&
item
:
runtime_graph_
->
intercept
o
r_id_to_node
())
{
int64_t
interceptor_id
=
item
.
first
;
TaskNode
*
task_node
=
item
.
second
;
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
ddc15a18
...
...
@@ -45,8 +45,11 @@ class MessageBus;
class
Carrier
final
{
public:
Carrier
()
=
default
;
Carrier
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{}
~
Carrier
();
void
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
...
...
@@ -75,7 +78,7 @@ class Carrier final {
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
)
const
;
bool
Send
(
const
InterceptorMessage
&
msg
);
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
...
...
@@ -90,6 +93,8 @@ class Carrier final {
void
HandleTmpMessages
();
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
// interceptor logic id to actually interceptor
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
...
...
@@ -111,6 +116,7 @@ class Carrier final {
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
};
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
ddc15a18
...
...
@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -28,6 +27,8 @@
namespace
paddle
{
namespace
distributed
{
std
::
unique_ptr
<
Carrier
>
FleetExecutor
::
carrier_
;
FleetExecutor
::
FleetExecutor
(
const
std
::
string
&
exe_desc_str
)
{
bool
parse_flag
=
exe_desc_
.
ParseFromString
(
exe_desc_str
);
PADDLE_ENFORCE
(
parse_flag
,
platform
::
errors
::
PreconditionNotMet
(
...
...
@@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
GetCarrier
()
.
Release
();
GetCarrier
()
->
Release
();
}
Carrier
&
FleetExecutor
::
GetCarrier
()
{
static
Carrier
carrier
;
return
carrier
;
Carrier
*
FleetExecutor
::
GetCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
.
get
(),
platform
::
errors
::
NotFound
(
"Carrier has not been created."
));
return
carrier_
.
get
();
}
void
FleetExecutor
::
Init
(
...
...
@@ -84,16 +86,16 @@ void FleetExecutor::Init(
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
CreateCarrier
();
InitCarrier
();
InitMessageBus
();
}
void
FleetExecutor
::
InitCarrier
()
{
Carrier
&
carrier
=
GetCarrier
();
if
(
!
carrier
.
IsInit
())
{
carrier
.
SetMsgBus
(
msg_bus_
);
carrier
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
if
(
!
GetCarrier
()
->
IsInit
())
{
GetCarrier
()
->
SetMsgBus
(
msg_bus_
);
GetCarrier
()
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
}
}
...
...
@@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
VLOG
(
5
)
<<
ss
.
str
();
if
(
!
msg_bus_
->
IsInit
())
{
msg_bus_
->
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
rank_to_addr
,
addr
);
msg_bus_
->
Init
(
cur_rank
,
rank_to_addr
,
addr
);
}
}
void
FleetExecutor
::
Run
()
{
// Run
Carrier
&
carrier
=
GetCarrier
();
PADDLE_ENFORCE_EQ
(
carrier
.
IsInit
(),
true
,
GetCarrier
()
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
carrier
.
Start
();
GetCarrier
()
->
Start
();
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
ddc15a18
...
...
@@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
...
...
@@ -30,7 +31,6 @@ namespace distributed {
class
RuntimeGraph
;
class
MessageBus
;
class
TaskNode
;
class
Carrier
;
class
FleetExecutor
final
{
public:
...
...
@@ -43,7 +43,15 @@ class FleetExecutor final {
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
// TODO(liyurui): Change to use registry table for multi-carrier.
static
Carrier
&
GetCarrier
();
static
Carrier
*
GetCarrier
();
template
<
typename
...
Args
>
static
Carrier
*
CreateCarrier
(
Args
&&
...
args
)
{
PADDLE_ENFORCE_EQ
(
carrier_
.
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Carrier has been created already."
));
carrier_
=
std
::
make_unique
<
Carrier
>
(
std
::
forward
<
Args
>
(
args
)...);
return
carrier_
.
get
();
}
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
...
@@ -59,6 +67,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
static
std
::
unique_ptr
<
Carrier
>
carrier_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
ddc15a18
...
...
@@ -52,24 +52,17 @@ void Interceptor::StopCarrier() {
cond_var
.
notify_all
();
}
std
::
condition_variable
&
Interceptor
::
GetCondVar
()
{
// get the conditional var
return
cond_var_
;
}
int64_t
Interceptor
::
GetInterceptorId
()
const
{
// return the interceptor id
return
interceptor_id_
;
}
bool
Interceptor
::
EnqueueRemoteInterceptorMessage
(
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG
(
3
)
<<
"Enqueue message: "
<<
interceptor_message
.
message_type
()
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
remote_mailbox_mutex_
);
remote_mailbox_
.
push
(
interceptor_message
);
return
true
;
remote_mailbox_
.
Push
(
interceptor_message
);
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
...
...
@@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() {
"Error encountered when fetch remote mailbox."
));
}
const
InterceptorMessage
interceptor_message
=
local_mailbox_
.
front
();
local_mailbox_
.
pop
();
local_mailbox_
.
pop
_front
();
const
MessageType
message_type
=
interceptor_message
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
interceptor_message
.
src_id
()
...
...
@@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() {
}
bool
Interceptor
::
FetchRemoteMailbox
()
{
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
std
::
unique_lock
<
std
::
mutex
>
lock
(
remote_mailbox_mutex_
);
cond_var_
.
wait
(
lock
,
[
this
]()
{
return
!
remote_mailbox_
.
empty
();
});
if
(
remote_mailbox_
.
empty
())
{
// the thread has been unblocked accidentally
return
false
;
}
while
(
!
remote_mailbox_
.
empty
())
{
local_mailbox_
.
push
(
std
::
move
(
remote_mailbox_
.
front
()));
remote_mailbox_
.
pop
();
}
return
true
;
remote_mailbox_
.
PopAll
(
&
local_mailbox_
);
return
!
local_mailbox_
.
empty
();
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
ddc15a18
...
...
@@ -15,14 +15,15 @@
#pragma once
#include <condition_variable>
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <queue>
#include <thread>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
...
...
@@ -59,11 +60,8 @@ class Interceptor {
// return the interceptor id
int64_t
GetInterceptorId
()
const
;
// return the conditional var
std
::
condition_variable
&
GetCondVar
();
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
bool
EnqueueRemoteInterceptorMessage
(
void
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
);
// NOLINT
...
...
@@ -115,23 +113,16 @@ class Interceptor {
// interceptor handle which process message
MsgHandle
handle_
{
nullptr
};
// mutex to control read/write conflict for remote mailbox
std
::
mutex
remote_mailbox_mutex_
;
// interceptor runs PoolTheMailbox() function to poll local mailbox
std
::
thread
interceptor_thread_
;
// conditional variable for blocking the thread when
// fetch an empty remote mailbox
std
::
condition_variable
cond_var_
;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
std
::
q
ueue
<
InterceptorMessage
>
remote_mailbox_
;
framework
::
BlockingQ
ueue
<
InterceptorMessage
>
remote_mailbox_
;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std
::
que
ue
<
InterceptorMessage
>
local_mailbox_
;
std
::
deq
ue
<
InterceptorMessage
>
local_mailbox_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
ddc15a18
...
...
@@ -29,8 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
true
);
bool
flag
=
FleetExecutor
::
GetCarrier
()
->
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
flag
);
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
ddc15a18
...
...
@@ -17,8 +17,6 @@
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
...
...
@@ -26,16 +24,25 @@ namespace paddle {
namespace
distributed
{
void
MessageBus
::
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"MessageBus is already init."
));
rank_
=
rank
;
is_init_
=
true
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
rank_to_addr_
=
rank_to_addr
;
addr_
=
addr
;
if
(
addr_
!=
""
)
{
const
auto
&
addr
=
GetAddr
(
rank_
);
PADDLE_ENFORCE_EQ
(
addr
,
addr_
,
platform
::
errors
::
Fatal
(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error."
,
addr
,
addr_
));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
...
...
@@ -65,43 +72,57 @@ MessageBus::~MessageBus() {
#endif
}
bool
MessageBus
::
Send
(
const
InterceptorMessage
&
interceptor_message
)
{
// called by Interceptor, send InterceptorMessage to dst
int64_t
src_id
=
interceptor_message
.
src_id
();
int64_t
dst_id
=
interceptor_message
.
dst_id
();
if
(
IsSameRank
(
src_id
,
dst_id
))
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
SendIntraRank
(
interceptor_message
);
}
else
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in different ranks."
;
const
std
::
string
&
MessageBus
::
GetAddr
(
int64_t
rank
)
const
{
PADDLE_ENFORCE_NE
(
rank_to_addr_
.
find
(
rank
),
rank_to_addr_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find addr rank id %lld."
,
rank
));
return
rank_to_addr_
.
at
(
rank
);
}
bool
MessageBus
::
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
int
retry_time
=
0
;
// message bus will retry sending for 10 times
while
(
retry_time
<
10
)
{
++
retry_time
;
if
(
SendInterRank
(
interceptor_message
))
{
VLOG
(
3
)
<<
"Message bus sends inter rank successfully with "
<<
retry_time
<<
" times retries."
;
return
true
;
}
VLOG
(
3
)
<<
"Message bus sends failed, retry after 1 seconds."
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
int
retry_time
=
0
;
// message bus will retry sending for 10 times
while
(
retry_time
<
10
)
{
++
retry_time
;
if
(
SendInterRank
(
dst_rank
,
interceptor_message
))
{
VLOG
(
3
)
<<
"Message bus sends inter rank successfully with "
<<
retry_time
<<
" times retries."
;
return
true
;
}
VLOG
(
3
)
<<
"Message bus sends inter rank fail after 10 times retries."
;
return
false
;
VLOG
(
3
)
<<
"Message bus sends failed, retry after 1 seconds."
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus sends inter rank fail after 10 times retries."
;
return
false
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."
));
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."
));
#endif
}
return
true
;
}
void
MessageBus
::
TestConnection
()
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
for
(
const
auto
&
dst_rank_pair
:
rank_to_addr_
)
{
int64_t
dst_rank
=
dst_rank_pair
.
first
;
if
(
dst_rank
!=
rank_
)
{
ctrl_msg
.
set_dst_id
(
dst_rank
);
VLOG
(
3
)
<<
"Send control message bus from rank "
<<
rank_
<<
" to rank "
<<
dst_rank
;
while
(
!
Send
(
dst_rank
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
dst_rank
<<
"."
;
}
}
}
void
MessageBus
::
ListenPort
()
{
if
(
addr_
==
""
)
{
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
...
...
@@ -130,30 +151,7 @@ void MessageBus::ListenPort() {
interval
+=
500
;
}
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
std
::
set
<
int64_t
>
visit
;
InterceptorMessage
tmp_msg
;
tmp_msg
.
set_ctrl_message
(
true
);
for
(
auto
pair
:
interceptor_id_to_rank_
)
{
if
(
rank_to_addr_
.
at
(
pair
.
second
)
==
addr_
)
{
tmp_msg
.
set_src_id
(
pair
.
first
);
}
}
for
(
auto
pair
:
interceptor_id_to_rank_
)
{
int64_t
rank
=
pair
.
second
;
if
(
rank_to_addr_
.
at
(
rank
)
==
addr_
)
{
continue
;
}
tmp_msg
.
set_dst_id
(
pair
.
first
);
if
(
visit
.
find
(
rank
)
==
visit
.
end
())
{
VLOG
(
3
)
<<
"Message bus is testing connection for rank: "
<<
rank
<<
"."
;
visit
.
insert
(
rank
);
while
(
!
Send
(
tmp_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
rank
<<
"."
;
}
}
TestConnection
();
#else
LOG
(
WARNING
)
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
...
...
@@ -162,53 +160,13 @@ void MessageBus::ListenPort() {
#endif
}
bool
MessageBus
::
IsSameRank
(
int64_t
src_id
,
int64_t
dst_id
)
{
// -1 is sent by carrier to source interceptor
if
(
src_id
==
-
1
)
src_id
=
dst_id
;
// check whether the dst is the same rank or different rank with src
const
auto
&
src_rank
=
interceptor_id_to_rank_
.
find
(
src_id
);
const
auto
&
dst_rank
=
interceptor_id_to_rank_
.
find
(
dst_id
);
PADDLE_ENFORCE_NE
(
src_rank
,
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for src interceptor id %lld. Init error."
,
src_id
));
PADDLE_ENFORCE_NE
(
dst_rank
,
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for dst interceptor id %lld. Init error."
,
dst_id
));
if
(
addr_
==
""
)
{
// single card training, must be same rank
return
true
;
}
const
auto
&
src_ip
=
rank_to_addr_
.
find
(
src_rank
->
second
);
PADDLE_ENFORCE_NE
(
src_ip
,
rank_to_addr_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find addr for src rank id %lld. Init error."
,
src_rank
->
second
));
PADDLE_ENFORCE_EQ
(
src_ip
->
second
,
addr_
,
platform
::
errors
::
Fatal
(
"The src interceptor's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error."
,
src_ip
->
second
,
addr_
));
return
src_rank
->
second
==
dst_rank
->
second
;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
bool
MessageBus
::
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
)
{
// send the message inter rank (dst is different rank with src)
int64_t
dst_id
=
interceptor_message
.
dst_id
();
int64_t
dst_rank
=
interceptor_id_to_rank_
[
dst_id
];
auto
dst_ip
=
rank_to_addr_
.
find
(
dst_rank
);
PADDLE_ENFORCE_NE
(
dst_ip
,
rank_to_addr_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Cannot find rank for dst interceptor id %lld. "
"Init error."
,
dst_id
));
VLOG
(
3
)
<<
"Message bus sending to addr: "
<<
dst_ip
->
second
;
const
char
*
dst_ip_for_brpc
=
dst_ip
->
second
.
c_str
();
bool
MessageBus
::
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
const
auto
&
dst_addr
=
GetAddr
(
dst_rank
);
VLOG
(
3
)
<<
"Message bus sending to addr: "
<<
dst_addr
;
const
char
*
dst_addr_for_brpc
=
dst_addr
.
c_str
();
brpc
::
Channel
channel
;
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
...
...
@@ -216,7 +174,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
options
.
timeout_ms
=
1000
;
options
.
max_retry
=
5
;
PADDLE_ENFORCE_EQ
(
channel
.
Init
(
dst_
ip
_for_brpc
,
&
options
),
0
,
channel
.
Init
(
dst_
addr
_for_brpc
,
&
options
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc channel error."
));
TheInterceptorMessageService_Stub
stub
(
&
channel
);
InterceptorResponse
response
;
...
...
@@ -239,11 +197,5 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
}
#endif
bool
MessageBus
::
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
)
{
// send the message intra rank (dst is the same rank with src)
return
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
interceptor_message
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
ddc15a18
...
...
@@ -42,14 +42,14 @@ class MessageBus final {
MessageBus
()
=
default
;
~
MessageBus
();
void
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_
rank
,
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
);
bool
IsInit
()
const
;
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
...
...
@@ -57,22 +57,20 @@ class MessageBus final {
// function keep listen the port and handle the message
void
ListenPort
();
// check whether the dst is the same rank or different rank with src
bool
IsSameRank
(
int64_t
src_id
,
int64_t
dst_id
);
void
TestConnection
();
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src)
bool
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
);
bool
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
#endif
bool
is_init_
{
false
};
// send the message intra rank (dst is the same rank with src)
bool
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
);
// handed by above layer, save the info mapping interceptor id to rank id
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int64_t
rank_
;
// handed by above layer, save the info mapping rank id to addr
std
::
unordered_map
<
int64_t
,
std
::
string
>
rank_to_addr_
;
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
浏览文件 @
ddc15a18
...
...
@@ -21,7 +21,7 @@ namespace distributed {
std
::
string
RuntimeGraph
::
DebugString
()
const
{
std
::
ostringstream
os
;
os
<<
"
\n
Runtime Graph Debug:
\n
"
;
for
(
const
auto
&
pair
:
intercept
e
r_id_to_node_
)
{
for
(
const
auto
&
pair
:
intercept
o
r_id_to_node_
)
{
os
<<
pair
.
second
->
DebugString
();
os
<<
"
\n
"
;
}
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.h
浏览文件 @
ddc15a18
...
...
@@ -29,26 +29,26 @@ class RuntimeGraph final {
public:
RuntimeGraph
()
=
default
;
~
RuntimeGraph
()
=
default
;
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
e
r_id_to_node
()
const
{
return
intercept
e
r_id_to_node_
;
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
o
r_id_to_node
()
const
{
return
intercept
o
r_id_to_node_
;
}
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
e
r_id_to_rank
()
const
{
return
intercept
e
r_id_to_rank_
;
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
o
r_id_to_rank
()
const
{
return
intercept
o
r_id_to_rank_
;
}
void
SetInterceptorIdToRank
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
e
r_id_to_rank
)
{
intercept
er_id_to_rank_
=
intercepte
r_id_to_rank
;
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
o
r_id_to_rank
)
{
intercept
or_id_to_rank_
=
intercepto
r_id_to_rank
;
}
void
SetInterceptorIdToNode
(
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
e
r_id_to_node
)
{
intercept
er_id_to_node_
=
intercepte
r_id_to_node
;
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
o
r_id_to_node
)
{
intercept
or_id_to_node_
=
intercepto
r_id_to_node
;
}
std
::
string
DebugString
()
const
;
private:
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
std
::
unordered_map
<
int64_t
,
TaskNode
*>
intercept
e
r_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
intercept
e
r_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
intercept
o
r_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
intercept
o
r_id_to_rank_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
ddc15a18
...
...
@@ -62,11 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
ddc15a18
...
...
@@ -47,11 +47,10 @@ class StartInterceptor : public Interceptor {
};
TEST
(
ComputeInterceptor
,
Compute
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
ddc15a18
...
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -60,11 +59,9 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
ddc15a18
...
...
@@ -104,35 +104,42 @@ TEST(InterceptorTest, PingPong) {
std
::
string
ip1
=
"127.0.0.1:"
+
std
::
to_string
(
port1
);
std
::
cout
<<
"ip0: "
<<
ip0
<<
std
::
endl
;
std
::
cout
<<
"ip1: "
<<
ip1
<<
std
::
endl
;
int
pid
=
fork
();
if
(
pid
==
0
)
{
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank
=
{{
0
,
0
},
{
1
,
1
}};
int
exe_pid
=
fork
();
if
(
exe_pid
==
0
)
{
int
pid
=
fork
();
if
(
pid
==
0
)
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
}
else
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
->
Wait
();
int
status
;
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
pid
);
}
}
else
{
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
carrier
.
Wait
();
int
status
;
int
ret
=
waitpid
(
exe_pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
exe_pid
);
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
ddc15a18
...
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -52,11 +51,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
ddc15a18
...
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -70,10 +69,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}}
,
{{
0
,
""
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
...
...
paddle/fluid/framework/blocking_queue.h
浏览文件 @
ddc15a18
...
...
@@ -75,6 +75,12 @@ class BlockingQueue {
return
ret
;
}
void
PopAll
(
std
::
deque
<
T
>
*
empty_queue
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
!
q_
.
empty
();
});
std
::
swap
(
*
empty_queue
,
q_
);
}
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录