Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
843435ff
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 1 年 前同步成功
通知
2292
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
843435ff
编写于
12月 17, 2021
作者:
L
LiYuRio
提交者:
GitHub
12月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Fix the problem in fleet executor stop (#38114)
上级
e3b033f9
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
140 addition
and
83 deletion
+140
-83
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+22
-11
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+12
-10
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+1
-2
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+22
-14
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+6
-1
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+6
-4
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+4
-0
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+1
-2
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+2
-5
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+5
-12
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+9
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+10
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+9
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+14
-6
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+9
-4
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+8
-3
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
843435ff
...
...
@@ -49,10 +49,11 @@ void Carrier::Release() {
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
PADDLE_ENFORCE_EQ
(
msg_bus
.
IsInit
(),
true
,
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Message bus has not been initialized."
));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
...
...
@@ -61,7 +62,7 @@ void Carrier::Release() {
stop_msg
.
set_src_id
(
-
1
);
stop_msg
.
set_dst_id
(
id
);
stop_msg
.
set_message_type
(
STOP
);
msg_bus
.
Send
(
stop_msg
);
Send
(
stop_msg
);
}
// TODO(wangxi): Maybe need a better to use thread.
...
...
@@ -113,11 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return
iter
->
second
.
get
();
}
void
Carrier
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
cond_var_
.
wait
(
lock
);
}
void
Carrier
::
Start
()
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
PADDLE_ENFORCE_EQ
(
msg_bus
.
IsInit
(),
true
,
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Message bus has not been initialized."
));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
...
...
@@ -127,11 +134,9 @@ void Carrier::Start() {
start_msg
.
set_src_id
(
-
1
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
msg_bus
.
Send
(
start_msg
);
Send
(
start_msg
);
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
cond_var_
.
wait
(
lock
);
Wait
();
dev_ctx_
->
Wait
();
}
...
...
@@ -139,6 +144,11 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
// TODO(liyurui): Move SendIntra into carrier
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
const
{
return
msg_bus_
->
Send
(
msg
);
}
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
interceptor
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
...
...
@@ -147,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id %lld has already been created! "
"The interceptor id should be unique."
,
interceptor_id
));
interceptor
->
RegisterCarrier
(
this
);
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
843435ff
...
...
@@ -40,22 +40,19 @@ namespace distributed {
class
TaskNode
;
class
InterceptorMessageServiceImpl
;
class
RuntimeGraph
;
class
MessageBus
;
// A singleton MessageBus
class
Carrier
final
{
public:
static
Carrier
&
Instance
()
{
static
Carrier
carrier
;
return
carrier
;
}
Carrier
()
=
default
;
~
Carrier
();
void
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
~
Carrier
();
void
Release
();
void
Wait
();
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
...
...
@@ -68,6 +65,9 @@ class Carrier final {
std
::
unique_ptr
<
Interceptor
>
);
void
SetCreatingFlag
(
bool
flag
);
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
msg_bus_
=
msg_bus
;
}
std
::
condition_variable
&
GetCondVar
();
...
...
@@ -75,15 +75,15 @@ class Carrier final {
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
)
const
;
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std
::
mutex
run
;
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
private:
Carrier
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
Carrier
)
;
// create each Interceptor
void
CreateInterceptors
();
...
...
@@ -110,6 +110,8 @@ class Carrier final {
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
843435ff
...
...
@@ -170,8 +170,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
void
ComputeInterceptor
::
RunOps
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_instance
.
run
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_
->
run
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
843435ff
...
...
@@ -34,7 +34,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
"Error occurs while parsing string to proto"
));
}
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
}
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
GetCarrier
().
Release
();
}
Carrier
&
FleetExecutor
::
GetCarrier
()
{
static
Carrier
carrier
;
return
carrier
;
}
void
FleetExecutor
::
Init
(
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
...
...
@@ -78,15 +86,17 @@ void FleetExecutor::Init(
CopyParameters
(
i
,
program_desc
);
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
InitCarrier
();
InitMessageBus
();
}
void
FleetExecutor
::
InitCarrier
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
if
(
!
carrier_instance
.
IsInit
())
{
carrier_instance
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
Carrier
&
carrier
=
GetCarrier
();
if
(
!
carrier
.
IsInit
())
{
carrier
.
SetMsgBus
(
msg_bus_
);
carrier
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
}
}
...
...
@@ -120,24 +130,22 @@ void FleetExecutor::InitMessageBus() {
VLOG
(
3
)
<<
"The number of ranks are "
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
VLOG
(
5
)
<<
ss
.
str
();
MessageBus
&
message_bus_instance
=
MessageBus
::
Instance
();
if
(
!
message_bus_instance
.
IsInit
())
{
message_bus_instance
.
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
rank_to_addr
,
addr
);
if
(
!
msg_bus_
->
IsInit
())
{
msg_bus_
->
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
rank_to_addr
,
addr
);
}
}
void
FleetExecutor
::
Run
()
{
// Run
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
MessageBus
&
message_bus_instance
=
MessageBus
::
Instance
();
Carrier
&
carrier
=
GetCarrier
();
PADDLE_ENFORCE_EQ
(
carrier
_instance
.
IsInit
(),
true
,
carrier
.
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
PADDLE_ENFORCE_EQ
(
m
essage_bus_instance
.
IsInit
(),
true
,
m
sg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
carrier
_instance
.
Start
();
carrier
.
Start
();
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
843435ff
...
...
@@ -28,9 +28,9 @@ class Scope;
namespace
distributed
{
class
RuntimeGraph
;
class
Carrier
;
class
MessageBus
;
class
TaskNode
;
class
Carrier
;
class
FleetExecutor
final
{
public:
...
...
@@ -42,6 +42,8 @@ class FleetExecutor final {
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
// TODO(liyurui): Change to use registry table for multi-carrier.
static
Carrier
&
GetCarrier
();
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
...
@@ -54,6 +56,9 @@ class FleetExecutor final {
framework
::
Scope
*
minibatch_scope_
;
platform
::
Place
place_
;
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
;
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
843435ff
...
...
@@ -14,7 +14,6 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
...
...
@@ -46,8 +45,9 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
}
void
Interceptor
::
StopCarrier
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
std
::
condition_variable
&
cond_var
=
carrier_instance
.
GetCondVar
();
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
std
::
condition_variable
&
cond_var
=
carrier_
->
GetCondVar
();
// probably double notify, but ok for ut
cond_var
.
notify_all
();
}
...
...
@@ -73,9 +73,11 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
msg
.
set_src_id
(
interceptor_id_
);
msg
.
set_dst_id
(
dst_id
);
return
MessageBus
::
Instance
().
Send
(
msg
);
return
carrier_
->
Send
(
msg
);
}
void
Interceptor
::
PoolTheMailbox
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
843435ff
...
...
@@ -36,6 +36,7 @@ class GarbageCollector;
namespace
distributed
{
class
TaskNode
;
class
Carrier
;
class
Interceptor
{
public:
...
...
@@ -77,6 +78,7 @@ class Interceptor {
void
SetGC
(
const
std
::
shared_ptr
<
framework
::
GarbageCollector
>&
gc
)
{
gc_
=
gc
;
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
...
...
@@ -100,6 +102,8 @@ class Interceptor {
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
{};
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
Carrier
*
carrier_
;
private:
// pool the local mailbox, parse the Message
void
PoolTheMailbox
();
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
843435ff
...
...
@@ -29,9 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
true
);
// call interceptor manager's method to handle the message
Carrier
::
Instance
().
EnqueueInterceptorMessage
(
*
request
);
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
843435ff
...
...
@@ -57,10 +57,6 @@ void MessageBus::Init(
bool
MessageBus
::
IsInit
()
const
{
return
is_init_
;
}
MessageBus
::~
MessageBus
()
{
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier
&
carrier
=
Carrier
::
Instance
();
carrier
.
Release
();
VLOG
(
3
)
<<
"Message bus releases resource."
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
...
...
@@ -245,7 +241,8 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
bool
MessageBus
::
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
)
{
// send the message intra rank (dst is the same rank with src)
return
Carrier
::
Instance
().
EnqueueInterceptorMessage
(
interceptor_message
);
return
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
interceptor_message
);
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
843435ff
...
...
@@ -39,10 +39,8 @@ class Carrier;
// A singleton MessageBus
class
MessageBus
final
{
public:
static
MessageBus
&
Instance
()
{
static
MessageBus
msg_bus
;
return
msg_bus
;
}
MessageBus
()
=
default
;
~
MessageBus
();
void
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
...
...
@@ -53,12 +51,8 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
const
InterceptorMessage
&
interceptor_message
);
~
MessageBus
();
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
private:
MessageBus
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
MessageBus
)
;
// function keep listen the port and handle the message
void
ListenPort
();
...
...
@@ -72,12 +66,11 @@ class MessageBus final {
bool
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
);
#endif
bool
is_init_
{
false
};
// send the message intra rank (dst is the same rank with src)
bool
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
);
bool
is_init_
{
false
};
std
::
once_flag
once_flag_
;
// handed by above layer, save the info mapping interceptor id to rank id
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -61,10 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
...
...
@@ -90,6 +93,9 @@ TEST(ComputeInterceptor, Compute) {
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -46,9 +47,12 @@ class StartInterceptor : public Interceptor {
};
TEST
(
ComputeInterceptor
,
Compute
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
...
...
@@ -74,6 +78,9 @@ TEST(ComputeInterceptor, Compute) {
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -44,6 +45,7 @@ class PingPongInterceptor : public Interceptor {
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
StopCarrier
();
return
;
}
...
...
@@ -58,10 +60,12 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
(
);
Carrier
&
carrier
=
Carrier
::
Instance
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
...
...
@@ -71,6 +75,8 @@ TEST(InterceptorTest, PingPong) {
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
843435ff
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -36,6 +37,7 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
StopCarrier
();
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
...
...
@@ -105,10 +107,12 @@ TEST(InterceptorTest, PingPong) {
int
pid
=
fork
();
if
(
pid
==
0
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
...
...
@@ -116,15 +120,19 @@ TEST(InterceptorTest, PingPong) {
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
}
else
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
carrier
.
Wait
();
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
...
...
@@ -88,6 +91,8 @@ TEST(AmplifierInterceptor, Amplifier) {
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -69,9 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
""
}},
""
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
""
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
...
...
@@ -103,6 +106,8 @@ TEST(AmplifierInterceptor, Amplifier) {
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
// namespace distributed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录