Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ddc15a18
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ddc15a18
编写于
12月 22, 2021
作者:
L
LiYuRio
提交者:
GitHub
12月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Move IntraSend to Carrier. Using blocking queue (#38322)
上级
142ea171
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
228 addition
and
262 deletion
+228
-262
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+2
-2
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+48
-23
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+8
-2
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+15
-15
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+11
-2
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+5
-23
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+5
-14
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+2
-2
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+62
-110
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+8
-10
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
+1
-1
paddle/fluid/distributed/fleet_executor/runtime_graph.h
paddle/fluid/distributed/fleet_executor/runtime_graph.h
+10
-10
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+2
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+35
-28
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+2
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+2
-4
paddle/fluid/framework/blocking_queue.h
paddle/fluid/framework/blocking_queue.h
+6
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
ddc15a18
...
...
@@ -5,7 +5,7 @@ endif()
proto_library
(
interceptor_message_proto SRCS interceptor_message.proto
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
set
(
BRPC_DEPS brpc ssl crypto protobuf
gflags glog
zlib leveldb snappy gflags glog
)
set
(
BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog
)
else
()
set
(
BRPC_DEPS
""
)
endif
()
...
...
@@ -13,7 +13,7 @@ endif()
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper
${
BRPC_DEPS
}
)
executor_gc_helper
gflags glog
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
ddc15a18
...
...
@@ -27,14 +27,16 @@ namespace distributed {
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
void
Carrier
::
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Carrier
::
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"Carrier is already init."
));
rank_
=
rank
;
runtime_graph_
=
runtime_graph
;
interceptor_id_to_rank_
=
runtime_graph_
->
interceptor_id_to_rank
();
minibatch_scope_
=
minibatch_scope
;
microbatch_scopes_
=
microbatch_scopes
;
place_
=
place
;
...
...
@@ -48,12 +50,6 @@ void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
...
...
@@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool
Carrier
::
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
// enqueue message to interceptor
if
(
interceptor_message
.
ctrl_message
())
{
// handle control message
return
true
;
VLOG
(
3
)
<<
"Receiving control message from rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
}
else
{
{
std
::
unique_lock
<
std
::
mutex
>
lock_creating
(
creating_flag_mutex_
);
...
...
@@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
}
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
bool
rst
=
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
if
(
rst
)
{
std
::
condition_variable
&
interceptor_cond_var
=
dst_interceptor
->
GetCondVar
();
interceptor_cond_var
.
notify_all
();
}
return
rst
;
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
}
return
true
;
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
...
...
@@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
// TODO(liyurui): Move SendIntra into carrier
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
const
{
return
msg_bus_
->
Send
(
msg
);
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
PADDLE_ENFORCE_NE
(
interceptor_id_to_rank_
.
find
(
interceptor_id
),
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for interceptor id %lld."
,
interceptor_id
));
return
interceptor_id_to_rank_
.
at
(
interceptor_id
);
}
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
{
int64_t
src_id
=
(
msg
.
src_id
()
==
-
1
)
?
msg
.
dst_id
()
:
msg
.
src_id
();
int64_t
dst_id
=
msg
.
dst_id
();
int64_t
src_rank
=
GetRank
(
src_id
);
int64_t
dst_rank
=
GetRank
(
dst_id
);
PADDLE_ENFORCE_EQ
(
src_rank
,
rank_
,
platform
::
errors
::
Fatal
(
"The source rank id %lld, which is not equal to "
"the carrier rank id %lld."
,
src_rank
,
rank_
));
if
(
src_rank
==
dst_rank
)
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
EnqueueInterceptorMessage
(
msg
);
}
else
{
PADDLE_ENFORCE_NOT_NULL
(
msg_bus_
.
get
(),
platform
::
errors
::
Unavailable
(
"Message bus is released accidently"
));
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in different ranks."
;
return
msg_bus_
->
Send
(
dst_rank
,
msg
);
}
}
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
...
...
@@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
void
Carrier
::
CreateInterceptors
()
{
if
(
runtime_graph_
->
intercept
e
r_id_to_node
().
empty
())
return
;
if
(
runtime_graph_
->
intercept
o
r_id_to_node
().
empty
())
return
;
auto
gc
=
GetGC
(
place_
);
// create each Interceptor
// no auto init since there is no config
for
(
const
auto
&
item
:
runtime_graph_
->
intercept
e
r_id_to_node
())
{
for
(
const
auto
&
item
:
runtime_graph_
->
intercept
o
r_id_to_node
())
{
int64_t
interceptor_id
=
item
.
first
;
TaskNode
*
task_node
=
item
.
second
;
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
ddc15a18
...
...
@@ -45,8 +45,11 @@ class MessageBus;
class
Carrier
final
{
public:
Carrier
()
=
default
;
Carrier
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{}
~
Carrier
();
void
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
...
...
@@ -75,7 +78,7 @@ class Carrier final {
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
)
const
;
bool
Send
(
const
InterceptorMessage
&
msg
);
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
...
...
@@ -90,6 +93,8 @@ class Carrier final {
void
HandleTmpMessages
();
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
// interceptor logic id to actually interceptor
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
...
...
@@ -111,6 +116,7 @@ class Carrier final {
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
};
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
ddc15a18
...
...
@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -28,6 +27,8 @@
namespace
paddle
{
namespace
distributed
{
std
::
unique_ptr
<
Carrier
>
FleetExecutor
::
carrier_
;
FleetExecutor
::
FleetExecutor
(
const
std
::
string
&
exe_desc_str
)
{
bool
parse_flag
=
exe_desc_
.
ParseFromString
(
exe_desc_str
);
PADDLE_ENFORCE
(
parse_flag
,
platform
::
errors
::
PreconditionNotMet
(
...
...
@@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
GetCarrier
()
.
Release
();
GetCarrier
()
->
Release
();
}
Carrier
&
FleetExecutor
::
GetCarrier
()
{
static
Carrier
carrier
;
return
carrier
;
Carrier
*
FleetExecutor
::
GetCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
.
get
(),
platform
::
errors
::
NotFound
(
"Carrier has not been created."
));
return
carrier_
.
get
();
}
void
FleetExecutor
::
Init
(
...
...
@@ -84,16 +86,16 @@ void FleetExecutor::Init(
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
CreateCarrier
();
InitCarrier
();
InitMessageBus
();
}
void
FleetExecutor
::
InitCarrier
()
{
Carrier
&
carrier
=
GetCarrier
();
if
(
!
carrier
.
IsInit
())
{
carrier
.
SetMsgBus
(
msg_bus_
);
carrier
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
if
(
!
GetCarrier
()
->
IsInit
())
{
GetCarrier
()
->
SetMsgBus
(
msg_bus_
);
GetCarrier
()
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
}
}
...
...
@@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
VLOG
(
5
)
<<
ss
.
str
();
if
(
!
msg_bus_
->
IsInit
())
{
msg_bus_
->
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
rank_to_addr
,
addr
);
msg_bus_
->
Init
(
cur_rank
,
rank_to_addr
,
addr
);
}
}
void
FleetExecutor
::
Run
()
{
// Run
Carrier
&
carrier
=
GetCarrier
();
PADDLE_ENFORCE_EQ
(
carrier
.
IsInit
(),
true
,
GetCarrier
()
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
carrier
.
Start
();
GetCarrier
()
->
Start
();
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
ddc15a18
...
...
@@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
...
...
@@ -30,7 +31,6 @@ namespace distributed {
class
RuntimeGraph
;
class
MessageBus
;
class
TaskNode
;
class
Carrier
;
class
FleetExecutor
final
{
public:
...
...
@@ -43,7 +43,15 @@ class FleetExecutor final {
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
// TODO(liyurui): Change to use registry table for multi-carrier.
static
Carrier
&
GetCarrier
();
static
Carrier
*
GetCarrier
();
template
<
typename
...
Args
>
static
Carrier
*
CreateCarrier
(
Args
&&
...
args
)
{
PADDLE_ENFORCE_EQ
(
carrier_
.
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Carrier has been created already."
));
carrier_
=
std
::
make_unique
<
Carrier
>
(
std
::
forward
<
Args
>
(
args
)...);
return
carrier_
.
get
();
}
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
...
@@ -59,6 +67,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
static
std
::
unique_ptr
<
Carrier
>
carrier_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
ddc15a18
...
...
@@ -52,24 +52,17 @@ void Interceptor::StopCarrier() {
cond_var
.
notify_all
();
}
std
::
condition_variable
&
Interceptor
::
GetCondVar
()
{
// get the conditional var
return
cond_var_
;
}
int64_t
Interceptor
::
GetInterceptorId
()
const
{
// return the interceptor id
return
interceptor_id_
;
}
bool
Interceptor
::
EnqueueRemoteInterceptorMessage
(
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG
(
3
)
<<
"Enqueue message: "
<<
interceptor_message
.
message_type
()
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
remote_mailbox_mutex_
);
remote_mailbox_
.
push
(
interceptor_message
);
return
true
;
remote_mailbox_
.
Push
(
interceptor_message
);
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
...
...
@@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() {
"Error encountered when fetch remote mailbox."
));
}
const
InterceptorMessage
interceptor_message
=
local_mailbox_
.
front
();
local_mailbox_
.
pop
();
local_mailbox_
.
pop
_front
();
const
MessageType
message_type
=
interceptor_message
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
interceptor_message
.
src_id
()
...
...
@@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() {
}
bool
Interceptor
::
FetchRemoteMailbox
()
{
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
std
::
unique_lock
<
std
::
mutex
>
lock
(
remote_mailbox_mutex_
);
cond_var_
.
wait
(
lock
,
[
this
]()
{
return
!
remote_mailbox_
.
empty
();
});
if
(
remote_mailbox_
.
empty
())
{
// the thread has been unblocked accidentally
return
false
;
}
while
(
!
remote_mailbox_
.
empty
())
{
local_mailbox_
.
push
(
std
::
move
(
remote_mailbox_
.
front
()));
remote_mailbox_
.
pop
();
}
return
true
;
remote_mailbox_
.
PopAll
(
&
local_mailbox_
);
return
!
local_mailbox_
.
empty
();
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
ddc15a18
...
...
@@ -15,14 +15,15 @@
#pragma once
#include <condition_variable>
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <queue>
#include <thread>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
...
...
@@ -59,11 +60,8 @@ class Interceptor {
// return the interceptor id
int64_t
GetInterceptorId
()
const
;
// return the conditional var
std
::
condition_variable
&
GetCondVar
();
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
bool
EnqueueRemoteInterceptorMessage
(
void
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
);
// NOLINT
...
...
@@ -115,23 +113,16 @@ class Interceptor {
// interceptor handle which process message
MsgHandle
handle_
{
nullptr
};
// mutex to control read/write conflict for remote mailbox
std
::
mutex
remote_mailbox_mutex_
;
// interceptor runs PoolTheMailbox() function to poll local mailbox
std
::
thread
interceptor_thread_
;
// conditional variable for blocking the thread when
// fetch an empty remote mailbox
std
::
condition_variable
cond_var_
;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
std
::
q
ueue
<
InterceptorMessage
>
remote_mailbox_
;
framework
::
BlockingQ
ueue
<
InterceptorMessage
>
remote_mailbox_
;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std
::
que
ue
<
InterceptorMessage
>
local_mailbox_
;
std
::
deq
ue
<
InterceptorMessage
>
local_mailbox_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
ddc15a18
...
...
@@ -29,8 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
true
);
bool
flag
=
FleetExecutor
::
GetCarrier
()
->
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
flag
);
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
ddc15a18
...
...
@@ -17,8 +17,6 @@
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
...
...
@@ -26,16 +24,25 @@ namespace paddle {
namespace
distributed
{
void
MessageBus
::
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"MessageBus is already init."
));
rank_
=
rank
;
is_init_
=
true
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
rank_to_addr_
=
rank_to_addr
;
addr_
=
addr
;
if
(
addr_
!=
""
)
{
const
auto
&
addr
=
GetAddr
(
rank_
);
PADDLE_ENFORCE_EQ
(
addr
,
addr_
,
platform
::
errors
::
Fatal
(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error."
,
addr
,
addr_
));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
...
...
@@ -65,43 +72,57 @@ MessageBus::~MessageBus() {
#endif
}
bool
MessageBus
::
Send
(
const
InterceptorMessage
&
interceptor_message
)
{
// called by Interceptor, send InterceptorMessage to dst
int64_t
src_id
=
interceptor_message
.
src_id
();
int64_t
dst_id
=
interceptor_message
.
dst_id
();
if
(
IsSameRank
(
src_id
,
dst_id
))
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
SendIntraRank
(
interceptor_message
);
}
else
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in different ranks."
;
const
std
::
string
&
MessageBus
::
GetAddr
(
int64_t
rank
)
const
{
PADDLE_ENFORCE_NE
(
rank_to_addr_
.
find
(
rank
),
rank_to_addr_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find addr rank id %lld."
,
rank
));
return
rank_to_addr_
.
at
(
rank
);
}
bool
MessageBus
::
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
int
retry_time
=
0
;
// message bus will retry sending for 10 times
while
(
retry_time
<
10
)
{
++
retry_time
;
if
(
SendInterRank
(
interceptor_message
))
{
VLOG
(
3
)
<<
"Message bus sends inter rank successfully with "
<<
retry_time
<<
" times retries."
;
return
true
;
}
VLOG
(
3
)
<<
"Message bus sends failed, retry after 1 seconds."
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
int
retry_time
=
0
;
// message bus will retry sending for 10 times
while
(
retry_time
<
10
)
{
++
retry_time
;
if
(
SendInterRank
(
dst_rank
,
interceptor_message
))
{
VLOG
(
3
)
<<
"Message bus sends inter rank successfully with "
<<
retry_time
<<
" times retries."
;
return
true
;
}
VLOG
(
3
)
<<
"Message bus sends inter rank fail after 10 times retries."
;
return
false
;
VLOG
(
3
)
<<
"Message bus sends failed, retry after 1 seconds."
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus sends inter rank fail after 10 times retries."
;
return
false
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."
));
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."
));
#endif
}
return
true
;
}
void
MessageBus
::
TestConnection
()
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
for
(
const
auto
&
dst_rank_pair
:
rank_to_addr_
)
{
int64_t
dst_rank
=
dst_rank_pair
.
first
;
if
(
dst_rank
!=
rank_
)
{
ctrl_msg
.
set_dst_id
(
dst_rank
);
VLOG
(
3
)
<<
"Send control message bus from rank "
<<
rank_
<<
" to rank "
<<
dst_rank
;
while
(
!
Send
(
dst_rank
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
dst_rank
<<
"."
;
}
}
}
void
MessageBus
::
ListenPort
()
{
if
(
addr_
==
""
)
{
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
...
...
@@ -130,30 +151,7 @@ void MessageBus::ListenPort() {
interval
+=
500
;
}
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
std
::
set
<
int64_t
>
visit
;
InterceptorMessage
tmp_msg
;
tmp_msg
.
set_ctrl_message
(
true
);
for
(
auto
pair
:
interceptor_id_to_rank_
)
{
if
(
rank_to_addr_
.
at
(
pair
.
second
)
==
addr_
)
{
tmp_msg
.
set_src_id
(
pair
.
first
);
}
}
for
(
auto
pair
:
interceptor_id_to_rank_
)
{
int64_t
rank
=
pair
.
second
;
if
(
rank_to_addr_
.
at
(
rank
)
==
addr_
)
{
continue
;
}
tmp_msg
.
set_dst_id
(
pair
.
first
);
if
(
visit
.
find
(
rank
)
==
visit
.
end
())
{
VLOG
(
3
)
<<
"Message bus is testing connection for rank: "
<<
rank
<<
"."
;
visit
.
insert
(
rank
);
while
(
!
Send
(
tmp_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
rank
<<
"."
;
}
}
TestConnection
();
#else
LOG
(
WARNING
)
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
...
...
@@ -162,53 +160,13 @@ void MessageBus::ListenPort() {
#endif
}
bool
MessageBus
::
IsSameRank
(
int64_t
src_id
,
int64_t
dst_id
)
{
// -1 is sent by carrier to source interceptor
if
(
src_id
==
-
1
)
src_id
=
dst_id
;
// check whether the dst is the same rank or different rank with src
const
auto
&
src_rank
=
interceptor_id_to_rank_
.
find
(
src_id
);
const
auto
&
dst_rank
=
interceptor_id_to_rank_
.
find
(
dst_id
);
PADDLE_ENFORCE_NE
(
src_rank
,
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for src interceptor id %lld. Init error."
,
src_id
));
PADDLE_ENFORCE_NE
(
dst_rank
,
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for dst interceptor id %lld. Init error."
,
dst_id
));
if
(
addr_
==
""
)
{
// single card training, must be same rank
return
true
;
}
const
auto
&
src_ip
=
rank_to_addr_
.
find
(
src_rank
->
second
);
PADDLE_ENFORCE_NE
(
src_ip
,
rank_to_addr_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find addr for src rank id %lld. Init error."
,
src_rank
->
second
));
PADDLE_ENFORCE_EQ
(
src_ip
->
second
,
addr_
,
platform
::
errors
::
Fatal
(
"The src interceptor's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error."
,
src_ip
->
second
,
addr_
));
return
src_rank
->
second
==
dst_rank
->
second
;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
bool
MessageBus
::
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
)
{
// send the message inter rank (dst is different rank with src)
int64_t
dst_id
=
interceptor_message
.
dst_id
();
int64_t
dst_rank
=
interceptor_id_to_rank_
[
dst_id
];
auto
dst_ip
=
rank_to_addr_
.
find
(
dst_rank
);
PADDLE_ENFORCE_NE
(
dst_ip
,
rank_to_addr_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Cannot find rank for dst interceptor id %lld. "
"Init error."
,
dst_id
));
VLOG
(
3
)
<<
"Message bus sending to addr: "
<<
dst_ip
->
second
;
const
char
*
dst_ip_for_brpc
=
dst_ip
->
second
.
c_str
();
bool
MessageBus
::
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
const
auto
&
dst_addr
=
GetAddr
(
dst_rank
);
VLOG
(
3
)
<<
"Message bus sending to addr: "
<<
dst_addr
;
const
char
*
dst_addr_for_brpc
=
dst_addr
.
c_str
();
brpc
::
Channel
channel
;
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
...
...
@@ -216,7 +174,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
options
.
timeout_ms
=
1000
;
options
.
max_retry
=
5
;
PADDLE_ENFORCE_EQ
(
channel
.
Init
(
dst_
ip
_for_brpc
,
&
options
),
0
,
channel
.
Init
(
dst_
addr
_for_brpc
,
&
options
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc channel error."
));
TheInterceptorMessageService_Stub
stub
(
&
channel
);
InterceptorResponse
response
;
...
...
@@ -239,11 +197,5 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
}
#endif
bool
MessageBus
::
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
)
{
// send the message intra rank (dst is the same rank with src)
return
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
interceptor_message
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
ddc15a18
...
...
@@ -42,14 +42,14 @@ class MessageBus final {
MessageBus
()
=
default
;
~
MessageBus
();
void
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_
rank
,
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
);
bool
IsInit
()
const
;
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
...
...
@@ -57,22 +57,20 @@ class MessageBus final {
// function keep listen the port and handle the message
void
ListenPort
();
// check whether the dst is the same rank or different rank with src
bool
IsSameRank
(
int64_t
src_id
,
int64_t
dst_id
);
void
TestConnection
();
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src)
bool
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
);
bool
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
#endif
bool
is_init_
{
false
};
// send the message intra rank (dst is the same rank with src)
bool
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
);
// handed by above layer, save the info mapping interceptor id to rank id
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int64_t
rank_
;
// handed by above layer, save the info mapping rank id to addr
std
::
unordered_map
<
int64_t
,
std
::
string
>
rank_to_addr_
;
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
浏览文件 @
ddc15a18
...
...
@@ -21,7 +21,7 @@ namespace distributed {
std
::
string
RuntimeGraph
::
DebugString
()
const
{
std
::
ostringstream
os
;
os
<<
"
\n
Runtime Graph Debug:
\n
"
;
for
(
const
auto
&
pair
:
intercept
e
r_id_to_node_
)
{
for
(
const
auto
&
pair
:
intercept
o
r_id_to_node_
)
{
os
<<
pair
.
second
->
DebugString
();
os
<<
"
\n
"
;
}
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.h
浏览文件 @
ddc15a18
...
...
@@ -29,26 +29,26 @@ class RuntimeGraph final {
public:
RuntimeGraph
()
=
default
;
~
RuntimeGraph
()
=
default
;
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
e
r_id_to_node
()
const
{
return
intercept
e
r_id_to_node_
;
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
o
r_id_to_node
()
const
{
return
intercept
o
r_id_to_node_
;
}
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
e
r_id_to_rank
()
const
{
return
intercept
e
r_id_to_rank_
;
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
o
r_id_to_rank
()
const
{
return
intercept
o
r_id_to_rank_
;
}
void
SetInterceptorIdToRank
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
e
r_id_to_rank
)
{
intercept
er_id_to_rank_
=
intercepte
r_id_to_rank
;
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
o
r_id_to_rank
)
{
intercept
or_id_to_rank_
=
intercepto
r_id_to_rank
;
}
void
SetInterceptorIdToNode
(
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
e
r_id_to_node
)
{
intercept
er_id_to_node_
=
intercepte
r_id_to_node
;
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
o
r_id_to_node
)
{
intercept
or_id_to_node_
=
intercepto
r_id_to_node
;
}
std
::
string
DebugString
()
const
;
private:
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
std
::
unordered_map
<
int64_t
,
TaskNode
*>
intercept
e
r_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
intercept
e
r_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
intercept
o
r_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
intercept
o
r_id_to_rank_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
ddc15a18
...
...
@@ -62,11 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
ddc15a18
...
...
@@ -47,11 +47,10 @@ class StartInterceptor : public Interceptor {
};
TEST
(
ComputeInterceptor
,
Compute
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
ddc15a18
...
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -60,11 +59,9 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
ddc15a18
...
...
@@ -104,35 +104,42 @@ TEST(InterceptorTest, PingPong) {
std
::
string
ip1
=
"127.0.0.1:"
+
std
::
to_string
(
port1
);
std
::
cout
<<
"ip0: "
<<
ip0
<<
std
::
endl
;
std
::
cout
<<
"ip1: "
<<
ip1
<<
std
::
endl
;
int
pid
=
fork
();
if
(
pid
==
0
)
{
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank
=
{{
0
,
0
},
{
1
,
1
}};
int
exe_pid
=
fork
();
if
(
exe_pid
==
0
)
{
int
pid
=
fork
();
if
(
pid
==
0
)
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
}
else
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
->
Wait
();
int
status
;
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
pid
);
}
}
else
{
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
carrier
.
Wait
();
int
status
;
int
ret
=
waitpid
(
exe_pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
exe_pid
);
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
ddc15a18
...
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -52,11 +51,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
ddc15a18
...
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -70,10 +69,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}}
,
{{
0
,
""
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
...
...
paddle/fluid/framework/blocking_queue.h
浏览文件 @
ddc15a18
...
...
@@ -75,6 +75,12 @@ class BlockingQueue {
return
ret
;
}
void
PopAll
(
std
::
deque
<
T
>
*
empty_queue
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
!
q_
.
empty
();
});
std
::
swap
(
*
empty_queue
,
q_
);
}
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录