Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
dba59db7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dba59db7
编写于
12月 27, 2021
作者:
W
WangXi
提交者:
GitHub
12月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Add task loop thread pool (#38420)
上级
5b6b88ab
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
544 addition
and
190 deletion
+544
-190
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+4
-2
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+24
-78
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+12
-14
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+1
-1
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+5
-0
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+43
-59
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+7
-19
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+44
-13
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+9
-2
paddle/fluid/distributed/fleet_executor/task_loop.cc
paddle/fluid/distributed/fleet_executor/task_loop.cc
+82
-0
paddle/fluid/distributed/fleet_executor/task_loop.h
paddle/fluid/distributed/fleet_executor/task_loop.h
+81
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
+58
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
+44
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
...fluid/distributed/fleet_executor/task_loop_thread_pool.cc
+66
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
.../fluid/distributed/fleet_executor/task_loop_thread_pool.h
+47
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+7
-2
paddle/fluid/framework/blocking_queue.h
paddle/fluid/framework/blocking_queue.h
+10
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
dba59db7
...
...
@@ -10,10 +10,12 @@ else()
set
(
BRPC_DEPS
""
)
endif
()
cc_library
(
task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog
)
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto
collective_helper op_registry
executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto
task_loop_thread_pool collective_helper
op_registry
executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
dba59db7
...
...
@@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
place_
=
place
;
root_scope_
=
root_scope
;
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
CreateInterceptors
();
is_init_
=
true
;
}
void
Carrier
::
Release
()
{
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
InterceptorMessage
stop_msg
;
// source node STOP is send by carrier, so set src_id=-1
stop_msg
.
set_src_id
(
-
1
);
stop_msg
.
set_dst_id
(
id
);
stop_msg
.
set_message_type
(
STOP
);
Send
(
stop_msg
);
}
// TODO(wangxi): Maybe need a better to use thread.
for
(
auto
&
interceptor
:
interceptor_idx_to_interceptor_
)
{
interceptor
.
second
->
Join
();
}
}
void
Carrier
::
Release
()
{}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
...
...
@@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage(
VLOG
(
3
)
<<
"Receiving control message from rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
// for barrier
msg_bus_
->
IncreaseBarrierCount
();
}
else
{
{
std
::
unique_lock
<
std
::
mutex
>
lock_creating
(
creating_flag_mutex_
);
if
(
creating_interceptors_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock_message
(
tmp_message_mutex_
);
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG
(
3
)
<<
"Receiving message while creating interceptors."
;
message_tmp_
.
emplace_back
(
interceptor_message
);
return
true
;
}
}
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
...
...
@@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage(
return
true
;
}
void
Carrier
::
Barrier
()
{
msg_bus_
->
Barrier
();
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
...
...
@@ -109,6 +89,11 @@ void Carrier::Wait() {
cond_var_
.
wait
(
lock
);
}
void
Carrier
::
WakeUp
()
{
// probably double notify, but ok for ut
cond_var_
.
notify_all
();
}
void
Carrier
::
Start
()
{
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
...
...
@@ -126,12 +111,11 @@ void Carrier::Start() {
start_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
start_msg
);
}
// TODO(wangxi): async step
Wait
();
dev_ctx_
->
Wait
();
}
std
::
condition_variable
&
Carrier
::
GetCondVar
()
{
return
cond_var_
;
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
...
...
@@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id should be unique."
,
interceptor_id
));
interceptor
->
RegisterCarrier
(
this
);
// TODO(fleet_exe dev): get loop
auto
*
loop
=
thread_pool_
.
GetLoop
(
interceptor_id
%
thread_num_
);
PADDLE_ENFORCE_NOT_NULL
(
loop
,
platform
::
errors
::
Fatal
(
"thread task loop must not null"
));
interceptor
->
RegisterTaskLoop
(
loop
);
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
return
ptr
;
}
void
Carrier
::
SetCreatingFlag
(
bool
flag
)
{
// set the creating flag
creating_flag_mutex_
.
lock
();
VLOG
(
3
)
<<
"Carrier is set the creating flag from "
<<
creating_interceptors_
<<
" to "
<<
flag
<<
"."
;
creating_interceptors_
=
flag
;
creating_flag_mutex_
.
unlock
();
if
(
!
flag
)
{
for
(
auto
&
pair
:
interceptor_idx_to_interceptor_
)
{
// update the source interceptor id
if
(
std
::
find
(
source_interceptor_ids_
.
begin
(),
source_interceptor_ids_
.
end
(),
pair
.
first
)
==
source_interceptor_ids_
.
end
())
{
auto
task
=
pair
.
second
->
GetTaskNode
();
if
(
task
!=
nullptr
&&
task
->
upstream
().
empty
())
{
source_interceptor_ids_
.
emplace_back
(
pair
.
first
);
}
}
}
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages
();
}
}
void
Carrier
::
HandleTmpMessages
()
{
// NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this
// `HandleTmpMessages` method, the creating_interceptors_ flag
// must be false, therefore, there won't have conflict with the
// lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage`
// on the same thread.
std
::
unique_lock
<
std
::
mutex
>
lock
(
tmp_message_mutex_
);
VLOG
(
3
)
<<
"Carrier has received "
<<
message_tmp_
.
size
()
<<
" messages during creating interceptors."
;
for
(
const
auto
&
msg
:
message_tmp_
)
{
EnqueueInterceptorMessage
(
msg
);
}
message_tmp_
.
clear
();
}
static
std
::
shared_ptr
<
framework
::
GarbageCollector
>
GetGC
(
const
platform
::
Place
&
place
)
{
int64_t
max_memory_size
=
framework
::
GetEagerDeletionThreshold
();
...
...
@@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() {
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
}
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_flag_mutex_
.
lock
();
creating_interceptors_
=
false
;
creating_flag_mutex_
.
unlock
();
HandleTmpMessages
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
dba59db7
...
...
@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
...
...
@@ -47,7 +48,11 @@ class Carrier final {
Carrier
()
=
default
;
Carrier
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{}
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
}
~
Carrier
();
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
...
...
@@ -56,6 +61,7 @@ class Carrier final {
void
Release
();
void
Wait
();
void
WakeUp
();
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
...
...
@@ -67,23 +73,18 @@ class Carrier final {
Interceptor
*
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
);
void
SetCreatingFlag
(
bool
flag
)
;
void
SetCreatingFlag
(
bool
flag
)
{}
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
msg_bus_
=
msg_bus
;
}
std
::
condition_variable
&
GetCondVar
();
void
Start
();
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
);
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std
::
mutex
run
;
void
Barrier
();
private:
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
...
...
@@ -91,8 +92,6 @@ class Carrier final {
// create each Interceptor
void
CreateInterceptors
();
void
HandleTmpMessages
();
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
// interceptor logic id to actually interceptor
...
...
@@ -101,10 +100,6 @@ class Carrier final {
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
std
::
vector
<
InterceptorMessage
>
message_tmp_
{};
std
::
mutex
tmp_message_mutex_
;
bool
creating_interceptors_
{
true
};
std
::
mutex
creating_flag_mutex_
;
bool
is_init_
{
false
};
std
::
mutex
running_mutex_
;
...
...
@@ -118,6 +113,9 @@ class Carrier final {
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int
thread_num_
;
TaskLoopThreadPool
thread_pool_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
dba59db7
...
...
@@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
void
ComputeInterceptor
::
RunOps
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_
->
run
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
...
...
@@ -198,6 +197,7 @@ void ComputeInterceptor::Run() {
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" is stopping carrier."
;
// FIXME(wangxi): with multi sink interceptor
StopCarrier
();
}
}
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
dba59db7
...
...
@@ -89,6 +89,11 @@ void FleetExecutor::Init(
CreateCarrier
();
InitCarrier
();
InitMessageBus
();
// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier
()
->
Barrier
();
}
void
FleetExecutor
::
InitCarrier
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
dba59db7
...
...
@@ -14,26 +14,21 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{
interceptor_thread_
=
std
::
thread
([
this
]()
{
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" starts the thread pooling it's local mailbox."
;
PoolTheMailbox
();
});
}
Interceptor
::~
Interceptor
()
{
Join
();
}
void
Interceptor
::
Join
()
{
if
(
interceptor_thread_
.
joinable
())
{
interceptor_thread_
.
join
();
}
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{}
Interceptor
::~
Interceptor
()
{
// FIXME(wangxi): throw in stop function
// std::lock_guard<std::mutex> lock(mutex_);
// PADDLE_ENFORCE_EQ(messages_.empty(), true,
// platform::errors::PreconditionNotMet(
// "Interceptor must destruct with messages empty"));
}
void
Interceptor
::
RegisterMsgHandle
(
MsgHandle
handle
)
{
handle_
=
handle
;
}
...
...
@@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
handle_
(
msg
);
}
void
Interceptor
::
LoopOnce
()
{
std
::
deque
<
InterceptorMessage
>
tmp_messages
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
messages_
.
swap
(
tmp_messages
);
}
PADDLE_ENFORCE_EQ
(
tmp_messages
.
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"tmp_messages must not empty in task loop"
));
for
(
auto
&
msg
:
tmp_messages
)
{
const
MessageType
message_type
=
msg
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
msg
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
Handle
(
msg
);
}
}
void
Interceptor
::
StopCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
std
::
condition_variable
&
cond_var
=
carrier_
->
GetCondVar
();
// probably double notify, but ok for ut
cond_var
.
notify_all
();
}
int64_t
Interceptor
::
GetInterceptorId
()
const
{
// return the interceptor id
return
interceptor_id_
;
carrier_
->
WakeUp
();
}
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_
message
)
{
const
InterceptorMessage
&
message
)
{
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG
(
3
)
<<
"Enqueue message: "
<<
interceptor_message
.
message_type
()
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
remote_mailbox_
.
Push
(
interceptor_message
);
VLOG
(
3
)
<<
"Enqueue message: "
<<
message
.
message_type
()
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
bool
empty
=
false
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
empty
=
messages_
.
empty
();
messages_
.
emplace_back
(
message
);
}
if
(
empty
)
{
loop_
->
QueueInLoop
([
this
]()
{
LoopOnce
();
});
}
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
...
...
@@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return
carrier_
->
Send
(
msg
);
}
void
Interceptor
::
PoolTheMailbox
()
{
// pool the local mailbox, parse the Message
for
(;;)
{
if
(
local_mailbox_
.
empty
())
{
// local mailbox is empty, fetch the remote mailbox
VLOG
(
3
)
<<
interceptor_id_
<<
"'s local mailbox is empty. "
<<
"Fetch the remote mailbox."
;
PADDLE_ENFORCE_EQ
(
FetchRemoteMailbox
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Error encountered when fetch remote mailbox."
));
}
const
InterceptorMessage
interceptor_message
=
local_mailbox_
.
front
();
local_mailbox_
.
pop_front
();
const
MessageType
message_type
=
interceptor_message
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
interceptor_message
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
Handle
(
interceptor_message
);
if
(
stop_
)
{
// break the pooling thread
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" is quiting."
;
break
;
}
}
}
bool
Interceptor
::
FetchRemoteMailbox
()
{
remote_mailbox_
.
PopAll
(
&
local_mailbox_
);
return
!
local_mailbox_
.
empty
();
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
interceptorMap
;
return
interceptorMap
;
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
dba59db7
...
...
@@ -38,6 +38,7 @@ namespace distributed {
class
TaskNode
;
class
Carrier
;
class
TaskLoop
;
class
Interceptor
{
public:
...
...
@@ -50,15 +51,13 @@ class Interceptor {
virtual
~
Interceptor
();
void
Join
();
// register interceptor handle
void
RegisterMsgHandle
(
MsgHandle
handle
);
void
Handle
(
const
InterceptorMessage
&
msg
);
// return the interceptor id
int64_t
GetInterceptorId
()
const
;
int64_t
GetInterceptorId
()
const
{
return
interceptor_id_
;
}
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
void
EnqueueRemoteInterceptorMessage
(
...
...
@@ -77,6 +76,7 @@ class Interceptor {
gc_
=
gc
;
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
void
RegisterTaskLoop
(
TaskLoop
*
loop
)
{
loop_
=
loop
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
...
...
@@ -101,28 +101,16 @@ class Interceptor {
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
Carrier
*
carrier_
;
TaskLoop
*
loop_
;
private:
// pool the local mailbox, parse the Message
void
PoolTheMailbox
();
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
bool
FetchRemoteMailbox
();
void
LoopOnce
();
// interceptor handle which process message
MsgHandle
handle_
{
nullptr
};
// interceptor runs PoolTheMailbox() function to poll local mailbox
std
::
thread
interceptor_thread_
;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
framework
::
BlockingQueue
<
InterceptorMessage
>
remote_mailbox_
;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std
::
deque
<
InterceptorMessage
>
local_mailbox_
;
std
::
mutex
mutex_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
dba59db7
...
...
@@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank,
return
true
;
}
void
MessageBus
::
TestConnection
()
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
for
(
const
auto
&
dst_rank_pair
:
rank_to_addr_
)
{
int64_t
dst_rank
=
dst_rank_pair
.
first
;
if
(
dst_rank
!=
rank_
)
{
ctrl_msg
.
set_dst_id
(
dst_rank
);
VLOG
(
3
)
<<
"Send control message bus from rank "
<<
rank_
<<
" to rank "
<<
dst_rank
;
while
(
!
Send
(
dst_rank
,
ctrl_msg
))
{
void
MessageBus
::
IncreaseBarrierCount
()
{
VLOG
(
3
)
<<
"IncreaseBarrierCount"
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
++
count_
;
cv_
.
notify_one
();
}
VLOG
(
3
)
<<
"End IncreaseBarrierCount"
;
}
void
MessageBus
::
Barrier
()
{
// gather to root
if
(
rank_
!=
0
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
ctrl_msg
.
set_dst_id
(
0
);
VLOG
(
3
)
<<
"Barrier Gather ctrl message from "
<<
rank_
<<
" to 0"
;
while
(
!
Send
(
0
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
else
{
VLOG
(
3
)
<<
"Barrier 0 wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
static_cast
<
int
>
(
rank_to_addr_
.
size
()
-
1
);
});
count_
=
0
;
}
// scatter from root
if
(
rank_
==
0
)
{
for
(
int
i
=
1
;
i
<
static_cast
<
int
>
(
rank_to_addr_
.
size
());
++
i
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
0
);
ctrl_msg
.
set_dst_id
(
i
);
VLOG
(
3
)
<<
"Barrier Scatter ctrl message from 0 to "
<<
i
;
while
(
!
Send
(
i
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
dst_rank
<<
"."
;
}
}
else
{
VLOG
(
3
)
<<
"Barrier "
<<
rank_
<<
" wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
1
;
});
count_
=
0
;
}
}
...
...
@@ -151,7 +183,6 @@ void MessageBus::ListenPort() {
interval
+=
500
;
}
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
TestConnection
();
#else
LOG
(
WARNING
)
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
dba59db7
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <condition_variable>
#include <mutex>
#include <string>
#include <thread>
...
...
@@ -51,14 +52,15 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
void
IncreaseBarrierCount
();
void
Barrier
();
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
// function keep listen the port and handle the message
void
ListenPort
();
void
TestConnection
();
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
...
...
@@ -84,6 +86,11 @@ class MessageBus final {
// brpc server
brpc
::
Server
server_
;
#endif
// for barrier
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
int
count_
{
0
};
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/task_loop.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
thread_local
TaskLoop
*
TaskLoop
::
thread_local_loop_
=
nullptr
;
TaskLoop
*
TaskLoop
::
GetTaskLoopOfCurrentThread
()
{
return
thread_local_loop_
;
}
TaskLoop
::
TaskLoop
()
:
looping_
(
false
),
quit_
(
false
),
thread_id_
(
std
::
this_thread
::
get_id
())
{
PADDLE_ENFORCE_EQ
(
thread_local_loop_
,
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Another TaskLoop is already init."
));
thread_local_loop_
=
this
;
}
TaskLoop
::~
TaskLoop
()
{
thread_local_loop_
=
nullptr
;
}
void
TaskLoop
::
Loop
()
{
PADDLE_ENFORCE_EQ
(
looping_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"Loop can only execute in one loop thread"
));
AssertInLoopThread
();
looping_
=
true
;
quit_
=
false
;
while
(
!
quit_
)
{
auto
tasks
=
tasks_
.
PopAll
();
for
(
auto
&
task
:
tasks
)
{
task
();
}
}
looping_
=
false
;
}
void
TaskLoop
::
Quit
()
{
quit_
=
true
;
if
(
!
IsInLoopThread
())
WakeUp
();
}
void
TaskLoop
::
RunInLoop
(
Functor
cb
)
{
if
(
IsInLoopThread
())
{
cb
();
}
else
{
QueueInLoop
(
cb
);
}
}
void
TaskLoop
::
QueueInLoop
(
Functor
cb
)
{
tasks_
.
Push
(
cb
);
}
void
TaskLoop
::
WakeUp
()
{
Functor
task
([]
{});
QueueInLoop
(
task
);
}
void
TaskLoop
::
AbortNotInLoopThread
()
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"This TaskLoop was created in thread %d, but current thread is %d"
,
thread_id_
,
std
::
this_thread
::
get_id
()));
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
{
public:
static
TaskLoop
*
GetTaskLoopOfCurrentThread
();
using
Functor
=
std
::
function
<
void
()
>
;
TaskLoop
();
~
TaskLoop
();
void
Loop
();
void
Quit
();
void
RunInLoop
(
Functor
cb
);
void
QueueInLoop
(
Functor
cb
);
template
<
class
F
,
class
...
Args
>
auto
Enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...));
std
::
future
<
return_type
>
task_future
=
task
->
get_future
();
tasks_
.
Push
([
task
]()
{
(
*
task
)();
});
return
task_future
;
}
void
WakeUp
();
bool
IsInLoopThread
()
const
{
return
thread_id_
==
std
::
this_thread
::
get_id
();
}
void
AssertInLoopThread
()
{
if
(
!
IsInLoopThread
())
{
AbortNotInLoopThread
();
}
}
private:
void
AbortNotInLoopThread
();
static
thread_local
TaskLoop
*
thread_local_loop_
;
bool
looping_
;
std
::
atomic
<
bool
>
quit_
;
std
::
thread
::
id
thread_id_
;
framework
::
BlockingQueue
<
Functor
>
tasks_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThread
::
TaskLoopThread
()
:
start_
(
false
),
loop_
(
nullptr
)
{}
TaskLoopThread
::~
TaskLoopThread
()
{
if
(
loop_
!=
nullptr
)
{
loop_
->
Quit
();
thread_
.
join
();
}
}
TaskLoop
*
TaskLoopThread
::
StartLoop
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread is already running."
));
start_
=
true
;
thread_
=
std
::
thread
([
this
]()
{
Loop
();
});
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
loop_
!=
nullptr
;
});
return
loop_
;
}
void
TaskLoopThread
::
Loop
()
{
TaskLoop
loop
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
&
loop
;
cv_
.
notify_one
();
}
loop
.
Loop
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
nullptr
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
{
public:
TaskLoopThread
();
~
TaskLoopThread
();
TaskLoop
*
StartLoop
();
private:
void
Loop
();
bool
start_
;
TaskLoop
*
loop_
;
std
::
thread
thread_
;
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThreadPool
::
TaskLoopThreadPool
()
:
TaskLoopThreadPool
(
1
)
{}
TaskLoopThreadPool
::
TaskLoopThreadPool
(
int
thread_num
)
:
start_
(
false
),
thread_num_
(
thread_num
)
{}
TaskLoopThreadPool
::~
TaskLoopThreadPool
()
=
default
;
void
TaskLoopThreadPool
::
Start
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool is already start."
));
PADDLE_ENFORCE_GT
(
thread_num_
,
0
,
platform
::
errors
::
InvalidArgument
(
"thread num must greater than 0, but now is %d"
,
thread_num_
));
start_
=
true
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
threads_
.
emplace_back
(
new
TaskLoopThread
());
loops_
.
push_back
(
threads_
[
i
]
->
StartLoop
());
}
}
TaskLoop
*
TaskLoopThreadPool
::
GetLoop
(
int
tid
)
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
PADDLE_ENFORCE_GE
(
tid
,
0
,
platform
::
errors
::
OutOfRange
(
"tid must >= 0, but now is %d"
,
tid
));
PADDLE_ENFORCE_LT
(
tid
,
thread_num_
,
platform
::
errors
::
OutOfRange
(
"tid must < thread_num, but now tid=%d thread_num=%d"
,
tid
,
thread_num_
));
return
loops_
[
tid
];
}
std
::
vector
<
TaskLoop
*>
TaskLoopThreadPool
::
GetAllLoops
()
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
return
loops_
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
;
class
TaskLoopThreadPool
{
public:
TaskLoopThreadPool
();
explicit
TaskLoopThreadPool
(
int
thread_num
);
~
TaskLoopThreadPool
();
void
SetThreadNum
(
int
thread_num
)
{
thread_num_
=
thread_num
;
}
void
Start
();
TaskLoop
*
GetLoop
(
int
tid
);
std
::
vector
<
TaskLoop
*>
GetAllLoops
();
private:
bool
start_
;
int
thread_num_
;
std
::
vector
<
std
::
unique_ptr
<
TaskLoopThread
>>
threads_
;
std
::
vector
<
TaskLoop
*>
loops_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
dba59db7
...
...
@@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
->
Barrier
();
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
...
...
@@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetMsgBus
(
msg_bus
);
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
->
Barrier
();
carrier
->
Wait
();
int
status
;
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
...
...
paddle/fluid/framework/blocking_queue.h
浏览文件 @
dba59db7
...
...
@@ -81,6 +81,16 @@ class BlockingQueue {
std
::
swap
(
*
empty_queue
,
q_
);
}
std
::
deque
<
T
>
PopAll
()
{
std
::
deque
<
T
>
ret
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
!
q_
.
empty
();
});
std
::
swap
(
ret
,
q_
);
}
return
ret
;
}
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录