Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
843435ff
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
843435ff
编写于
12月 17, 2021
作者:
L
LiYuRio
提交者:
GitHub
12月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Fix the problem in fleet executor stop (#38114)
上级
e3b033f9
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
140 addition
and
83 deletion
+140
-83
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+22
-11
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+12
-10
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+1
-2
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+22
-14
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+6
-1
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+6
-4
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+4
-0
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+1
-2
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+2
-5
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+5
-12
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+9
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+10
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+9
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+14
-6
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+9
-4
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+8
-3
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
843435ff
...
@@ -49,10 +49,11 @@ void Carrier::Release() {
...
@@ -49,10 +49,11 @@ void Carrier::Release() {
// otherwise Derived object will be destructed before thread complete.
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
// Sending STOP msg to the source interceptor
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
PADDLE_ENFORCE_EQ
(
msg_bus
.
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Message bus has not been initialized."
));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
<<
"."
;
...
@@ -61,7 +62,7 @@ void Carrier::Release() {
...
@@ -61,7 +62,7 @@ void Carrier::Release() {
stop_msg
.
set_src_id
(
-
1
);
stop_msg
.
set_src_id
(
-
1
);
stop_msg
.
set_dst_id
(
id
);
stop_msg
.
set_dst_id
(
id
);
stop_msg
.
set_message_type
(
STOP
);
stop_msg
.
set_message_type
(
STOP
);
msg_bus
.
Send
(
stop_msg
);
Send
(
stop_msg
);
}
}
// TODO(wangxi): Maybe need a better to use thread.
// TODO(wangxi): Maybe need a better to use thread.
...
@@ -113,11 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
...
@@ -113,11 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return
iter
->
second
.
get
();
return
iter
->
second
.
get
();
}
}
void
Carrier
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
cond_var_
.
wait
(
lock
);
}
void
Carrier
::
Start
()
{
void
Carrier
::
Start
()
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
PADDLE_ENFORCE_EQ
(
msg_bus
.
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Message bus has not been initialized."
));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
...
@@ -127,11 +134,9 @@ void Carrier::Start() {
...
@@ -127,11 +134,9 @@ void Carrier::Start() {
start_msg
.
set_src_id
(
-
1
);
start_msg
.
set_src_id
(
-
1
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
msg_bus
.
Send
(
start_msg
);
Send
(
start_msg
);
}
}
Wait
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
cond_var_
.
wait
(
lock
);
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
}
}
...
@@ -139,6 +144,11 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
...
@@ -139,6 +144,11 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
// TODO(liyurui): Move SendIntra into carrier
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
const
{
return
msg_bus_
->
Send
(
msg
);
}
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
interceptor
)
{
std
::
unique_ptr
<
Interceptor
>
interceptor
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
...
@@ -147,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
...
@@ -147,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id %lld has already been created! "
"The interceptor id %lld has already been created! "
"The interceptor id should be unique."
,
"The interceptor id should be unique."
,
interceptor_id
));
interceptor_id
));
interceptor
->
RegisterCarrier
(
this
);
auto
*
ptr
=
interceptor
.
get
();
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
843435ff
...
@@ -40,22 +40,19 @@ namespace distributed {
...
@@ -40,22 +40,19 @@ namespace distributed {
class
TaskNode
;
class
TaskNode
;
class
InterceptorMessageServiceImpl
;
class
InterceptorMessageServiceImpl
;
class
RuntimeGraph
;
class
RuntimeGraph
;
class
MessageBus
;
// A singleton MessageBus
class
Carrier
final
{
class
Carrier
final
{
public:
public:
static
Carrier
&
Instance
()
{
Carrier
()
=
default
;
static
Carrier
carrier
;
~
Carrier
();
return
carrier
;
}
void
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
const
platform
::
Place
&
place
);
~
Carrier
();
void
Release
();
void
Release
();
void
Wait
();
// Enqueue a message to corresponding interceptor id
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
...
@@ -68,6 +65,9 @@ class Carrier final {
...
@@ -68,6 +65,9 @@ class Carrier final {
std
::
unique_ptr
<
Interceptor
>
);
std
::
unique_ptr
<
Interceptor
>
);
void
SetCreatingFlag
(
bool
flag
);
void
SetCreatingFlag
(
bool
flag
);
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
msg_bus_
=
msg_bus
;
}
std
::
condition_variable
&
GetCondVar
();
std
::
condition_variable
&
GetCondVar
();
...
@@ -75,15 +75,15 @@ class Carrier final {
...
@@ -75,15 +75,15 @@ class Carrier final {
bool
IsInit
()
const
;
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
)
const
;
// NOTE: This mutex will be used in interceptor's RunOps function.
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
// simultaneously, which will lead to a random hang for some sync ops.
std
::
mutex
run
;
std
::
mutex
run
;
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
private:
private:
Carrier
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
Carrier
)
;
// create each Interceptor
// create each Interceptor
void
CreateInterceptors
();
void
CreateInterceptors
();
...
@@ -110,6 +110,8 @@ class Carrier final {
...
@@ -110,6 +110,8 @@ class Carrier final {
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
843435ff
...
@@ -170,8 +170,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -170,8 +170,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
}
void
ComputeInterceptor
::
RunOps
()
{
void
ComputeInterceptor
::
RunOps
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_
->
run
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_instance
.
run
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
for
(
auto
op
:
node_
->
ops
())
{
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
843435ff
...
@@ -34,7 +34,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
...
@@ -34,7 +34,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
"Error occurs while parsing string to proto"
));
"Error occurs while parsing string to proto"
));
}
}
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
}
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
GetCarrier
().
Release
();
}
Carrier
&
FleetExecutor
::
GetCarrier
()
{
static
Carrier
carrier
;
return
carrier
;
}
void
FleetExecutor
::
Init
(
void
FleetExecutor
::
Init
(
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
...
@@ -78,14 +86,16 @@ void FleetExecutor::Init(
...
@@ -78,14 +86,16 @@ void FleetExecutor::Init(
CopyParameters
(
i
,
program_desc
);
CopyParameters
(
i
,
program_desc
);
}
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
InitCarrier
();
InitCarrier
();
InitMessageBus
();
InitMessageBus
();
}
}
void
FleetExecutor
::
InitCarrier
()
{
void
FleetExecutor
::
InitCarrier
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
Carrier
&
carrier
=
GetCarrier
();
if
(
!
carrier_instance
.
IsInit
())
{
if
(
!
carrier
.
IsInit
())
{
carrier_instance
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
carrier
.
SetMsgBus
(
msg_bus_
);
carrier
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
microbatch_scopes_
,
place_
);
}
}
}
}
...
@@ -120,24 +130,22 @@ void FleetExecutor::InitMessageBus() {
...
@@ -120,24 +130,22 @@ void FleetExecutor::InitMessageBus() {
VLOG
(
3
)
<<
"The number of ranks are "
VLOG
(
3
)
<<
"The number of ranks are "
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
VLOG
(
5
)
<<
ss
.
str
();
VLOG
(
5
)
<<
ss
.
str
();
MessageBus
&
message_bus_instance
=
MessageBus
::
Instance
();
if
(
!
msg_bus_
->
IsInit
())
{
if
(
!
message_bus_instance
.
IsInit
())
{
msg_bus_
->
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
rank_to_addr
,
message_bus_instance
.
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
addr
);
rank_to_addr
,
addr
);
}
}
}
}
void
FleetExecutor
::
Run
()
{
void
FleetExecutor
::
Run
()
{
// Run
// Run
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
Carrier
&
carrier
=
GetCarrier
();
MessageBus
&
message_bus_instance
=
MessageBus
::
Instance
();
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
carrier
_instance
.
IsInit
(),
true
,
carrier
.
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
m
essage_bus_instance
.
IsInit
(),
true
,
m
sg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
carrier
_instance
.
Start
();
carrier
.
Start
();
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
843435ff
...
@@ -28,9 +28,9 @@ class Scope;
...
@@ -28,9 +28,9 @@ class Scope;
namespace
distributed
{
namespace
distributed
{
class
RuntimeGraph
;
class
RuntimeGraph
;
class
Carrier
;
class
MessageBus
;
class
MessageBus
;
class
TaskNode
;
class
TaskNode
;
class
Carrier
;
class
FleetExecutor
final
{
class
FleetExecutor
final
{
public:
public:
...
@@ -42,6 +42,8 @@ class FleetExecutor final {
...
@@ -42,6 +42,8 @@ class FleetExecutor final {
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
void
Run
();
// TODO(liyurui): Change to use registry table for multi-carrier.
static
Carrier
&
GetCarrier
();
private:
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
@@ -54,6 +56,9 @@ class FleetExecutor final {
...
@@ -54,6 +56,9 @@ class FleetExecutor final {
framework
::
Scope
*
minibatch_scope_
;
framework
::
Scope
*
minibatch_scope_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
;
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
;
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
843435ff
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -46,8 +45,9 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
...
@@ -46,8 +45,9 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
}
}
void
Interceptor
::
StopCarrier
()
{
void
Interceptor
::
StopCarrier
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
std
::
condition_variable
&
cond_var
=
carrier_instance
.
GetCondVar
();
"Carrier is not registered."
));
std
::
condition_variable
&
cond_var
=
carrier_
->
GetCondVar
();
// probably double notify, but ok for ut
// probably double notify, but ok for ut
cond_var
.
notify_all
();
cond_var
.
notify_all
();
}
}
...
@@ -73,9 +73,11 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
...
@@ -73,9 +73,11 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
}
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
msg
.
set_src_id
(
interceptor_id_
);
msg
.
set_src_id
(
interceptor_id_
);
msg
.
set_dst_id
(
dst_id
);
msg
.
set_dst_id
(
dst_id
);
return
MessageBus
::
Instance
().
Send
(
msg
);
return
carrier_
->
Send
(
msg
);
}
}
void
Interceptor
::
PoolTheMailbox
()
{
void
Interceptor
::
PoolTheMailbox
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
843435ff
...
@@ -36,6 +36,7 @@ class GarbageCollector;
...
@@ -36,6 +36,7 @@ class GarbageCollector;
namespace
distributed
{
namespace
distributed
{
class
TaskNode
;
class
TaskNode
;
class
Carrier
;
class
Interceptor
{
class
Interceptor
{
public:
public:
...
@@ -77,6 +78,7 @@ class Interceptor {
...
@@ -77,6 +78,7 @@ class Interceptor {
void
SetGC
(
const
std
::
shared_ptr
<
framework
::
GarbageCollector
>&
gc
)
{
void
SetGC
(
const
std
::
shared_ptr
<
framework
::
GarbageCollector
>&
gc
)
{
gc_
=
gc
;
gc_
=
gc
;
}
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
...
@@ -100,6 +102,8 @@ class Interceptor {
...
@@ -100,6 +102,8 @@ class Interceptor {
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
{};
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
{};
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
Carrier
*
carrier_
;
private:
private:
// pool the local mailbox, parse the Message
// pool the local mailbox, parse the Message
void
PoolTheMailbox
();
void
PoolTheMailbox
();
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
843435ff
...
@@ -29,9 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
...
@@ -29,9 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
<<
", with the message: "
<<
request
->
message_type
();
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
true
);
response
->
set_rst
(
true
);
// call interceptor manager's method to handle the message
Carrier
::
Instance
().
EnqueueInterceptorMessage
(
*
request
);
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
843435ff
...
@@ -57,10 +57,6 @@ void MessageBus::Init(
...
@@ -57,10 +57,6 @@ void MessageBus::Init(
bool
MessageBus
::
IsInit
()
const
{
return
is_init_
;
}
bool
MessageBus
::
IsInit
()
const
{
return
is_init_
;
}
MessageBus
::~
MessageBus
()
{
MessageBus
::~
MessageBus
()
{
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier
&
carrier
=
Carrier
::
Instance
();
carrier
.
Release
();
VLOG
(
3
)
<<
"Message bus releases resource."
;
VLOG
(
3
)
<<
"Message bus releases resource."
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
!defined(PADDLE_WITH_ASCEND_CL)
...
@@ -245,7 +241,8 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
...
@@ -245,7 +241,8 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
bool
MessageBus
::
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
)
{
bool
MessageBus
::
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
)
{
// send the message intra rank (dst is the same rank with src)
// send the message intra rank (dst is the same rank with src)
return
Carrier
::
Instance
().
EnqueueInterceptorMessage
(
interceptor_message
);
return
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
interceptor_message
);
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
843435ff
...
@@ -39,10 +39,8 @@ class Carrier;
...
@@ -39,10 +39,8 @@ class Carrier;
// A singleton MessageBus
// A singleton MessageBus
class
MessageBus
final
{
class
MessageBus
final
{
public:
public:
static
MessageBus
&
Instance
()
{
MessageBus
()
=
default
;
static
MessageBus
msg_bus
;
~
MessageBus
();
return
msg_bus
;
}
void
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
void
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
...
@@ -53,12 +51,8 @@ class MessageBus final {
...
@@ -53,12 +51,8 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
const
InterceptorMessage
&
interceptor_message
);
~
MessageBus
();
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
private:
private:
MessageBus
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
MessageBus
)
;
// function keep listen the port and handle the message
// function keep listen the port and handle the message
void
ListenPort
();
void
ListenPort
();
...
@@ -72,12 +66,11 @@ class MessageBus final {
...
@@ -72,12 +66,11 @@ class MessageBus final {
bool
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
);
bool
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
);
#endif
#endif
bool
is_init_
{
false
};
// send the message intra rank (dst is the same rank with src)
// send the message intra rank (dst is the same rank with src)
bool
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
);
bool
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
);
bool
is_init_
{
false
};
std
::
once_flag
once_flag_
;
// handed by above layer, save the info mapping interceptor id to rank id
// handed by above layer, save the info mapping interceptor id to rank id
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
843435ff
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
@@ -61,10 +62,12 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -61,10 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
platform
::
Place
place
=
platform
::
CPUPlace
();
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
TaskNode
*
node_a
=
...
@@ -90,6 +93,9 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -90,6 +93,9 @@ TEST(ComputeInterceptor, Compute) {
msg
.
set_src_id
(
-
1
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
843435ff
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
@@ -46,9 +47,12 @@ class StartInterceptor : public Interceptor {
...
@@ -46,9 +47,12 @@ class StartInterceptor : public Interceptor {
};
};
TEST
(
ComputeInterceptor
,
Compute
)
{
TEST
(
ComputeInterceptor
,
Compute
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
...
@@ -74,6 +78,9 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -74,6 +78,9 @@ TEST(ComputeInterceptor, Compute) {
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
843435ff
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
@@ -44,6 +45,7 @@ class PingPongInterceptor : public Interceptor {
...
@@ -44,6 +45,7 @@ class PingPongInterceptor : public Interceptor {
stop
.
set_message_type
(
STOP
);
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
Send
(
1
,
stop
);
StopCarrier
();
return
;
return
;
}
}
...
@@ -58,10 +60,12 @@ class PingPongInterceptor : public Interceptor {
...
@@ -58,10 +60,12 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
TEST
(
InterceptorTest
,
PingPong
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
(
);
Carrier
&
carrier
=
Carrier
::
Instance
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
...
@@ -71,6 +75,8 @@ TEST(InterceptorTest, PingPong) {
...
@@ -71,6 +75,8 @@ TEST(InterceptorTest, PingPong) {
InterceptorMessage
msg
;
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
843435ff
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
@@ -36,6 +37,7 @@ class PingPongInterceptor : public Interceptor {
...
@@ -36,6 +37,7 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
stop_
=
true
;
StopCarrier
();
return
;
return
;
}
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
...
@@ -105,10 +107,12 @@ TEST(InterceptorTest, PingPong) {
...
@@ -105,10 +107,12 @@ TEST(InterceptorTest, PingPong) {
int
pid
=
fork
();
int
pid
=
fork
();
if
(
pid
==
0
)
{
if
(
pid
==
0
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
...
@@ -116,15 +120,19 @@ TEST(InterceptorTest, PingPong) {
...
@@ -116,15 +120,19 @@ TEST(InterceptorTest, PingPong) {
InterceptorMessage
msg
;
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
}
else
{
}
else
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetInterceptor
(
1
,
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
carrier
.
SetCreatingFlag
(
false
);
carrier
.
Wait
();
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
843435ff
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
...
@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
int64_t
micro_steps
=
3
;
...
@@ -88,6 +91,8 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -88,6 +91,8 @@ TEST(AmplifierInterceptor, Amplifier) {
msg
.
set_src_id
(
-
1
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
843435ff
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
@@ -69,9 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
...
@@ -69,9 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
""
}},
""
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
""
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
int64_t
micro_steps
=
6
;
...
@@ -103,6 +106,8 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -103,6 +106,8 @@ TEST(AmplifierInterceptor, Amplifier) {
msg
.
set_src_id
(
-
1
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
}
// namespace distributed
}
// namespace distributed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录