Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
843435ff
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
843435ff
编写于
12月 17, 2021
作者:
L
LiYuRio
提交者:
GitHub
12月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Fix the problem in fleet executor stop (#38114)
上级
e3b033f9
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
140 addition
and
83 deletion
+140
-83
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+22
-11
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+12
-10
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+1
-2
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+22
-14
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+6
-1
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+6
-4
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+4
-0
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+1
-2
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+2
-5
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+5
-12
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+9
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+10
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+9
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+14
-6
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+9
-4
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+8
-3
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
843435ff
...
...
@@ -49,10 +49,11 @@ void Carrier::Release() {
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
PADDLE_ENFORCE_EQ
(
msg_bus
.
IsInit
(),
true
,
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Message bus has not been initialized."
));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
...
...
@@ -61,7 +62,7 @@ void Carrier::Release() {
stop_msg
.
set_src_id
(
-
1
);
stop_msg
.
set_dst_id
(
id
);
stop_msg
.
set_message_type
(
STOP
);
msg_bus
.
Send
(
stop_msg
);
Send
(
stop_msg
);
}
// TODO(wangxi): Maybe need a better to use thread.
...
...
@@ -113,11 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return
iter
->
second
.
get
();
}
void
Carrier
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
cond_var_
.
wait
(
lock
);
}
void
Carrier
::
Start
()
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
PADDLE_ENFORCE_EQ
(
msg_bus
.
IsInit
(),
true
,
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Message bus has not been initialized."
));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
...
...
@@ -127,11 +134,9 @@ void Carrier::Start() {
start_msg
.
set_src_id
(
-
1
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
msg_bus
.
Send
(
start_msg
);
Send
(
start_msg
);
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
cond_var_
.
wait
(
lock
);
Wait
();
dev_ctx_
->
Wait
();
}
...
...
@@ -139,6 +144,11 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
// TODO(liyurui): Move SendIntra into carrier
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
const
{
return
msg_bus_
->
Send
(
msg
);
}
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
interceptor
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
...
...
@@ -147,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id %lld has already been created! "
"The interceptor id should be unique."
,
interceptor_id
));
interceptor
->
RegisterCarrier
(
this
);
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
843435ff
...
...
@@ -40,22 +40,19 @@ namespace distributed {
class
TaskNode
;
class
InterceptorMessageServiceImpl
;
class
RuntimeGraph
;
class
MessageBus
;
// A singleton MessageBus
class
Carrier
final
{
public:
static
Carrier
&
Instance
()
{
static
Carrier
carrier
;
return
carrier
;
}
Carrier
()
=
default
;
~
Carrier
();
void
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
~
Carrier
();
void
Release
();
void
Wait
();
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
...
...
@@ -68,6 +65,9 @@ class Carrier final {
std
::
unique_ptr
<
Interceptor
>
);
void
SetCreatingFlag
(
bool
flag
);
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
msg_bus_
=
msg_bus
;
}
std
::
condition_variable
&
GetCondVar
();
...
...
@@ -75,15 +75,15 @@ class Carrier final {
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
)
const
;
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std
::
mutex
run
;
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
private:
Carrier
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
Carrier
)
;
// create each Interceptor
void
CreateInterceptors
();
...
...
@@ -110,6 +110,8 @@ class Carrier final {
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
843435ff
...
...
@@ -170,8 +170,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
void
ComputeInterceptor
::
RunOps
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_instance
.
run
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
carrier_
->
run
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
843435ff
...
...
@@ -34,7 +34,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
"Error occurs while parsing string to proto"
));
}
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
}
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
GetCarrier
().
Release
();
}
Carrier
&
FleetExecutor
::
GetCarrier
()
{
static
Carrier
carrier
;
return
carrier
;
}
void
FleetExecutor
::
Init
(
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
...
...
@@ -78,15 +86,17 @@ void FleetExecutor::Init(
CopyParameters
(
i
,
program_desc
);
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
InitCarrier
();
InitMessageBus
();
}
void
FleetExecutor
::
InitCarrier
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
if
(
!
carrier_instance
.
IsInit
())
{
carrier_instance
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
Carrier
&
carrier
=
GetCarrier
();
if
(
!
carrier
.
IsInit
())
{
carrier
.
SetMsgBus
(
msg_bus_
);
carrier
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
}
}
...
...
@@ -120,24 +130,22 @@ void FleetExecutor::InitMessageBus() {
VLOG
(
3
)
<<
"The number of ranks are "
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
VLOG
(
5
)
<<
ss
.
str
();
MessageBus
&
message_bus_instance
=
MessageBus
::
Instance
();
if
(
!
message_bus_instance
.
IsInit
())
{
message_bus_instance
.
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
rank_to_addr
,
addr
);
if
(
!
msg_bus_
->
IsInit
())
{
msg_bus_
->
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
rank_to_addr
,
addr
);
}
}
void
FleetExecutor
::
Run
()
{
// Run
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
MessageBus
&
message_bus_instance
=
MessageBus
::
Instance
();
Carrier
&
carrier
=
GetCarrier
();
PADDLE_ENFORCE_EQ
(
carrier
_instance
.
IsInit
(),
true
,
carrier
.
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
PADDLE_ENFORCE_EQ
(
m
essage_bus_instance
.
IsInit
(),
true
,
m
sg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
carrier
_instance
.
Start
();
carrier
.
Start
();
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
843435ff
...
...
@@ -28,9 +28,9 @@ class Scope;
namespace
distributed
{
class
RuntimeGraph
;
class
Carrier
;
class
MessageBus
;
class
TaskNode
;
class
Carrier
;
class
FleetExecutor
final
{
public:
...
...
@@ -42,6 +42,8 @@ class FleetExecutor final {
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
// TODO(liyurui): Change to use registry table for multi-carrier.
static
Carrier
&
GetCarrier
();
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
...
@@ -54,6 +56,9 @@ class FleetExecutor final {
framework
::
Scope
*
minibatch_scope_
;
platform
::
Place
place_
;
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
;
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
843435ff
...
...
@@ -14,7 +14,6 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
...
...
@@ -46,8 +45,9 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
}
void
Interceptor
::
StopCarrier
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
std
::
condition_variable
&
cond_var
=
carrier_instance
.
GetCondVar
();
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
std
::
condition_variable
&
cond_var
=
carrier_
->
GetCondVar
();
// probably double notify, but ok for ut
cond_var
.
notify_all
();
}
...
...
@@ -73,9 +73,11 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
msg
.
set_src_id
(
interceptor_id_
);
msg
.
set_dst_id
(
dst_id
);
return
MessageBus
::
Instance
().
Send
(
msg
);
return
carrier_
->
Send
(
msg
);
}
void
Interceptor
::
PoolTheMailbox
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
843435ff
...
...
@@ -36,6 +36,7 @@ class GarbageCollector;
namespace
distributed
{
class
TaskNode
;
class
Carrier
;
class
Interceptor
{
public:
...
...
@@ -77,6 +78,7 @@ class Interceptor {
void
SetGC
(
const
std
::
shared_ptr
<
framework
::
GarbageCollector
>&
gc
)
{
gc_
=
gc
;
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
...
...
@@ -100,6 +102,8 @@ class Interceptor {
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
{};
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
Carrier
*
carrier_
;
private:
// pool the local mailbox, parse the Message
void
PoolTheMailbox
();
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
843435ff
...
...
@@ -29,9 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
true
);
// call interceptor manager's method to handle the message
Carrier
::
Instance
().
EnqueueInterceptorMessage
(
*
request
);
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
843435ff
...
...
@@ -57,10 +57,6 @@ void MessageBus::Init(
bool
MessageBus
::
IsInit
()
const
{
return
is_init_
;
}
MessageBus
::~
MessageBus
()
{
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier
&
carrier
=
Carrier
::
Instance
();
carrier
.
Release
();
VLOG
(
3
)
<<
"Message bus releases resource."
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
...
...
@@ -245,7 +241,8 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
bool
MessageBus
::
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
)
{
// send the message intra rank (dst is the same rank with src)
return
Carrier
::
Instance
().
EnqueueInterceptorMessage
(
interceptor_message
);
return
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
interceptor_message
);
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
843435ff
...
...
@@ -39,10 +39,8 @@ class Carrier;
// A singleton MessageBus
class
MessageBus
final
{
public:
static
MessageBus
&
Instance
()
{
static
MessageBus
msg_bus
;
return
msg_bus
;
}
MessageBus
()
=
default
;
~
MessageBus
();
void
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
...
...
@@ -53,12 +51,8 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
const
InterceptorMessage
&
interceptor_message
);
~
MessageBus
();
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
private:
MessageBus
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
MessageBus
)
;
// function keep listen the port and handle the message
void
ListenPort
();
...
...
@@ -72,12 +66,11 @@ class MessageBus final {
bool
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
);
#endif
bool
is_init_
{
false
};
// send the message intra rank (dst is the same rank with src)
bool
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
);
bool
is_init_
{
false
};
std
::
once_flag
once_flag_
;
// handed by above layer, save the info mapping interceptor id to rank id
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -61,10 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
...
...
@@ -90,6 +93,9 @@ TEST(ComputeInterceptor, Compute) {
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -46,9 +47,12 @@ class StartInterceptor : public Interceptor {
};
TEST
(
ComputeInterceptor
,
Compute
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
...
...
@@ -74,6 +78,9 @@ TEST(ComputeInterceptor, Compute) {
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -44,6 +45,7 @@ class PingPongInterceptor : public Interceptor {
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
StopCarrier
();
return
;
}
...
...
@@ -58,10 +60,12 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
(
);
Carrier
&
carrier
=
Carrier
::
Instance
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
...
...
@@ -71,6 +75,8 @@ TEST(InterceptorTest, PingPong) {
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
843435ff
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -36,6 +37,7 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
StopCarrier
();
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
...
...
@@ -105,10 +107,12 @@ TEST(InterceptorTest, PingPong) {
int
pid
=
fork
();
if
(
pid
==
0
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
...
...
@@ -116,15 +120,19 @@ TEST(InterceptorTest, PingPong) {
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
}
else
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
carrier
.
Wait
();
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
...
...
@@ -88,6 +91,8 @@ TEST(AmplifierInterceptor, Amplifier) {
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
843435ff
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -69,9 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
""
}},
""
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
""
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
...
...
@@ -103,6 +106,8 @@ TEST(AmplifierInterceptor, Amplifier) {
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
}
}
// namespace distributed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录