Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
60356f67
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
60356f67
编写于
4月 19, 2022
作者:
L
LiYuRio
提交者:
GitHub
4月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[FleetExecutor] Modified test cases using source and sink (#41926)
上级
771a4144
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
89 addition
and
40 deletion
+89
-40
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+7
-1
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+1
-1
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+3
-0
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
...luid/distributed/fleet_executor/interceptor_message.proto
+2
-2
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
+1
-1
paddle/fluid/distributed/fleet_executor/task_node.cc
paddle/fluid/distributed/fleet_executor/task_node.cc
+3
-0
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+1
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+15
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+19
-6
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+14
-6
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
.../distributed/fleet_executor/test/sink_interceptor_test.cc
+14
-11
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
...istributed/fleet_executor/test/source_interceptor_test.cc
+9
-7
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
60356f67
...
@@ -186,7 +186,13 @@ int64_t Carrier::GetRank(int64_t interceptor_id) const {
...
@@ -186,7 +186,13 @@ int64_t Carrier::GetRank(int64_t interceptor_id) const {
}
}
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
{
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
{
int64_t
src_id
=
(
msg
.
src_id
()
==
-
1
)
?
msg
.
dst_id
()
:
msg
.
src_id
();
int64_t
src_id
=
msg
.
src_id
();
// TODO(liyurui): compatible solution, will be removed completely in the
// future
if
(
interceptor_id_to_rank_
.
find
(
src_id
)
==
interceptor_id_to_rank_
.
end
()
&&
src_id
==
SOURCE_ID
)
{
src_id
=
msg
.
dst_id
();
}
int64_t
dst_id
=
msg
.
dst_id
();
int64_t
dst_id
=
msg
.
dst_id
();
int64_t
src_rank
=
GetRank
(
src_id
);
int64_t
src_rank
=
GetRank
(
src_id
);
int64_t
dst_rank
=
GetRank
(
dst_id
);
int64_t
dst_rank
=
GetRank
(
dst_id
);
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
60356f67
...
@@ -161,7 +161,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -161,7 +161,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" for step: "
<<
step_
;
<<
" for step: "
<<
step_
;
if
(
up_id
==
-
1
)
return
;
if
(
is_source_
&&
up_id
==
-
1
)
return
;
InterceptorMessage
reply_msg
;
InterceptorMessage
reply_msg
;
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
60356f67
...
@@ -40,6 +40,9 @@ class TaskNode;
...
@@ -40,6 +40,9 @@ class TaskNode;
class
Carrier
;
class
Carrier
;
class
TaskLoop
;
class
TaskLoop
;
constexpr
int64_t
SOURCE_ID
=
-
1
;
constexpr
int64_t
SINK_ID
=
-
2
;
class
Interceptor
{
class
Interceptor
{
public:
public:
using
MsgHandle
=
std
::
function
<
void
(
const
InterceptorMessage
&
)
>
;
using
MsgHandle
=
std
::
function
<
void
(
const
InterceptorMessage
&
)
>
;
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
浏览文件 @
60356f67
...
@@ -27,8 +27,8 @@ enum MessageType {
...
@@ -27,8 +27,8 @@ enum MessageType {
}
}
message
InterceptorMessage
{
message
InterceptorMessage
{
optional
int64
src_id
=
1
[
default
=
0
];
optional
s
int64
src_id
=
1
[
default
=
0
];
optional
int64
dst_id
=
2
[
default
=
0
];
optional
s
int64
dst_id
=
2
[
default
=
0
];
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
int64
scope_idx
=
5
[
default
=
0
];
optional
int64
scope_idx
=
5
[
default
=
0
];
...
...
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
浏览文件 @
60356f67
...
@@ -30,7 +30,7 @@ SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node)
...
@@ -30,7 +30,7 @@ SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node)
void
SinkInterceptor
::
StopCarrierIfComplete
()
{
void
SinkInterceptor
::
StopCarrierIfComplete
()
{
bool
flag
=
true
;
bool
flag
=
true
;
for
(
const
auto
&
up
:
upstream_step_
)
{
for
(
const
auto
&
up
:
upstream_step_
)
{
flag
=
flag
&
(
up
.
second
==
max_run_times_
);
flag
=
flag
&
&
(
up
.
second
==
max_run_times_
);
}
}
if
(
flag
)
{
if
(
flag
)
{
VLOG
(
3
)
<<
"Sink Interceptor is stopping carrier"
;
VLOG
(
3
)
<<
"Sink Interceptor is stopping carrier"
;
...
...
paddle/fluid/distributed/fleet_executor/task_node.cc
浏览文件 @
60356f67
...
@@ -74,6 +74,9 @@ void TaskNode::Init(bool use_feed_fetch_ops) {
...
@@ -74,6 +74,9 @@ void TaskNode::Init(bool use_feed_fetch_ops) {
}
}
}
}
TaskNode
::
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
)
:
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
)
{}
TaskNode
::
TaskNode
(
int32_t
role
,
TaskNode
::
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
...
...
paddle/fluid/distributed/fleet_executor/task_node.h
浏览文件 @
60356f67
...
@@ -32,6 +32,7 @@ namespace distributed {
...
@@ -32,6 +32,7 @@ namespace distributed {
class
TaskNode
final
{
class
TaskNode
final
{
public:
public:
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
);
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
int64_t
max_slot_nums
);
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
60356f67
...
@@ -69,32 +69,42 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -69,32 +69,42 @@ TEST(ComputeInterceptor, Compute) {
std
::
string
carrier_id
=
"0"
;
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}});
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// FIXME: don't delete, otherwise interceptor will use undefined node
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
2
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
TaskNode
*
node_a
=
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
,
0
);
// role, ops, rank, task_id
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
,
0
);
// role, ops, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
2
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
2
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
2
);
// a->b
// source->a->b->sink
source
->
AddDownstreamTask
(
0
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
);
node_a
->
AddDownstreamTask
(
1
);
node_a
->
AddDownstreamTask
(
1
);
node_b
->
AddUpstreamTask
(
0
);
node_b
->
AddUpstreamTask
(
0
);
sink
->
AddUpstreamTask
(
1
);
node_b
->
AddDownstreamTask
(
SINK_ID
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
auto
*
a
=
carrier
->
SetInterceptor
(
auto
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
a
->
SetPlace
(
place
);
a
->
SetPlace
(
place
);
a
->
SetMicroBatchScope
(
scopes
);
a
->
SetMicroBatchScope
(
scopes
);
// start
// start
InterceptorMessage
msg
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_message_type
(
START
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
SOURCE_ID
);
msg
.
set_dst_id
(
0
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
60356f67
...
@@ -55,27 +55,39 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -55,27 +55,39 @@ TEST(AmplifierInterceptor, Amplifier) {
std
::
string
carrier_id
=
"0"
;
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}});
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
int64_t
micro_steps
=
3
;
int64_t
micro_steps
=
3
;
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
,
0
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
,
0
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
,
0
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
,
0
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
//
a->b->c->d->e->f
//
source->a->b->c->d->e->f->sink
LinkNodes
({
node_a
,
node_b
,
node_c
,
node_d
,
node_e
,
node_f
});
LinkNodes
({
source
,
node_a
,
node_b
,
node_c
,
node_d
,
node_e
,
node_f
,
sink
});
// LR->b(1:3)->F->B->e(3:1)->U
// LR->b(1:3)->F->B->e(3:1)->U
node_b
->
SetReplyUpPerSteps
(
micro_steps
);
node_b
->
SetReplyUpPerSteps
(
micro_steps
);
node_e
->
SetSendDownPerSteps
(
micro_steps
);
node_e
->
SetSendDownPerSteps
(
micro_steps
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
...
@@ -84,12 +96,13 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -84,12 +96,13 @@ TEST(AmplifierInterceptor, Amplifier) {
carrier
->
SetInterceptor
(
4
,
carrier
->
SetInterceptor
(
4
,
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
carrier
->
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
carrier
->
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
// start
InterceptorMessage
msg
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_message_type
(
START
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
SOURCE_ID
);
msg
.
set_dst_id
(
0
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
carrier
->
Release
();
carrier
->
Release
();
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
60356f67
...
@@ -73,39 +73,47 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -73,39 +73,47 @@ TEST(AmplifierInterceptor, Amplifier) {
std
::
string
carrier_id
=
"0"
;
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}});
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
int64_t
micro_steps
=
6
;
int64_t
micro_steps
=
6
;
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
micro_steps
,
0
);
// role, rank, task_id
new
TaskNode
(
0
,
0
,
0
,
micro_steps
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
//
a->b->c->d
//
source->a->b->c->d->sink
// LR->F->B->U
// LR->F->B->U
LinkNodes
({
node_a
,
node_b
,
node_c
,
node_d
},
{{{
node_b
,
node_c
},
1
}});
LinkNodes
({
source
,
node_a
,
node_b
,
node_c
,
node_d
,
sink
},
{{{
node_b
,
node_c
},
1
}});
node_a
->
SetRunPerSteps
(
micro_steps
);
node_a
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunAtOffset
(
micro_steps
-
1
);
node_d
->
SetRunAtOffset
(
micro_steps
-
1
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
// start
InterceptorMessage
msg
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_message_type
(
START
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
SOURCE_ID
);
msg
.
set_dst_id
(
0
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
carrier
->
Release
();
carrier
->
Release
();
...
...
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
浏览文件 @
60356f67
...
@@ -39,10 +39,10 @@ class FakeInterceptor : public Interceptor {
...
@@ -39,10 +39,10 @@ class FakeInterceptor : public Interceptor {
<<
std
::
endl
;
<<
std
::
endl
;
InterceptorMessage
reply
;
InterceptorMessage
reply
;
reply
.
set_message_type
(
DATA_IS_USELESS
);
reply
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
-
1
,
reply
);
Send
(
SOURCE_ID
,
reply
);
InterceptorMessage
ready
;
InterceptorMessage
ready
;
ready
.
set_message_type
(
DATA_IS_READY
);
ready
.
set_message_type
(
DATA_IS_READY
);
Send
(
-
2
,
ready
);
Send
(
SINK_ID
,
ready
);
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
std
::
cout
<<
"FakeInterceptor remove result in scope "
<<
msg
.
scope_idx
()
std
::
cout
<<
"FakeInterceptor remove result in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -57,28 +57,31 @@ TEST(SourceInterceptor, Source) {
...
@@ -57,28 +57,31 @@ TEST(SourceInterceptor, Source) {
std
::
string
carrier_id
=
"0"
;
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
-
1
,
0
},
{
0
,
0
},
{
-
2
,
0
}});
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
-
1
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
-
2
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
-
1
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
node_a
->
AddDownstreamTask
(
-
2
,
1
);
node_a
->
AddDownstreamTask
(
SINK_ID
,
1
);
sink
->
AddUpstreamTask
(
0
,
1
);
sink
->
AddUpstreamTask
(
0
,
1
);
carrier
->
SetInterceptor
(
-
1
,
InterceptorFactory
::
Create
(
"Source"
,
-
1
,
source
));
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
-
2
,
InterceptorFactory
::
Create
(
"Sink"
,
-
2
,
sink
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
// start
InterceptorMessage
msg
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
-
1
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
...
...
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
浏览文件 @
60356f67
...
@@ -40,7 +40,7 @@ class FakeInterceptor : public Interceptor {
...
@@ -40,7 +40,7 @@ class FakeInterceptor : public Interceptor {
<<
std
::
endl
;
<<
std
::
endl
;
InterceptorMessage
reply
;
InterceptorMessage
reply
;
reply
.
set_message_type
(
DATA_IS_USELESS
);
reply
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
-
1
,
reply
);
Send
(
SOURCE_ID
,
reply
);
step_
++
;
step_
++
;
if
(
step_
==
node_
->
max_run_times
())
{
if
(
step_
==
node_
->
max_run_times
())
{
carrier_
->
WakeUp
();
carrier_
->
WakeUp
();
...
@@ -56,24 +56,26 @@ TEST(SourceInterceptor, Source) {
...
@@ -56,24 +56,26 @@ TEST(SourceInterceptor, Source) {
std
::
string
carrier_id
=
"0"
;
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
-
1
,
0
},
{
0
,
0
}});
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
-
1
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
-
1
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
carrier
->
SetInterceptor
(
-
1
,
InterceptorFactory
::
Create
(
"Source"
,
-
1
,
source
));
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
// start
// start
InterceptorMessage
msg
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
-
1
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录