Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
dba59db7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dba59db7
编写于
12月 27, 2021
作者:
W
WangXi
提交者:
GitHub
12月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Add task loop thread pool (#38420)
上级
5b6b88ab
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
544 addition
and
190 deletion
+544
-190
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+4
-2
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+24
-78
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+12
-14
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+1
-1
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+5
-0
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+43
-59
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+7
-19
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+44
-13
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+9
-2
paddle/fluid/distributed/fleet_executor/task_loop.cc
paddle/fluid/distributed/fleet_executor/task_loop.cc
+82
-0
paddle/fluid/distributed/fleet_executor/task_loop.h
paddle/fluid/distributed/fleet_executor/task_loop.h
+81
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
+58
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
+44
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
...fluid/distributed/fleet_executor/task_loop_thread_pool.cc
+66
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
.../fluid/distributed/fleet_executor/task_loop_thread_pool.h
+47
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+7
-2
paddle/fluid/framework/blocking_queue.h
paddle/fluid/framework/blocking_queue.h
+10
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
dba59db7
...
@@ -10,10 +10,12 @@ else()
...
@@ -10,10 +10,12 @@ else()
set
(
BRPC_DEPS
""
)
set
(
BRPC_DEPS
""
)
endif
()
endif
()
cc_library
(
task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog
)
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto
collective_helper op_registry
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto
task_loop_thread_pool collective_helper
executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
op_registry
executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
dba59db7
...
@@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
...
@@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
place_
=
place
;
place_
=
place
;
root_scope_
=
root_scope
;
root_scope_
=
root_scope
;
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
CreateInterceptors
();
CreateInterceptors
();
is_init_
=
true
;
is_init_
=
true
;
}
}
void
Carrier
::
Release
()
{
void
Carrier
::
Release
()
{}
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
InterceptorMessage
stop_msg
;
// source node STOP is send by carrier, so set src_id=-1
stop_msg
.
set_src_id
(
-
1
);
stop_msg
.
set_dst_id
(
id
);
stop_msg
.
set_message_type
(
STOP
);
Send
(
stop_msg
);
}
// TODO(wangxi): Maybe need a better to use thread.
for
(
auto
&
interceptor
:
interceptor_idx_to_interceptor_
)
{
interceptor
.
second
->
Join
();
}
}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
...
@@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage(
...
@@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage(
VLOG
(
3
)
<<
"Receiving control message from rank "
VLOG
(
3
)
<<
"Receiving control message from rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
<<
interceptor_message
.
dst_id
();
// for barrier
msg_bus_
->
IncreaseBarrierCount
();
}
else
{
}
else
{
{
std
::
unique_lock
<
std
::
mutex
>
lock_creating
(
creating_flag_mutex_
);
if
(
creating_interceptors_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock_message
(
tmp_message_mutex_
);
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG
(
3
)
<<
"Receiving message while creating interceptors."
;
message_tmp_
.
emplace_back
(
interceptor_message
);
return
true
;
}
}
int64_t
dst_id
=
interceptor_message
.
dst_id
();
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
...
@@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage(
...
@@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage(
return
true
;
return
true
;
}
}
void
Carrier
::
Barrier
()
{
msg_bus_
->
Barrier
();
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
PADDLE_ENFORCE_NE
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
...
@@ -109,6 +89,11 @@ void Carrier::Wait() {
...
@@ -109,6 +89,11 @@ void Carrier::Wait() {
cond_var_
.
wait
(
lock
);
cond_var_
.
wait
(
lock
);
}
}
void
Carrier
::
WakeUp
()
{
// probably double notify, but ok for ut
cond_var_
.
notify_all
();
}
void
Carrier
::
Start
()
{
void
Carrier
::
Start
()
{
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
...
@@ -126,12 +111,11 @@ void Carrier::Start() {
...
@@ -126,12 +111,11 @@ void Carrier::Start() {
start_msg
.
set_message_type
(
DATA_IS_READY
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
start_msg
);
Send
(
start_msg
);
}
}
// TODO(wangxi): async step
Wait
();
Wait
();
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
}
}
std
::
condition_variable
&
Carrier
::
GetCondVar
()
{
return
cond_var_
;
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
...
@@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
...
@@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id should be unique."
,
"The interceptor id should be unique."
,
interceptor_id
));
interceptor_id
));
interceptor
->
RegisterCarrier
(
this
);
interceptor
->
RegisterCarrier
(
this
);
// TODO(fleet_exe dev): get loop
auto
*
loop
=
thread_pool_
.
GetLoop
(
interceptor_id
%
thread_num_
);
PADDLE_ENFORCE_NOT_NULL
(
loop
,
platform
::
errors
::
Fatal
(
"thread task loop must not null"
));
interceptor
->
RegisterTaskLoop
(
loop
);
auto
*
ptr
=
interceptor
.
get
();
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
return
ptr
;
return
ptr
;
}
}
void
Carrier
::
SetCreatingFlag
(
bool
flag
)
{
// set the creating flag
creating_flag_mutex_
.
lock
();
VLOG
(
3
)
<<
"Carrier is set the creating flag from "
<<
creating_interceptors_
<<
" to "
<<
flag
<<
"."
;
creating_interceptors_
=
flag
;
creating_flag_mutex_
.
unlock
();
if
(
!
flag
)
{
for
(
auto
&
pair
:
interceptor_idx_to_interceptor_
)
{
// update the source interceptor id
if
(
std
::
find
(
source_interceptor_ids_
.
begin
(),
source_interceptor_ids_
.
end
(),
pair
.
first
)
==
source_interceptor_ids_
.
end
())
{
auto
task
=
pair
.
second
->
GetTaskNode
();
if
(
task
!=
nullptr
&&
task
->
upstream
().
empty
())
{
source_interceptor_ids_
.
emplace_back
(
pair
.
first
);
}
}
}
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages
();
}
}
void
Carrier
::
HandleTmpMessages
()
{
// NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this
// `HandleTmpMessages` method, the creating_interceptors_ flag
// must be false, therefore, there won't have conflict with the
// lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage`
// on the same thread.
std
::
unique_lock
<
std
::
mutex
>
lock
(
tmp_message_mutex_
);
VLOG
(
3
)
<<
"Carrier has received "
<<
message_tmp_
.
size
()
<<
" messages during creating interceptors."
;
for
(
const
auto
&
msg
:
message_tmp_
)
{
EnqueueInterceptorMessage
(
msg
);
}
message_tmp_
.
clear
();
}
static
std
::
shared_ptr
<
framework
::
GarbageCollector
>
GetGC
(
static
std
::
shared_ptr
<
framework
::
GarbageCollector
>
GetGC
(
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
int64_t
max_memory_size
=
framework
::
GetEagerDeletionThreshold
();
int64_t
max_memory_size
=
framework
::
GetEagerDeletionThreshold
();
...
@@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() {
...
@@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() {
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
}
}
}
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_flag_mutex_
.
lock
();
creating_interceptors_
=
false
;
creating_flag_mutex_
.
unlock
();
HandleTmpMessages
();
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
dba59db7
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
...
@@ -47,7 +48,11 @@ class Carrier final {
...
@@ -47,7 +48,11 @@ class Carrier final {
Carrier
()
=
default
;
Carrier
()
=
default
;
Carrier
(
int64_t
rank
,
Carrier
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{}
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
}
~
Carrier
();
~
Carrier
();
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
...
@@ -56,6 +61,7 @@ class Carrier final {
...
@@ -56,6 +61,7 @@ class Carrier final {
void
Release
();
void
Release
();
void
Wait
();
void
Wait
();
void
WakeUp
();
// Enqueue a message to corresponding interceptor id
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
...
@@ -67,23 +73,18 @@ class Carrier final {
...
@@ -67,23 +73,18 @@ class Carrier final {
Interceptor
*
SetInterceptor
(
int64_t
interceptor_id
,
Interceptor
*
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
);
std
::
unique_ptr
<
Interceptor
>
);
void
SetCreatingFlag
(
bool
flag
)
;
void
SetCreatingFlag
(
bool
flag
)
{}
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
msg_bus_
=
msg_bus
;
msg_bus_
=
msg_bus
;
}
}
std
::
condition_variable
&
GetCondVar
();
void
Start
();
void
Start
();
bool
IsInit
()
const
;
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
);
bool
Send
(
const
InterceptorMessage
&
msg
);
// NOTE: This mutex will be used in interceptor's RunOps function.
void
Barrier
();
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std
::
mutex
run
;
private:
private:
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
...
@@ -91,8 +92,6 @@ class Carrier final {
...
@@ -91,8 +92,6 @@ class Carrier final {
// create each Interceptor
// create each Interceptor
void
CreateInterceptors
();
void
CreateInterceptors
();
void
HandleTmpMessages
();
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
// interceptor logic id to actually interceptor
// interceptor logic id to actually interceptor
...
@@ -101,10 +100,6 @@ class Carrier final {
...
@@ -101,10 +100,6 @@ class Carrier final {
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
std
::
vector
<
InterceptorMessage
>
message_tmp_
{};
std
::
mutex
tmp_message_mutex_
;
bool
creating_interceptors_
{
true
};
std
::
mutex
creating_flag_mutex_
;
bool
is_init_
{
false
};
bool
is_init_
{
false
};
std
::
mutex
running_mutex_
;
std
::
mutex
running_mutex_
;
...
@@ -118,6 +113,9 @@ class Carrier final {
...
@@ -118,6 +113,9 @@ class Carrier final {
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
int64_t
rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int
thread_num_
;
TaskLoopThreadPool
thread_pool_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
dba59db7
...
@@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
}
void
ComputeInterceptor
::
RunOps
()
{
void
ComputeInterceptor
::
RunOps
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_
->
run
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
for
(
auto
op
:
node_
->
ops
())
{
...
@@ -198,6 +197,7 @@ void ComputeInterceptor::Run() {
...
@@ -198,6 +197,7 @@ void ComputeInterceptor::Run() {
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" is stopping carrier."
;
<<
" is stopping carrier."
;
// FIXME(wangxi): with multi sink interceptor
StopCarrier
();
StopCarrier
();
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
dba59db7
...
@@ -89,6 +89,11 @@ void FleetExecutor::Init(
...
@@ -89,6 +89,11 @@ void FleetExecutor::Init(
CreateCarrier
();
CreateCarrier
();
InitCarrier
();
InitCarrier
();
InitMessageBus
();
InitMessageBus
();
// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier
()
->
Barrier
();
}
}
void
FleetExecutor
::
InitCarrier
()
{
void
FleetExecutor
::
InitCarrier
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
dba59db7
...
@@ -14,26 +14,21 @@
...
@@ -14,26 +14,21 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{}
interceptor_thread_
=
std
::
thread
([
this
]()
{
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
Interceptor
::~
Interceptor
()
{
<<
" starts the thread pooling it's local mailbox."
;
// FIXME(wangxi): throw in stop function
PoolTheMailbox
();
// std::lock_guard<std::mutex> lock(mutex_);
});
// PADDLE_ENFORCE_EQ(messages_.empty(), true,
}
// platform::errors::PreconditionNotMet(
// "Interceptor must destruct with messages empty"));
Interceptor
::~
Interceptor
()
{
Join
();
}
void
Interceptor
::
Join
()
{
if
(
interceptor_thread_
.
joinable
())
{
interceptor_thread_
.
join
();
}
}
}
void
Interceptor
::
RegisterMsgHandle
(
MsgHandle
handle
)
{
handle_
=
handle
;
}
void
Interceptor
::
RegisterMsgHandle
(
MsgHandle
handle
)
{
handle_
=
handle
;
}
...
@@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
...
@@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
handle_
(
msg
);
handle_
(
msg
);
}
}
void
Interceptor
::
LoopOnce
()
{
std
::
deque
<
InterceptorMessage
>
tmp_messages
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
messages_
.
swap
(
tmp_messages
);
}
PADDLE_ENFORCE_EQ
(
tmp_messages
.
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"tmp_messages must not empty in task loop"
));
for
(
auto
&
msg
:
tmp_messages
)
{
const
MessageType
message_type
=
msg
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
msg
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
Handle
(
msg
);
}
}
void
Interceptor
::
StopCarrier
()
{
void
Interceptor
::
StopCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
"Carrier is not registered."
));
std
::
condition_variable
&
cond_var
=
carrier_
->
GetCondVar
();
carrier_
->
WakeUp
();
// probably double notify, but ok for ut
cond_var
.
notify_all
();
}
int64_t
Interceptor
::
GetInterceptorId
()
const
{
// return the interceptor id
return
interceptor_id_
;
}
}
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_
message
)
{
const
InterceptorMessage
&
message
)
{
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG
(
3
)
<<
"Enqueue message: "
<<
interceptor_message
.
message_type
()
VLOG
(
3
)
<<
"Enqueue message: "
<<
message
.
message_type
()
<<
" into "
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
<<
interceptor_id_
<<
"'s remote mailbox."
;
remote_mailbox_
.
Push
(
interceptor_message
);
bool
empty
=
false
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
empty
=
messages_
.
empty
();
messages_
.
emplace_back
(
message
);
}
if
(
empty
)
{
loop_
->
QueueInLoop
([
this
]()
{
LoopOnce
();
});
}
}
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
...
@@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
...
@@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return
carrier_
->
Send
(
msg
);
return
carrier_
->
Send
(
msg
);
}
}
void
Interceptor
::
PoolTheMailbox
()
{
// pool the local mailbox, parse the Message
for
(;;)
{
if
(
local_mailbox_
.
empty
())
{
// local mailbox is empty, fetch the remote mailbox
VLOG
(
3
)
<<
interceptor_id_
<<
"'s local mailbox is empty. "
<<
"Fetch the remote mailbox."
;
PADDLE_ENFORCE_EQ
(
FetchRemoteMailbox
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Error encountered when fetch remote mailbox."
));
}
const
InterceptorMessage
interceptor_message
=
local_mailbox_
.
front
();
local_mailbox_
.
pop_front
();
const
MessageType
message_type
=
interceptor_message
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
interceptor_message
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
Handle
(
interceptor_message
);
if
(
stop_
)
{
// break the pooling thread
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" is quiting."
;
break
;
}
}
}
bool
Interceptor
::
FetchRemoteMailbox
()
{
remote_mailbox_
.
PopAll
(
&
local_mailbox_
);
return
!
local_mailbox_
.
empty
();
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
interceptorMap
;
static
InterceptorFactory
::
CreateInterceptorMap
interceptorMap
;
return
interceptorMap
;
return
interceptorMap
;
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
dba59db7
...
@@ -38,6 +38,7 @@ namespace distributed {
...
@@ -38,6 +38,7 @@ namespace distributed {
class
TaskNode
;
class
TaskNode
;
class
Carrier
;
class
Carrier
;
class
TaskLoop
;
class
Interceptor
{
class
Interceptor
{
public:
public:
...
@@ -50,15 +51,13 @@ class Interceptor {
...
@@ -50,15 +51,13 @@ class Interceptor {
virtual
~
Interceptor
();
virtual
~
Interceptor
();
void
Join
();
// register interceptor handle
// register interceptor handle
void
RegisterMsgHandle
(
MsgHandle
handle
);
void
RegisterMsgHandle
(
MsgHandle
handle
);
void
Handle
(
const
InterceptorMessage
&
msg
);
void
Handle
(
const
InterceptorMessage
&
msg
);
// return the interceptor id
// return the interceptor id
int64_t
GetInterceptorId
()
const
;
int64_t
GetInterceptorId
()
const
{
return
interceptor_id_
;
}
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
void
EnqueueRemoteInterceptorMessage
(
void
EnqueueRemoteInterceptorMessage
(
...
@@ -77,6 +76,7 @@ class Interceptor {
...
@@ -77,6 +76,7 @@ class Interceptor {
gc_
=
gc
;
gc_
=
gc
;
}
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
void
RegisterTaskLoop
(
TaskLoop
*
loop
)
{
loop_
=
loop
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
...
@@ -101,28 +101,16 @@ class Interceptor {
...
@@ -101,28 +101,16 @@ class Interceptor {
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
Carrier
*
carrier_
;
Carrier
*
carrier_
;
TaskLoop
*
loop_
;
private:
private:
// pool the local mailbox, parse the Message
void
LoopOnce
();
void
PoolTheMailbox
();
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
bool
FetchRemoteMailbox
();
// interceptor handle which process message
// interceptor handle which process message
MsgHandle
handle_
{
nullptr
};
MsgHandle
handle_
{
nullptr
};
// interceptor runs PoolTheMailbox() function to poll local mailbox
std
::
mutex
mutex_
;
std
::
thread
interceptor_thread_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
framework
::
BlockingQueue
<
InterceptorMessage
>
remote_mailbox_
;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std
::
deque
<
InterceptorMessage
>
local_mailbox_
;
int64_t
already_run_times_
{
0
};
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
int64_t
used_slot_nums_
{
0
};
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
dba59db7
...
@@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank,
...
@@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank,
return
true
;
return
true
;
}
}
void
MessageBus
::
TestConnection
()
{
void
MessageBus
::
IncreaseBarrierCount
()
{
InterceptorMessage
ctrl_msg
;
VLOG
(
3
)
<<
"IncreaseBarrierCount"
;
ctrl_msg
.
set_ctrl_message
(
true
);
{
ctrl_msg
.
set_src_id
(
rank_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
for
(
const
auto
&
dst_rank_pair
:
rank_to_addr_
)
{
++
count_
;
int64_t
dst_rank
=
dst_rank_pair
.
first
;
cv_
.
notify_one
();
if
(
dst_rank
!=
rank_
)
{
}
ctrl_msg
.
set_dst_id
(
dst_rank
);
VLOG
(
3
)
<<
"End IncreaseBarrierCount"
;
VLOG
(
3
)
<<
"Send control message bus from rank "
<<
rank_
<<
" to rank "
}
<<
dst_rank
;
while
(
!
Send
(
dst_rank
,
ctrl_msg
))
{
void
MessageBus
::
Barrier
()
{
// gather to root
if
(
rank_
!=
0
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
ctrl_msg
.
set_dst_id
(
0
);
VLOG
(
3
)
<<
"Barrier Gather ctrl message from "
<<
rank_
<<
" to 0"
;
while
(
!
Send
(
0
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
else
{
VLOG
(
3
)
<<
"Barrier 0 wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
static_cast
<
int
>
(
rank_to_addr_
.
size
()
-
1
);
});
count_
=
0
;
}
// scatter from root
if
(
rank_
==
0
)
{
for
(
int
i
=
1
;
i
<
static_cast
<
int
>
(
rank_to_addr_
.
size
());
++
i
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
0
);
ctrl_msg
.
set_dst_id
(
i
);
VLOG
(
3
)
<<
"Barrier Scatter ctrl message from 0 to "
<<
i
;
while
(
!
Send
(
i
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
dst_rank
<<
"."
;
}
}
}
else
{
VLOG
(
3
)
<<
"Barrier "
<<
rank_
<<
" wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
1
;
});
count_
=
0
;
}
}
}
}
...
@@ -151,7 +183,6 @@ void MessageBus::ListenPort() {
...
@@ -151,7 +183,6 @@ void MessageBus::ListenPort() {
interval
+=
500
;
interval
+=
500
;
}
}
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
TestConnection
();
#else
#else
LOG
(
WARNING
)
LOG
(
WARNING
)
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
dba59db7
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <condition_variable>
#include <mutex>
#include <mutex>
#include <string>
#include <string>
#include <thread>
#include <thread>
...
@@ -51,14 +52,15 @@ class MessageBus final {
...
@@ -51,14 +52,15 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
void
IncreaseBarrierCount
();
void
Barrier
();
private:
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
// function keep listen the port and handle the message
// function keep listen the port and handle the message
void
ListenPort
();
void
ListenPort
();
void
TestConnection
();
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
...
@@ -84,6 +86,11 @@ class MessageBus final {
...
@@ -84,6 +86,11 @@ class MessageBus final {
// brpc server
// brpc server
brpc
::
Server
server_
;
brpc
::
Server
server_
;
#endif
#endif
// for barrier
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
int
count_
{
0
};
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/task_loop.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
thread_local
TaskLoop
*
TaskLoop
::
thread_local_loop_
=
nullptr
;
TaskLoop
*
TaskLoop
::
GetTaskLoopOfCurrentThread
()
{
return
thread_local_loop_
;
}
TaskLoop
::
TaskLoop
()
:
looping_
(
false
),
quit_
(
false
),
thread_id_
(
std
::
this_thread
::
get_id
())
{
PADDLE_ENFORCE_EQ
(
thread_local_loop_
,
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Another TaskLoop is already init."
));
thread_local_loop_
=
this
;
}
TaskLoop
::~
TaskLoop
()
{
thread_local_loop_
=
nullptr
;
}
void
TaskLoop
::
Loop
()
{
PADDLE_ENFORCE_EQ
(
looping_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"Loop can only execute in one loop thread"
));
AssertInLoopThread
();
looping_
=
true
;
quit_
=
false
;
while
(
!
quit_
)
{
auto
tasks
=
tasks_
.
PopAll
();
for
(
auto
&
task
:
tasks
)
{
task
();
}
}
looping_
=
false
;
}
void
TaskLoop
::
Quit
()
{
quit_
=
true
;
if
(
!
IsInLoopThread
())
WakeUp
();
}
void
TaskLoop
::
RunInLoop
(
Functor
cb
)
{
if
(
IsInLoopThread
())
{
cb
();
}
else
{
QueueInLoop
(
cb
);
}
}
void
TaskLoop
::
QueueInLoop
(
Functor
cb
)
{
tasks_
.
Push
(
cb
);
}
void
TaskLoop
::
WakeUp
()
{
Functor
task
([]
{});
QueueInLoop
(
task
);
}
void
TaskLoop
::
AbortNotInLoopThread
()
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"This TaskLoop was created in thread %d, but current thread is %d"
,
thread_id_
,
std
::
this_thread
::
get_id
()));
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
{
public:
static
TaskLoop
*
GetTaskLoopOfCurrentThread
();
using
Functor
=
std
::
function
<
void
()
>
;
TaskLoop
();
~
TaskLoop
();
void
Loop
();
void
Quit
();
void
RunInLoop
(
Functor
cb
);
void
QueueInLoop
(
Functor
cb
);
template
<
class
F
,
class
...
Args
>
auto
Enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...));
std
::
future
<
return_type
>
task_future
=
task
->
get_future
();
tasks_
.
Push
([
task
]()
{
(
*
task
)();
});
return
task_future
;
}
void
WakeUp
();
bool
IsInLoopThread
()
const
{
return
thread_id_
==
std
::
this_thread
::
get_id
();
}
void
AssertInLoopThread
()
{
if
(
!
IsInLoopThread
())
{
AbortNotInLoopThread
();
}
}
private:
void
AbortNotInLoopThread
();
static
thread_local
TaskLoop
*
thread_local_loop_
;
bool
looping_
;
std
::
atomic
<
bool
>
quit_
;
std
::
thread
::
id
thread_id_
;
framework
::
BlockingQueue
<
Functor
>
tasks_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThread
::
TaskLoopThread
()
:
start_
(
false
),
loop_
(
nullptr
)
{}
TaskLoopThread
::~
TaskLoopThread
()
{
if
(
loop_
!=
nullptr
)
{
loop_
->
Quit
();
thread_
.
join
();
}
}
TaskLoop
*
TaskLoopThread
::
StartLoop
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread is already running."
));
start_
=
true
;
thread_
=
std
::
thread
([
this
]()
{
Loop
();
});
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
loop_
!=
nullptr
;
});
return
loop_
;
}
void
TaskLoopThread
::
Loop
()
{
TaskLoop
loop
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
&
loop
;
cv_
.
notify_one
();
}
loop
.
Loop
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
nullptr
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
{
public:
TaskLoopThread
();
~
TaskLoopThread
();
TaskLoop
*
StartLoop
();
private:
void
Loop
();
bool
start_
;
TaskLoop
*
loop_
;
std
::
thread
thread_
;
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThreadPool
::
TaskLoopThreadPool
()
:
TaskLoopThreadPool
(
1
)
{}
TaskLoopThreadPool
::
TaskLoopThreadPool
(
int
thread_num
)
:
start_
(
false
),
thread_num_
(
thread_num
)
{}
TaskLoopThreadPool
::~
TaskLoopThreadPool
()
=
default
;
void
TaskLoopThreadPool
::
Start
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool is already start."
));
PADDLE_ENFORCE_GT
(
thread_num_
,
0
,
platform
::
errors
::
InvalidArgument
(
"thread num must greater than 0, but now is %d"
,
thread_num_
));
start_
=
true
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
threads_
.
emplace_back
(
new
TaskLoopThread
());
loops_
.
push_back
(
threads_
[
i
]
->
StartLoop
());
}
}
TaskLoop
*
TaskLoopThreadPool
::
GetLoop
(
int
tid
)
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
PADDLE_ENFORCE_GE
(
tid
,
0
,
platform
::
errors
::
OutOfRange
(
"tid must >= 0, but now is %d"
,
tid
));
PADDLE_ENFORCE_LT
(
tid
,
thread_num_
,
platform
::
errors
::
OutOfRange
(
"tid must < thread_num, but now tid=%d thread_num=%d"
,
tid
,
thread_num_
));
return
loops_
[
tid
];
}
std
::
vector
<
TaskLoop
*>
TaskLoopThreadPool
::
GetAllLoops
()
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
return
loops_
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
;
class
TaskLoopThreadPool
{
public:
TaskLoopThreadPool
();
explicit
TaskLoopThreadPool
(
int
thread_num
);
~
TaskLoopThreadPool
();
void
SetThreadNum
(
int
thread_num
)
{
thread_num_
=
thread_num
;
}
void
Start
();
TaskLoop
*
GetLoop
(
int
tid
);
std
::
vector
<
TaskLoop
*>
GetAllLoops
();
private:
bool
start_
;
int
thread_num_
;
std
::
vector
<
std
::
unique_ptr
<
TaskLoopThread
>>
threads_
;
std
::
vector
<
TaskLoop
*>
loops_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
dba59db7
...
@@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) {
...
@@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
->
Barrier
();
InterceptorMessage
msg
;
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
...
@@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) {
...
@@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetInterceptor
(
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
->
Barrier
();
carrier
->
Wait
();
carrier
->
Wait
();
int
status
;
int
status
;
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
...
...
paddle/fluid/framework/blocking_queue.h
浏览文件 @
dba59db7
...
@@ -81,6 +81,16 @@ class BlockingQueue {
...
@@ -81,6 +81,16 @@ class BlockingQueue {
std
::
swap
(
*
empty_queue
,
q_
);
std
::
swap
(
*
empty_queue
,
q_
);
}
}
std
::
deque
<
T
>
PopAll
()
{
std
::
deque
<
T
>
ret
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
!
q_
.
empty
();
});
std
::
swap
(
ret
,
q_
);
}
return
ret
;
}
T
Pop
()
{
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录