Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0adc2006
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0adc2006
编写于
12月 01, 2021
作者:
Y
Yuang Liu
提交者:
GitHub
12月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] auto STOP msg and auto notify carrier (#37742)
上级
79095918
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
92 addition
and
60 deletion
+92
-60
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+40
-9
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+8
-9
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+14
-8
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+1
-0
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+3
-3
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+1
-1
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+10
-6
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+4
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+4
-11
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+7
-13
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
0adc2006
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -24,14 +25,14 @@ namespace distributed {
...
@@ -24,14 +25,14 @@ namespace distributed {
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Compute
);
void
Carrier
::
Init
(
void
Carrier
::
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_nod
e
,
framework
::
Scope
*
root_scop
e
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"Carrier is already init."
));
"Carrier is already init."
));
interceptor_id_to_node_
=
interceptor_id_to_node
;
runtime_graph_
=
runtime_graph
;
minibatch_scope_
=
minibatch_scope
;
minibatch_scope_
=
minibatch_scope
;
microbatch_scopes_
=
microbatch_scopes
;
microbatch_scopes_
=
microbatch_scopes
;
place_
=
place
;
place_
=
place
;
...
@@ -41,15 +42,34 @@ void Carrier::Init(
...
@@ -41,15 +42,34 @@ void Carrier::Init(
is_init_
=
true
;
is_init_
=
true
;
}
}
Carrier
::~
Carrier
()
{
void
Carrier
::
Release
()
{
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
PADDLE_ENFORCE_EQ
(
msg_bus
.
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Message bus has not been initialized."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
InterceptorMessage
stop_msg
;
// source node STOP is send by carrier, so set src_id=-1
stop_msg
.
set_src_id
(
-
1
);
stop_msg
.
set_dst_id
(
id
);
stop_msg
.
set_message_type
(
STOP
);
msg_bus
.
Send
(
stop_msg
);
}
// TODO(wangxi): Maybe need a better to use thread.
// TODO(wangxi): Maybe need a better to use thread.
for
(
auto
&
interceptor
:
interceptor_idx_to_interceptor_
)
{
for
(
auto
&
interceptor
:
interceptor_idx_to_interceptor_
)
{
interceptor
.
second
->
Join
();
interceptor
.
second
->
Join
();
}
}
}
}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
bool
Carrier
::
EnqueueInterceptorMessage
(
bool
Carrier
::
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
const
InterceptorMessage
&
interceptor_message
)
{
// enqueue message to interceptor
// enqueue message to interceptor
...
@@ -139,6 +159,17 @@ void Carrier::SetCreatingFlag(bool flag) {
...
@@ -139,6 +159,17 @@ void Carrier::SetCreatingFlag(bool flag) {
creating_interceptors_
=
flag
;
creating_interceptors_
=
flag
;
creating_flag_mutex_
.
unlock
();
creating_flag_mutex_
.
unlock
();
if
(
!
flag
)
{
if
(
!
flag
)
{
for
(
auto
&
pair
:
interceptor_idx_to_interceptor_
)
{
// update the source interceptor id
if
(
std
::
find
(
source_interceptor_ids_
.
begin
(),
source_interceptor_ids_
.
end
(),
pair
.
first
)
==
source_interceptor_ids_
.
end
())
{
auto
task
=
pair
.
second
->
GetTaskNode
();
if
(
task
!=
nullptr
&&
task
->
upstream
().
empty
())
{
source_interceptor_ids_
.
emplace_back
(
pair
.
first
);
}
}
}
// finish create interceptors outside, handle tmp messsages
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages
();
HandleTmpMessages
();
}
}
...
@@ -161,9 +192,9 @@ void Carrier::HandleTmpMessages() {
...
@@ -161,9 +192,9 @@ void Carrier::HandleTmpMessages() {
void
Carrier
::
CreateInterceptors
()
{
void
Carrier
::
CreateInterceptors
()
{
// create each Interceptor
// create each Interceptor
if
(
!
interceptor_id_to_node_
.
empty
(
))
{
if
(
!
(
runtime_graph_
->
intercepter_id_to_node
().
empty
()
))
{
// no auto init since there is no config
// no auto init since there is no config
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
for
(
const
auto
&
item
:
runtime_graph_
->
intercepter_id_to_node
()
)
{
int64_t
interceptor_id
=
item
.
first
;
int64_t
interceptor_id
=
item
.
first
;
TaskNode
*
task_node
=
item
.
second
;
TaskNode
*
task_node
=
item
.
second
;
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
0adc2006
...
@@ -39,6 +39,7 @@ namespace distributed {
...
@@ -39,6 +39,7 @@ namespace distributed {
class
TaskNode
;
class
TaskNode
;
class
InterceptorMessageServiceImpl
;
class
InterceptorMessageServiceImpl
;
class
RuntimeGraph
;
// A singleton MessageBus
// A singleton MessageBus
class
Carrier
final
{
class
Carrier
final
{
...
@@ -48,13 +49,13 @@ class Carrier final {
...
@@ -48,13 +49,13 @@ class Carrier final {
return
carrier
;
return
carrier
;
}
}
void
Init
(
void
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
const
platform
::
Place
&
place
);
~
Carrier
();
~
Carrier
();
void
Release
();
// Enqueue a message to corresponding interceptor id
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
...
@@ -84,9 +85,6 @@ class Carrier final {
...
@@ -84,9 +85,6 @@ class Carrier final {
void
HandleTmpMessages
();
void
HandleTmpMessages
();
// interceptor logic id to the Nodes info
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
// interceptor logic id to actually interceptor
// interceptor logic id to actually interceptor
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
interceptor_idx_to_interceptor_
;
...
@@ -105,7 +103,8 @@ class Carrier final {
...
@@ -105,7 +103,8 @@ class Carrier final {
framework
::
Scope
*
root_scope_
;
framework
::
Scope
*
root_scope_
;
framework
::
Scope
*
minibatch_scope_
;
framework
::
Scope
*
minibatch_scope_
;
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
=
nullptr
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
0adc2006
...
@@ -51,6 +51,11 @@ void ComputeInterceptor::PrepareDeps() {
...
@@ -51,6 +51,11 @@ void ComputeInterceptor::PrepareDeps() {
"times, but now max_run_times=%ld"
,
"times, but now max_run_times=%ld"
,
node_
->
max_run_times
()));
node_
->
max_run_times
()));
}
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_
=
downstream
.
empty
();
}
}
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
)
{
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
)
{
...
@@ -129,7 +134,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
...
@@ -129,7 +134,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage
ready_msg
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_message_type
(
DATA_IS_READY
);
VLOG
(
3
)
<<
"ComputeInterceptor Send data_is_ready msg to "
<<
down_id
;
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Send data_is_ready msg to "
<<
down_id
;
Send
(
down_id
,
ready_msg
);
Send
(
down_id
,
ready_msg
);
}
}
}
}
...
@@ -148,7 +154,8 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -148,7 +154,8 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
InterceptorMessage
reply_msg
;
InterceptorMessage
reply_msg
;
reply_msg
.
set_message_type
(
DATE_IS_USELESS
);
reply_msg
.
set_message_type
(
DATE_IS_USELESS
);
VLOG
(
3
)
<<
"ComputeInterceptor Reply data_is_useless msg to "
<<
up_id
;
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Reply data_is_useless msg to "
<<
up_id
;
Send
(
up_id
,
reply_msg
);
Send
(
up_id
,
reply_msg
);
}
}
}
}
...
@@ -159,7 +166,7 @@ void ComputeInterceptor::Run() {
...
@@ -159,7 +166,7 @@ void ComputeInterceptor::Run() {
// step_ %= node_->max_run_times();
// step_ %= node_->max_run_times();
for
(
auto
op
:
node_
->
ops
())
{
for
(
auto
op
:
node_
->
ops
())
{
auto
*
scope
=
microbatch_scopes_
[
step_
%
node_
->
max_
slot_num
s
()];
auto
*
scope
=
microbatch_scopes_
[
step_
%
node_
->
max_
run_time
s
()];
op
->
Run
(
*
scope
,
place_
);
op
->
Run
(
*
scope
,
place_
);
}
}
++
step_
;
++
step_
;
...
@@ -168,6 +175,10 @@ void ComputeInterceptor::Run() {
...
@@ -168,6 +175,10 @@ void ComputeInterceptor::Run() {
SendDataReadyToDownStream
();
SendDataReadyToDownStream
();
// reply to upstream and decrease ready data
// reply to upstream and decrease ready data
ReplyCompletedToUpStream
();
ReplyCompletedToUpStream
();
// Try to stop Carrier
if
(
step_
%
node_
->
max_run_times
()
==
0
&&
is_last_
)
{
StopCarrier
();
}
}
}
// If there is no limit, source interceptor can be executed
// If there is no limit, source interceptor can be executed
...
@@ -221,11 +232,6 @@ void ComputeInterceptor::TryStop() {
...
@@ -221,11 +232,6 @@ void ComputeInterceptor::TryStop() {
Send
(
down_id
,
stop
);
Send
(
down_id
,
stop
);
}
}
stop_
=
true
;
stop_
=
true
;
if
(
out_buffs_
.
size
()
==
0
)
{
// TODO(fleet executor dev) need a better place to notify
StopCarrier
();
}
}
}
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
浏览文件 @
0adc2006
...
@@ -44,6 +44,7 @@ class ComputeInterceptor : public Interceptor {
...
@@ -44,6 +44,7 @@ class ComputeInterceptor : public Interceptor {
private:
private:
bool
is_source_
{
false
};
bool
is_source_
{
false
};
bool
is_last_
{
false
};
int64_t
step_
{
0
};
int64_t
step_
{
0
};
// upstream_id-->(max_ready_size, ready_size)
// upstream_id-->(max_ready_size, ready_size)
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
0adc2006
...
@@ -38,7 +38,7 @@ FleetExecutor::~FleetExecutor() {
...
@@ -38,7 +38,7 @@ FleetExecutor::~FleetExecutor() {
void
FleetExecutor
::
Init
(
const
framework
::
ProgramDesc
&
program_desc
,
void
FleetExecutor
::
Init
(
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
runtime_graph_
=
std
::
make_
unique
<
RuntimeGraph
>
(
program_desc
,
exe_desc_
);
runtime_graph_
=
std
::
make_
shared
<
RuntimeGraph
>
(
program_desc
,
exe_desc_
);
root_scope_
=
scope
;
root_scope_
=
scope
;
place_
=
place
;
place_
=
place
;
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
platform
::
errors
::
InvalidArgument
(
...
@@ -58,8 +58,8 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
...
@@ -58,8 +58,8 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
void
FleetExecutor
::
InitCarrier
()
{
void
FleetExecutor
::
InitCarrier
()
{
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
Carrier
&
carrier_instance
=
Carrier
::
Instance
();
if
(
!
carrier_instance
.
IsInit
())
{
if
(
!
carrier_instance
.
IsInit
())
{
carrier_instance
.
Init
(
runtime_graph_
->
intercepter_id_to_node
(),
root
_scope_
,
carrier_instance
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch
_scope_
,
mi
nibatch_scope_
,
mi
crobatch_scopes_
,
place_
);
microbatch_scopes_
,
place_
);
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
0adc2006
...
@@ -47,7 +47,7 @@ class FleetExecutor final {
...
@@ -47,7 +47,7 @@ class FleetExecutor final {
void
InitCarrier
();
void
InitCarrier
();
void
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
);
void
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
);
FleetExecutorDesc
exe_desc_
;
FleetExecutorDesc
exe_desc_
;
std
::
unique
_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared
_ptr
<
RuntimeGraph
>
runtime_graph_
;
framework
::
Scope
*
root_scope_
;
framework
::
Scope
*
root_scope_
;
framework
::
Scope
*
minibatch_scope_
;
framework
::
Scope
*
minibatch_scope_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
0adc2006
...
@@ -46,7 +46,6 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
...
@@ -46,7 +46,6 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
VLOG
(
3
)
<<
"Interceptor is using default message handler. This handler is "
VLOG
(
3
)
<<
"Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor "
"only used for test purpose. Check whether you init interceptor "
"in the proper way."
;
"in the proper way."
;
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
if
(
node_
->
role
()
!=
2
)
{
if
(
node_
->
role
()
!=
2
)
{
VLOG
(
3
)
<<
"Fake handler is sending DATA_IS_READY message to: "
VLOG
(
3
)
<<
"Fake handler is sending DATA_IS_READY message to: "
...
@@ -54,14 +53,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
...
@@ -54,14 +53,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
InterceptorMessage
data_is_ready_msg
;
InterceptorMessage
data_is_ready_msg
;
data_is_ready_msg
.
set_message_type
(
DATA_IS_READY
);
data_is_ready_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
interceptor_id_
+
1
,
data_is_ready_msg
);
Send
(
interceptor_id_
+
1
,
data_is_ready_msg
);
}
else
{
// NOTE: max run time is reach for last interceptor
StopCarrier
();
}
}
VLOG
(
3
)
<<
"Fake handler is sending stop message to it self."
;
InterceptorMessage
stop_msg
;
stop_msg
.
set_message_type
(
STOP
);
Send
(
interceptor_id_
,
stop_msg
);
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
stop_
=
true
;
StopCarrier
();
if
(
node_
->
role
()
!=
2
)
{
VLOG
(
3
)
<<
"Fake handler is sending STOP message to: "
<<
interceptor_id_
+
1
<<
"."
;
InterceptorMessage
stop_msg
;
stop_msg
.
set_message_type
(
STOP
);
Send
(
interceptor_id_
+
1
,
stop_msg
);
}
}
}
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
0adc2006
...
@@ -57,6 +57,10 @@ bool MessageBus::IsInit() const { return is_init_; }
...
@@ -57,6 +57,10 @@ bool MessageBus::IsInit() const { return is_init_; }
MessageBus
::~
MessageBus
()
{
MessageBus
::~
MessageBus
()
{
VLOG
(
3
)
<<
"Message bus releases resource."
;
VLOG
(
3
)
<<
"Message bus releases resource."
;
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier
&
carrier
=
Carrier
::
Instance
();
carrier
.
Release
();
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
!defined(PADDLE_WITH_ASCEND_CL)
server_
.
Stop
(
1000
);
server_
.
Stop
(
1000
);
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
0adc2006
...
@@ -61,15 +61,15 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -61,15 +61,15 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
platform
::
Place
place
=
platform
::
CPUPlace
();
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// FIXME: don't delete, otherwise interceptor will use undefined node
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
TaskNode
*
node_a
=
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
,
2
);
// role, ops, rank, task_id
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
,
0
);
// role, ops, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
0
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
2
,
0
);
// a->b
// a->b
node_a
->
AddDownstreamTask
(
1
);
node_a
->
AddDownstreamTask
(
1
);
...
@@ -90,13 +90,6 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -90,13 +90,6 @@ TEST(ComputeInterceptor, Compute) {
msg
.
set_src_id
(
-
1
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
// stop
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
stop
.
set_src_id
(
-
1
);
stop
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
stop
);
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
0adc2006
...
@@ -35,31 +35,25 @@ class StartInterceptor : public Interceptor {
...
@@ -35,31 +35,25 @@ class StartInterceptor : public Interceptor {
void
NOP
(
const
InterceptorMessage
&
msg
)
{
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
stop_
=
true
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
1
,
stop
);
// stop 1, compute
return
;
return
;
}
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
<<
std
::
endl
;
++
count_
;
if
(
count_
==
3
)
{
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
msg
.
dst_id
(),
stop
);
// stop 0, this
Send
(
msg
.
src_id
(),
stop
);
// stop 1, compute
}
}
}
int
count_
{
0
};
};
};
TEST
(
ComputeInterceptor
,
Compute
)
{
TEST
(
ComputeInterceptor
,
Compute
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
0
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
0
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
0
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
// a->b->c
// a->b->c
node_a
->
AddDownstreamTask
(
1
);
node_a
->
AddDownstreamTask
(
1
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录