Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
dba59db7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dba59db7
编写于
12月 27, 2021
作者:
W
WangXi
提交者:
GitHub
12月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Add task loop thread pool (#38420)
上级
5b6b88ab
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
544 addition
and
190 deletion
+544
-190
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+4
-2
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+24
-78
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+12
-14
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+1
-1
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+5
-0
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+43
-59
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+7
-19
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+44
-13
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+9
-2
paddle/fluid/distributed/fleet_executor/task_loop.cc
paddle/fluid/distributed/fleet_executor/task_loop.cc
+82
-0
paddle/fluid/distributed/fleet_executor/task_loop.h
paddle/fluid/distributed/fleet_executor/task_loop.h
+81
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
+58
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
+44
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
...fluid/distributed/fleet_executor/task_loop_thread_pool.cc
+66
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
.../fluid/distributed/fleet_executor/task_loop_thread_pool.h
+47
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+7
-2
paddle/fluid/framework/blocking_queue.h
paddle/fluid/framework/blocking_queue.h
+10
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
dba59db7
...
@@ -10,10 +10,12 @@ else()
...
@@ -10,10 +10,12 @@ else()
set
(
BRPC_DEPS
""
)
set
(
BRPC_DEPS
""
)
endif
()
endif
()
cc_library
(
task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog
)
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto
collective_helper op_registry
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto
task_loop_thread_pool collective_helper
executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
op_registry
executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
dba59db7
...
@@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
...
@@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
place_
=
place
;
place_
=
place
;
root_scope_
=
root_scope
;
root_scope_
=
root_scope
;
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
CreateInterceptors
();
CreateInterceptors
();
is_init_
=
true
;
is_init_
=
true
;
}
}
void
Carrier
::
Release
()
{
void
Carrier
::
Release
()
{}
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
InterceptorMessage
stop_msg
;
// source node STOP is send by carrier, so set src_id=-1
stop_msg
.
set_src_id
(
-
1
);
stop_msg
.
set_dst_id
(
id
);
stop_msg
.
set_message_type
(
STOP
);
Send
(
stop_msg
);
}
// TODO(wangxi): Maybe need a better to use thread.
for
(
auto
&
interceptor
:
interceptor_idx_to_interceptor_
)
{
interceptor
.
second
->
Join
();
}
}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
...
@@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage(
...
@@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage(
VLOG
(
3
)
<<
"Receiving control message from rank "
VLOG
(
3
)
<<
"Receiving control message from rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
<<
interceptor_message
.
dst_id
();
// for barrier
msg_bus_
->
IncreaseBarrierCount
();
}
else
{
}
else
{
{
std
::
unique_lock
<
std
::
mutex
>
lock_creating
(
creating_flag_mutex_
);
if
(
creating_interceptors_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock_message
(
tmp_message_mutex_
);
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG
(
3
)
<<
"Receiving message while creating interceptors."
;
message_tmp_
.
emplace_back
(
interceptor_message
);
return
true
;
}
}
int64_t
dst_id
=
interceptor_message
.
dst_id
();
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
...
@@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage(
...
@@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage(
return
true
;
return
true
;
}
}
void
Carrier
::
Barrier
()
{
msg_bus_
->
Barrier
();
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
PADDLE_ENFORCE_NE
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
...
@@ -109,6 +89,11 @@ void Carrier::Wait() {
...
@@ -109,6 +89,11 @@ void Carrier::Wait() {
cond_var_
.
wait
(
lock
);
cond_var_
.
wait
(
lock
);
}
}
void
Carrier
::
WakeUp
()
{
// probably double notify, but ok for ut
cond_var_
.
notify_all
();
}
void
Carrier
::
Start
()
{
void
Carrier
::
Start
()
{
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
...
@@ -126,12 +111,11 @@ void Carrier::Start() {
...
@@ -126,12 +111,11 @@ void Carrier::Start() {
start_msg
.
set_message_type
(
DATA_IS_READY
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
start_msg
);
Send
(
start_msg
);
}
}
// TODO(wangxi): async step
Wait
();
Wait
();
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
}
}
std
::
condition_variable
&
Carrier
::
GetCondVar
()
{
return
cond_var_
;
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
...
@@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
...
@@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id should be unique."
,
"The interceptor id should be unique."
,
interceptor_id
));
interceptor_id
));
interceptor
->
RegisterCarrier
(
this
);
interceptor
->
RegisterCarrier
(
this
);
// TODO(fleet_exe dev): get loop
auto
*
loop
=
thread_pool_
.
GetLoop
(
interceptor_id
%
thread_num_
);
PADDLE_ENFORCE_NOT_NULL
(
loop
,
platform
::
errors
::
Fatal
(
"thread task loop must not null"
));
interceptor
->
RegisterTaskLoop
(
loop
);
auto
*
ptr
=
interceptor
.
get
();
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
return
ptr
;
return
ptr
;
}
}
void
Carrier
::
SetCreatingFlag
(
bool
flag
)
{
// set the creating flag
creating_flag_mutex_
.
lock
();
VLOG
(
3
)
<<
"Carrier is set the creating flag from "
<<
creating_interceptors_
<<
" to "
<<
flag
<<
"."
;
creating_interceptors_
=
flag
;
creating_flag_mutex_
.
unlock
();
if
(
!
flag
)
{
for
(
auto
&
pair
:
interceptor_idx_to_interceptor_
)
{
// update the source interceptor id
if
(
std
::
find
(
source_interceptor_ids_
.
begin
(),
source_interceptor_ids_
.
end
(),
pair
.
first
)
==
source_interceptor_ids_
.
end
())
{
auto
task
=
pair
.
second
->
GetTaskNode
();
if
(
task
!=
nullptr
&&
task
->
upstream
().
empty
())
{
source_interceptor_ids_
.
emplace_back
(
pair
.
first
);
}
}
}
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages
();
}
}
void
Carrier
::
HandleTmpMessages
()
{
// NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this
// `HandleTmpMessages` method, the creating_interceptors_ flag
// must be false, therefore, there won't have conflict with the
// lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage`
// on the same thread.
std
::
unique_lock
<
std
::
mutex
>
lock
(
tmp_message_mutex_
);
VLOG
(
3
)
<<
"Carrier has received "
<<
message_tmp_
.
size
()
<<
" messages during creating interceptors."
;
for
(
const
auto
&
msg
:
message_tmp_
)
{
EnqueueInterceptorMessage
(
msg
);
}
message_tmp_
.
clear
();
}
static
std
::
shared_ptr
<
framework
::
GarbageCollector
>
GetGC
(
static
std
::
shared_ptr
<
framework
::
GarbageCollector
>
GetGC
(
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
int64_t
max_memory_size
=
framework
::
GetEagerDeletionThreshold
();
int64_t
max_memory_size
=
framework
::
GetEagerDeletionThreshold
();
...
@@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() {
...
@@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() {
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
}
}
}
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_flag_mutex_
.
lock
();
creating_interceptors_
=
false
;
creating_flag_mutex_
.
unlock
();
HandleTmpMessages
();
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
dba59db7
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
...
@@ -47,7 +48,11 @@ class Carrier final {
...
@@ -47,7 +48,11 @@ class Carrier final {
Carrier
()
=
default
;
Carrier
()
=
default
;
Carrier
(
int64_t
rank
,
Carrier
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{}
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
}
~
Carrier
();
~
Carrier
();
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
...
@@ -56,6 +61,7 @@ class Carrier final {
...
@@ -56,6 +61,7 @@ class Carrier final {
void
Release
();
void
Release
();
void
Wait
();
void
Wait
();
void
WakeUp
();
// Enqueue a message to corresponding interceptor id
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
...
@@ -67,23 +73,18 @@ class Carrier final {
...
@@ -67,23 +73,18 @@ class Carrier final {
Interceptor
*
SetInterceptor
(
int64_t
interceptor_id
,
Interceptor
*
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
);
std
::
unique_ptr
<
Interceptor
>
);
void
SetCreatingFlag
(
bool
flag
)
;
void
SetCreatingFlag
(
bool
flag
)
{}
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
msg_bus_
=
msg_bus
;
msg_bus_
=
msg_bus
;
}
}
std
::
condition_variable
&
GetCondVar
();
void
Start
();
void
Start
();
bool
IsInit
()
const
;
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
);
bool
Send
(
const
InterceptorMessage
&
msg
);
// NOTE: This mutex will be used in interceptor's RunOps function.
void
Barrier
();
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std
::
mutex
run
;
private:
private:
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
...
@@ -91,8 +92,6 @@ class Carrier final {
...
@@ -91,8 +92,6 @@ class Carrier final {
// create each Interceptor
// create each Interceptor
void
CreateInterceptors
();
void
CreateInterceptors
();
void
HandleTmpMessages
();
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
// interceptor logic id to actually interceptor
// interceptor logic id to actually interceptor
...
@@ -101,10 +100,6 @@ class Carrier final {
...
@@ -101,10 +100,6 @@ class Carrier final {
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
std
::
vector
<
InterceptorMessage
>
message_tmp_
{};
std
::
mutex
tmp_message_mutex_
;
bool
creating_interceptors_
{
true
};
std
::
mutex
creating_flag_mutex_
;
bool
is_init_
{
false
};
bool
is_init_
{
false
};
std
::
mutex
running_mutex_
;
std
::
mutex
running_mutex_
;
...
@@ -118,6 +113,9 @@ class Carrier final {
...
@@ -118,6 +113,9 @@ class Carrier final {
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
int64_t
rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int
thread_num_
;
TaskLoopThreadPool
thread_pool_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
dba59db7
...
@@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
}
void
ComputeInterceptor
::
RunOps
()
{
void
ComputeInterceptor
::
RunOps
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_
->
run
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
for
(
auto
op
:
node_
->
ops
())
{
...
@@ -198,6 +197,7 @@ void ComputeInterceptor::Run() {
...
@@ -198,6 +197,7 @@ void ComputeInterceptor::Run() {
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" is stopping carrier."
;
<<
" is stopping carrier."
;
// FIXME(wangxi): with multi sink interceptor
StopCarrier
();
StopCarrier
();
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
dba59db7
...
@@ -89,6 +89,11 @@ void FleetExecutor::Init(
...
@@ -89,6 +89,11 @@ void FleetExecutor::Init(
CreateCarrier
();
CreateCarrier
();
InitCarrier
();
InitCarrier
();
InitMessageBus
();
InitMessageBus
();
// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier
()
->
Barrier
();
}
}
void
FleetExecutor
::
InitCarrier
()
{
void
FleetExecutor
::
InitCarrier
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
dba59db7
...
@@ -14,26 +14,21 @@
...
@@ -14,26 +14,21 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{}
interceptor_thread_
=
std
::
thread
([
this
]()
{
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
Interceptor
::~
Interceptor
()
{
<<
" starts the thread pooling it's local mailbox."
;
// FIXME(wangxi): throw in stop function
PoolTheMailbox
();
// std::lock_guard<std::mutex> lock(mutex_);
});
// PADDLE_ENFORCE_EQ(messages_.empty(), true,
}
// platform::errors::PreconditionNotMet(
// "Interceptor must destruct with messages empty"));
Interceptor
::~
Interceptor
()
{
Join
();
}
void
Interceptor
::
Join
()
{
if
(
interceptor_thread_
.
joinable
())
{
interceptor_thread_
.
join
();
}
}
}
void
Interceptor
::
RegisterMsgHandle
(
MsgHandle
handle
)
{
handle_
=
handle
;
}
void
Interceptor
::
RegisterMsgHandle
(
MsgHandle
handle
)
{
handle_
=
handle
;
}
...
@@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
...
@@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
handle_
(
msg
);
handle_
(
msg
);
}
}
void
Interceptor
::
LoopOnce
()
{
std
::
deque
<
InterceptorMessage
>
tmp_messages
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
messages_
.
swap
(
tmp_messages
);
}
PADDLE_ENFORCE_EQ
(
tmp_messages
.
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"tmp_messages must not empty in task loop"
));
for
(
auto
&
msg
:
tmp_messages
)
{
const
MessageType
message_type
=
msg
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
msg
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
Handle
(
msg
);
}
}
void
Interceptor
::
StopCarrier
()
{
void
Interceptor
::
StopCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
"Carrier is not registered."
));
std
::
condition_variable
&
cond_var
=
carrier_
->
GetCondVar
();
carrier_
->
WakeUp
();
// probably double notify, but ok for ut
cond_var
.
notify_all
();
}
int64_t
Interceptor
::
GetInterceptorId
()
const
{
// return the interceptor id
return
interceptor_id_
;
}
}
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_
message
)
{
const
InterceptorMessage
&
message
)
{
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG
(
3
)
<<
"Enqueue message: "
<<
interceptor_message
.
message_type
()
VLOG
(
3
)
<<
"Enqueue message: "
<<
message
.
message_type
()
<<
" into "
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
<<
interceptor_id_
<<
"'s remote mailbox."
;
remote_mailbox_
.
Push
(
interceptor_message
);
bool
empty
=
false
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
empty
=
messages_
.
empty
();
messages_
.
emplace_back
(
message
);
}
if
(
empty
)
{
loop_
->
QueueInLoop
([
this
]()
{
LoopOnce
();
});
}
}
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
...
@@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
...
@@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return
carrier_
->
Send
(
msg
);
return
carrier_
->
Send
(
msg
);
}
}
void
Interceptor
::
PoolTheMailbox
()
{
// pool the local mailbox, parse the Message
for
(;;)
{
if
(
local_mailbox_
.
empty
())
{
// local mailbox is empty, fetch the remote mailbox
VLOG
(
3
)
<<
interceptor_id_
<<
"'s local mailbox is empty. "
<<
"Fetch the remote mailbox."
;
PADDLE_ENFORCE_EQ
(
FetchRemoteMailbox
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Error encountered when fetch remote mailbox."
));
}
const
InterceptorMessage
interceptor_message
=
local_mailbox_
.
front
();
local_mailbox_
.
pop_front
();
const
MessageType
message_type
=
interceptor_message
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
interceptor_message
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
Handle
(
interceptor_message
);
if
(
stop_
)
{
// break the pooling thread
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" is quiting."
;
break
;
}
}
}
bool
Interceptor
::
FetchRemoteMailbox
()
{
remote_mailbox_
.
PopAll
(
&
local_mailbox_
);
return
!
local_mailbox_
.
empty
();
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
interceptorMap
;
static
InterceptorFactory
::
CreateInterceptorMap
interceptorMap
;
return
interceptorMap
;
return
interceptorMap
;
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
dba59db7
...
@@ -38,6 +38,7 @@ namespace distributed {
...
@@ -38,6 +38,7 @@ namespace distributed {
class
TaskNode
;
class
TaskNode
;
class
Carrier
;
class
Carrier
;
class
TaskLoop
;
class
Interceptor
{
class
Interceptor
{
public:
public:
...
@@ -50,15 +51,13 @@ class Interceptor {
...
@@ -50,15 +51,13 @@ class Interceptor {
virtual
~
Interceptor
();
virtual
~
Interceptor
();
void
Join
();
// register interceptor handle
// register interceptor handle
void
RegisterMsgHandle
(
MsgHandle
handle
);
void
RegisterMsgHandle
(
MsgHandle
handle
);
void
Handle
(
const
InterceptorMessage
&
msg
);
void
Handle
(
const
InterceptorMessage
&
msg
);
// return the interceptor id
// return the interceptor id
int64_t
GetInterceptorId
()
const
;
int64_t
GetInterceptorId
()
const
{
return
interceptor_id_
;
}
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
void
EnqueueRemoteInterceptorMessage
(
void
EnqueueRemoteInterceptorMessage
(
...
@@ -77,6 +76,7 @@ class Interceptor {
...
@@ -77,6 +76,7 @@ class Interceptor {
gc_
=
gc
;
gc_
=
gc
;
}
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
void
RegisterTaskLoop
(
TaskLoop
*
loop
)
{
loop_
=
loop
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
...
@@ -101,28 +101,16 @@ class Interceptor {
...
@@ -101,28 +101,16 @@ class Interceptor {
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
Carrier
*
carrier_
;
Carrier
*
carrier_
;
TaskLoop
*
loop_
;
private:
private:
// pool the local mailbox, parse the Message
void
LoopOnce
();
void
PoolTheMailbox
();
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
bool
FetchRemoteMailbox
();
// interceptor handle which process message
// interceptor handle which process message
MsgHandle
handle_
{
nullptr
};
MsgHandle
handle_
{
nullptr
};
// interceptor runs PoolTheMailbox() function to poll local mailbox
std
::
mutex
mutex_
;
std
::
thread
interceptor_thread_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
framework
::
BlockingQueue
<
InterceptorMessage
>
remote_mailbox_
;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std
::
deque
<
InterceptorMessage
>
local_mailbox_
;
int64_t
already_run_times_
{
0
};
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
int64_t
used_slot_nums_
{
0
};
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
dba59db7
...
@@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank,
...
@@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank,
return
true
;
return
true
;
}
}
void
MessageBus
::
TestConnection
()
{
void
MessageBus
::
IncreaseBarrierCount
()
{
InterceptorMessage
ctrl_msg
;
VLOG
(
3
)
<<
"IncreaseBarrierCount"
;
ctrl_msg
.
set_ctrl_message
(
true
);
{
ctrl_msg
.
set_src_id
(
rank_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
for
(
const
auto
&
dst_rank_pair
:
rank_to_addr_
)
{
++
count_
;
int64_t
dst_rank
=
dst_rank_pair
.
first
;
cv_
.
notify_one
();
if
(
dst_rank
!=
rank_
)
{
}
ctrl_msg
.
set_dst_id
(
dst_rank
);
VLOG
(
3
)
<<
"End IncreaseBarrierCount"
;
VLOG
(
3
)
<<
"Send control message bus from rank "
<<
rank_
<<
" to rank "
}
<<
dst_rank
;
while
(
!
Send
(
dst_rank
,
ctrl_msg
))
{
void
MessageBus
::
Barrier
()
{
// gather to root
if
(
rank_
!=
0
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
ctrl_msg
.
set_dst_id
(
0
);
VLOG
(
3
)
<<
"Barrier Gather ctrl message from "
<<
rank_
<<
" to 0"
;
while
(
!
Send
(
0
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
else
{
VLOG
(
3
)
<<
"Barrier 0 wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
static_cast
<
int
>
(
rank_to_addr_
.
size
()
-
1
);
});
count_
=
0
;
}
// scatter from root
if
(
rank_
==
0
)
{
for
(
int
i
=
1
;
i
<
static_cast
<
int
>
(
rank_to_addr_
.
size
());
++
i
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
0
);
ctrl_msg
.
set_dst_id
(
i
);
VLOG
(
3
)
<<
"Barrier Scatter ctrl message from 0 to "
<<
i
;
while
(
!
Send
(
i
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
dst_rank
<<
"."
;
}
}
}
else
{
VLOG
(
3
)
<<
"Barrier "
<<
rank_
<<
" wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
1
;
});
count_
=
0
;
}
}
}
}
...
@@ -151,7 +183,6 @@ void MessageBus::ListenPort() {
...
@@ -151,7 +183,6 @@ void MessageBus::ListenPort() {
interval
+=
500
;
interval
+=
500
;
}
}
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
TestConnection
();
#else
#else
LOG
(
WARNING
)
LOG
(
WARNING
)
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
dba59db7
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <condition_variable>
#include <mutex>
#include <mutex>
#include <string>
#include <string>
#include <thread>
#include <thread>
...
@@ -51,14 +52,15 @@ class MessageBus final {
...
@@ -51,14 +52,15 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
void
IncreaseBarrierCount
();
void
Barrier
();
private:
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
// function keep listen the port and handle the message
// function keep listen the port and handle the message
void
ListenPort
();
void
ListenPort
();
void
TestConnection
();
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
...
@@ -84,6 +86,11 @@ class MessageBus final {
...
@@ -84,6 +86,11 @@ class MessageBus final {
// brpc server
// brpc server
brpc
::
Server
server_
;
brpc
::
Server
server_
;
#endif
#endif
// for barrier
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
int
count_
{
0
};
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/task_loop.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
thread_local
TaskLoop
*
TaskLoop
::
thread_local_loop_
=
nullptr
;
TaskLoop
*
TaskLoop
::
GetTaskLoopOfCurrentThread
()
{
return
thread_local_loop_
;
}
TaskLoop
::
TaskLoop
()
:
looping_
(
false
),
quit_
(
false
),
thread_id_
(
std
::
this_thread
::
get_id
())
{
PADDLE_ENFORCE_EQ
(
thread_local_loop_
,
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Another TaskLoop is already init."
));
thread_local_loop_
=
this
;
}
TaskLoop
::~
TaskLoop
()
{
thread_local_loop_
=
nullptr
;
}
void
TaskLoop
::
Loop
()
{
PADDLE_ENFORCE_EQ
(
looping_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"Loop can only execute in one loop thread"
));
AssertInLoopThread
();
looping_
=
true
;
quit_
=
false
;
while
(
!
quit_
)
{
auto
tasks
=
tasks_
.
PopAll
();
for
(
auto
&
task
:
tasks
)
{
task
();
}
}
looping_
=
false
;
}
void
TaskLoop
::
Quit
()
{
quit_
=
true
;
if
(
!
IsInLoopThread
())
WakeUp
();
}
void
TaskLoop
::
RunInLoop
(
Functor
cb
)
{
if
(
IsInLoopThread
())
{
cb
();
}
else
{
QueueInLoop
(
cb
);
}
}
void
TaskLoop
::
QueueInLoop
(
Functor
cb
)
{
tasks_
.
Push
(
cb
);
}
void
TaskLoop
::
WakeUp
()
{
Functor
task
([]
{});
QueueInLoop
(
task
);
}
void
TaskLoop
::
AbortNotInLoopThread
()
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"This TaskLoop was created in thread %d, but current thread is %d"
,
thread_id_
,
std
::
this_thread
::
get_id
()));
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
{
public:
static
TaskLoop
*
GetTaskLoopOfCurrentThread
();
using
Functor
=
std
::
function
<
void
()
>
;
TaskLoop
();
~
TaskLoop
();
void
Loop
();
void
Quit
();
void
RunInLoop
(
Functor
cb
);
void
QueueInLoop
(
Functor
cb
);
template
<
class
F
,
class
...
Args
>
auto
Enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...));
std
::
future
<
return_type
>
task_future
=
task
->
get_future
();
tasks_
.
Push
([
task
]()
{
(
*
task
)();
});
return
task_future
;
}
void
WakeUp
();
bool
IsInLoopThread
()
const
{
return
thread_id_
==
std
::
this_thread
::
get_id
();
}
void
AssertInLoopThread
()
{
if
(
!
IsInLoopThread
())
{
AbortNotInLoopThread
();
}
}
private:
void
AbortNotInLoopThread
();
static
thread_local
TaskLoop
*
thread_local_loop_
;
bool
looping_
;
std
::
atomic
<
bool
>
quit_
;
std
::
thread
::
id
thread_id_
;
framework
::
BlockingQueue
<
Functor
>
tasks_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThread
::
TaskLoopThread
()
:
start_
(
false
),
loop_
(
nullptr
)
{}
TaskLoopThread
::~
TaskLoopThread
()
{
if
(
loop_
!=
nullptr
)
{
loop_
->
Quit
();
thread_
.
join
();
}
}
TaskLoop
*
TaskLoopThread
::
StartLoop
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread is already running."
));
start_
=
true
;
thread_
=
std
::
thread
([
this
]()
{
Loop
();
});
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
loop_
!=
nullptr
;
});
return
loop_
;
}
void
TaskLoopThread
::
Loop
()
{
TaskLoop
loop
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
&
loop
;
cv_
.
notify_one
();
}
loop
.
Loop
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
nullptr
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
{
public:
TaskLoopThread
();
~
TaskLoopThread
();
TaskLoop
*
StartLoop
();
private:
void
Loop
();
bool
start_
;
TaskLoop
*
loop_
;
std
::
thread
thread_
;
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThreadPool
::
TaskLoopThreadPool
()
:
TaskLoopThreadPool
(
1
)
{}
TaskLoopThreadPool
::
TaskLoopThreadPool
(
int
thread_num
)
:
start_
(
false
),
thread_num_
(
thread_num
)
{}
TaskLoopThreadPool
::~
TaskLoopThreadPool
()
=
default
;
void
TaskLoopThreadPool
::
Start
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool is already start."
));
PADDLE_ENFORCE_GT
(
thread_num_
,
0
,
platform
::
errors
::
InvalidArgument
(
"thread num must greater than 0, but now is %d"
,
thread_num_
));
start_
=
true
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
threads_
.
emplace_back
(
new
TaskLoopThread
());
loops_
.
push_back
(
threads_
[
i
]
->
StartLoop
());
}
}
TaskLoop
*
TaskLoopThreadPool
::
GetLoop
(
int
tid
)
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
PADDLE_ENFORCE_GE
(
tid
,
0
,
platform
::
errors
::
OutOfRange
(
"tid must >= 0, but now is %d"
,
tid
));
PADDLE_ENFORCE_LT
(
tid
,
thread_num_
,
platform
::
errors
::
OutOfRange
(
"tid must < thread_num, but now tid=%d thread_num=%d"
,
tid
,
thread_num_
));
return
loops_
[
tid
];
}
std
::
vector
<
TaskLoop
*>
TaskLoopThreadPool
::
GetAllLoops
()
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
return
loops_
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
0 → 100644
浏览文件 @
dba59db7
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
;
class
TaskLoopThreadPool
{
public:
TaskLoopThreadPool
();
explicit
TaskLoopThreadPool
(
int
thread_num
);
~
TaskLoopThreadPool
();
void
SetThreadNum
(
int
thread_num
)
{
thread_num_
=
thread_num
;
}
void
Start
();
TaskLoop
*
GetLoop
(
int
tid
);
std
::
vector
<
TaskLoop
*>
GetAllLoops
();
private:
bool
start_
;
int
thread_num_
;
std
::
vector
<
std
::
unique_ptr
<
TaskLoopThread
>>
threads_
;
std
::
vector
<
TaskLoop
*>
loops_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
dba59db7
...
@@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) {
...
@@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
->
Barrier
();
InterceptorMessage
msg
;
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
...
@@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) {
...
@@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetInterceptor
(
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
->
Barrier
();
carrier
->
Wait
();
carrier
->
Wait
();
int
status
;
int
status
;
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
...
...
paddle/fluid/framework/blocking_queue.h
浏览文件 @
dba59db7
...
@@ -81,6 +81,16 @@ class BlockingQueue {
...
@@ -81,6 +81,16 @@ class BlockingQueue {
std
::
swap
(
*
empty_queue
,
q_
);
std
::
swap
(
*
empty_queue
,
q_
);
}
}
std
::
deque
<
T
>
PopAll
()
{
std
::
deque
<
T
>
ret
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
!
q_
.
empty
();
});
std
::
swap
(
ret
,
q_
);
}
return
ret
;
}
T
Pop
()
{
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录