Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3658405c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3658405c
编写于
12月 30, 2021
作者:
L
LiYuRio
提交者:
GitHub
12月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fleet Executor] Support multi carrier (#38535)
上级
2421a25a
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
338 addition
and
149 deletion
+338
-149
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+42
-17
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+16
-16
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+27
-26
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+0
-11
paddle/fluid/distributed/fleet_executor/global_map.h
paddle/fluid/distributed/fleet_executor/global_map.h
+49
-0
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+10
-2
paddle/fluid/distributed/fleet_executor/runtime_graph.h
paddle/fluid/distributed/fleet_executor/runtime_graph.h
+11
-0
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+3
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+9
-8
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+9
-8
paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc
...d/fleet_executor/test/interceptor_pass_the_parcel_test.cc
+101
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+7
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+25
-36
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+16
-11
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+13
-9
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
3658405c
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -27,16 +28,32 @@ namespace distributed {
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
void
Carrier
::
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"Carrier is already init."
));
void
Carrier
::
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
)
{
rank_
=
rank
;
runtime_graph_
=
runtime_graph
;
interceptor_id_to_rank_
=
runtime_graph_
->
interceptor_id_to_rank
();
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_ids_
=
interceptor_ids
;
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
}
void
Carrier
::
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_ids_
=
interceptor_ids
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
minibatch_scope_
=
minibatch_scope
;
microbatch_scopes_
=
microbatch_scopes
;
place_
=
place
;
...
...
@@ -72,8 +89,6 @@ bool Carrier::EnqueueInterceptorMessage(
return
true
;
}
void
Carrier
::
Barrier
()
{
msg_bus_
->
Barrier
();
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
...
...
@@ -100,7 +115,8 @@ void Carrier::Start() {
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
PADDLE_ENFORCE_EQ
(
is_init_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using carrier before initialized."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
<<
"."
;
...
...
@@ -140,7 +156,9 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if
(
src_rank
==
dst_rank
)
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
EnqueueInterceptorMessage
(
msg
);
int64_t
carrier_id
=
*
GlobalMap
<
int64_t
,
int64_t
>::
Get
(
dst_id
);
return
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
msg
);
}
else
{
PADDLE_ENFORCE_NOT_NULL
(
msg_bus_
.
get
(),
...
...
@@ -174,6 +192,9 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop
,
platform
::
errors
::
Fatal
(
"thread task loop must not null"
));
interceptor
->
RegisterTaskLoop
(
loop
);
// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap
<
int64_t
,
int64_t
>::
Create
(
interceptor_id
,
carrier_id_
);
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
...
...
@@ -199,15 +220,19 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
void
Carrier
::
CreateInterceptors
()
{
if
(
runtime_graph_
->
interceptor_id_to_node
()
.
empty
())
return
;
if
(
interceptor_ids_
.
empty
())
return
;
auto
gc
=
GetGC
(
place_
);
// create each Interceptor
// no auto init since there is no config
for
(
const
auto
&
item
:
runtime_graph_
->
interceptor_id_to_node
())
{
int64_t
interceptor_id
=
item
.
first
;
TaskNode
*
task_node
=
item
.
second
;
for
(
int64_t
interceptor_id
:
interceptor_ids_
)
{
const
auto
&
task_node_iter
=
interceptor_id_to_node_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_NE
(
task_node_iter
,
interceptor_id_to_node_
.
end
(),
platform
::
errors
::
NotFound
(
"Can not find task node for interceptor %ld"
,
interceptor_id
));
TaskNode
*
task_node
=
task_node_iter
->
second
;
PADDLE_ENFORCE_LT
(
task_node
->
run_at_offset
(),
task_node
->
run_per_steps
(),
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
3658405c
...
...
@@ -45,19 +45,19 @@ class MessageBus;
class
Carrier
final
{
public:
Carrier
()
=
default
;
Carrier
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
}
explicit
Carrier
(
int64_t
carrier_id
)
:
carrier_id_
(
carrier_id
)
{}
~
Carrier
();
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
);
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
void
Release
();
void
Wait
();
...
...
@@ -83,10 +83,9 @@ class Carrier final {
bool
Send
(
const
InterceptorMessage
&
msg
);
void
Barrier
();
private:
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
Carrier
()
=
delete
;
// create each Interceptor
void
CreateInterceptors
();
...
...
@@ -108,13 +107,14 @@ class Carrier final {
framework
::
Scope
*
minibatch_scope_
;
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
int64_t
carrier_id_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int
thread_num_
;
TaskLoopThreadPool
thread_pool_
;
std
::
unordered_set
<
int64_t
>
interceptor_ids_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
3658405c
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -27,8 +28,6 @@
namespace
paddle
{
namespace
distributed
{
std
::
unique_ptr
<
Carrier
>
FleetExecutor
::
carrier_
;
FleetExecutor
::
FleetExecutor
(
const
std
::
string
&
exe_desc_str
)
{
bool
parse_flag
=
exe_desc_
.
ParseFromString
(
exe_desc_str
);
PADDLE_ENFORCE
(
parse_flag
,
platform
::
errors
::
PreconditionNotMet
(
...
...
@@ -37,13 +36,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
GetCarrier
()
->
Release
();
}
Carrier
*
FleetExecutor
::
GetCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
.
get
(),
platform
::
errors
::
NotFound
(
"Carrier has not been created."
));
return
carrier_
.
get
();
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
)
->
Release
();
}
}
void
FleetExecutor
::
Init
(
...
...
@@ -63,13 +58,19 @@ void FleetExecutor::Init(
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>
carrier_id_to_interceptor_ids
;
std
::
unordered_set
<
int64_t
>
interceptor_ids
;
for
(
auto
task_node
:
task_nodes
)
{
task_node
->
SetUnusedVars
(
unused_vars
);
int64_t
interceptor_id
=
task_node
->
task_id
();
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
interceptor_ids
.
insert
(
interceptor_id
);
}
carrier_id_to_interceptor_ids
.
emplace
(
0
,
interceptor_ids
);
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
runtime_graph_
->
SetCarrierIdToInterceptorIds
(
carrier_id_to_interceptor_ids
);
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
...
...
@@ -86,21 +87,26 @@ void FleetExecutor::Init(
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
CreateCarrier
();
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
item
.
first
,
item
.
first
);
}
InitCarrier
();
InitMessageBus
();
// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier
()
->
Barrier
();
// Wait for all message bus connected.
msg_bus_
->
Barrier
();
}
void
FleetExecutor
::
InitCarrier
()
{
if
(
!
GetCarrier
()
->
IsInit
())
{
GetCarrier
()
->
SetMsgBus
(
msg_bus_
);
GetCarrier
()
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
);
PADDLE_ENFORCE_NOT_NULL
(
carrier
,
platform
::
errors
::
InvalidArgument
(
"Carrier has not been created."
));
carrier
->
SetMsgBus
(
msg_bus_
);
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
item
.
second
,
runtime_graph_
->
interceptor_id_to_node
(),
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
}
}
...
...
@@ -140,14 +146,9 @@ void FleetExecutor::InitMessageBus() {
}
void
FleetExecutor
::
Run
()
{
// Run
PADDLE_ENFORCE_EQ
(
GetCarrier
()
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
GetCarrier
()
->
Start
();
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
)
->
Start
();
}
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
3658405c
...
...
@@ -42,16 +42,6 @@ class FleetExecutor final {
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
// TODO(liyurui): Change to use registry table for multi-carrier.
static
Carrier
*
GetCarrier
();
template
<
typename
...
Args
>
static
Carrier
*
CreateCarrier
(
Args
&&
...
args
)
{
PADDLE_ENFORCE_EQ
(
carrier_
.
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Carrier has been created already."
));
carrier_
=
std
::
make_unique
<
Carrier
>
(
std
::
forward
<
Args
>
(
args
)...);
return
carrier_
.
get
();
}
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
...
@@ -67,7 +57,6 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
static
std
::
unique_ptr
<
Carrier
>
carrier_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/global_map.h
0 → 100644
浏览文件 @
3658405c
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
distributed
{
template
<
typename
KeyT
,
typename
ValueT
>
class
GlobalMap
final
{
public:
static
ValueT
*
Get
(
KeyT
id
)
{
ValueT
*
item
=
GetPPtr
(
id
)
->
get
();
PADDLE_ENFORCE_NOT_NULL
(
item
,
platform
::
errors
::
NotFound
(
"This value is not in global map."
));
return
item
;
}
template
<
typename
...
Args
>
static
ValueT
*
Create
(
KeyT
id
,
Args
&&
...
args
)
{
auto
*
ptr
=
GetPPtr
(
id
);
PADDLE_ENFORCE_EQ
(
ptr
->
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"This value has already in global map."
));
ValueT
*
item
=
new
ValueT
(
std
::
forward
<
Args
>
(
args
)...);
ptr
->
reset
(
item
);
return
item
;
}
private:
static
std
::
unique_ptr
<
ValueT
>*
GetPPtr
(
KeyT
id
)
{
static
std
::
mutex
mutex
;
static
std
::
unordered_map
<
KeyT
,
std
::
unique_ptr
<
ValueT
>>
id_to_ptr
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex
);
return
&
id_to_ptr
[
id
];
}
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
3658405c
...
...
@@ -16,7 +16,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/
fleet_executor
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
namespace
paddle
{
namespace
distributed
{
...
...
@@ -29,7 +29,15 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
bool
flag
=
FleetExecutor
::
GetCarrier
()
->
EnqueueInterceptorMessage
(
*
request
);
// TODO(liyurui): Remove this hard code.
int64_t
carrier_id
;
if
(
request
->
ctrl_message
())
{
carrier_id
=
0
;
}
else
{
carrier_id
=
*
GlobalMap
<
int64_t
,
int64_t
>::
Get
(
request
->
dst_id
());
}
bool
flag
=
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
flag
);
}
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.h
浏览文件 @
3658405c
...
...
@@ -35,6 +35,10 @@ class RuntimeGraph final {
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
()
const
{
return
interceptor_id_to_rank_
;
}
const
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>&
carrier_id_to_interceptor_ids
()
const
{
return
carrier_id_to_interceptor_ids_
;
}
void
SetInterceptorIdToRank
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
{
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
...
...
@@ -43,12 +47,19 @@ class RuntimeGraph final {
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
)
{
interceptor_id_to_node_
=
interceptor_id_to_node
;
}
void
SetCarrierIdToInterceptorIds
(
const
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>&
carrier_id_to_interceptor_ids
)
{
carrier_id_to_interceptor_ids_
=
carrier_id_to_interceptor_ids
;
}
std
::
string
DebugString
()
const
;
private:
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>
carrier_id_to_interceptor_ids_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
浏览文件 @
3658405c
...
...
@@ -13,6 +13,9 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties
(
compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
op_registry fill_constant_op elementwise_add_op scope device_context
)
set_source_files_properties
(
interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
set_source_files_properties
(
interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
3658405c
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/
fleet_executor
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -62,11 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}},
{
0
,
1
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
...
...
@@ -77,9 +78,9 @@ TEST(ComputeInterceptor, Compute) {
node_a
->
AddDownstreamTask
(
1
);
node_b
->
AddUpstreamTask
(
0
);
auto
*
a
=
carrier
.
SetInterceptor
(
auto
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
a
->
SetPlace
(
place
);
a
->
SetMicroBatchScope
(
scopes
);
...
...
@@ -89,10 +90,10 @@ TEST(ComputeInterceptor, Compute) {
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
3658405c
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/
fleet_executor
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -47,11 +47,12 @@ class StartInterceptor : public Interceptor {
};
TEST
(
ComputeInterceptor
,
Compute
)
{
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{
0
,
1
,
2
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
...
...
@@ -65,9 +66,9 @@ TEST(ComputeInterceptor, Compute) {
node_c
->
AddUpstreamTask
(
1
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
...
...
@@ -76,8 +77,8 @@ TEST(ComputeInterceptor, Compute) {
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
carrier
.
Release
();
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc
0 → 100644
浏览文件 @
3658405c
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace
paddle
{
namespace
distributed
{
class
ParcelInterceptor
:
public
Interceptor
{
public:
ParcelInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
(
[
this
](
const
InterceptorMessage
&
msg
)
{
PassParcel
(
msg
);
});
}
void
PassParcel
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
<<
std
::
endl
;
if
(
count_
==
5
&&
interceptor_id_
==
0
)
{
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
Send
(
2
,
stop
);
Send
(
3
,
stop
);
StopCarrier
();
return
;
}
++
count_
;
InterceptorMessage
new_msg
;
if
(
msg
.
dst_id
()
==
3
)
{
Send
(
0
,
new_msg
);
}
else
{
Send
(
msg
.
dst_id
()
+
1
,
new_msg
);
}
}
private:
int
count_
{
0
};
};
REGISTER_INTERCEPTOR
(
Parcel
,
ParcelInterceptor
);
TEST
(
InterceptorTest
,
PassTheParcel
)
{
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
Carrier
*
carrier_0
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier_0
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
0
});
carrier_0
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_1
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
1
,
1
);
carrier_1
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
1
});
carrier_1
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_2
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
2
,
2
);
carrier_2
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
2
});
carrier_2
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_3
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
3
,
3
);
carrier_3
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
3
});
carrier_3
->
SetMsgBus
(
msg_bus
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
Interceptor
*
a
=
carrier_0
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Parcel"
,
0
,
nullptr
));
carrier_1
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Parcel"
,
1
,
nullptr
));
carrier_2
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Parcel"
,
2
,
nullptr
));
carrier_3
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Parcel"
,
3
,
nullptr
));
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier_0
->
Wait
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
3658405c
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -59,20 +60,21 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}},
{
0
,
1
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
.
SetInterceptor
(
1
,
std
::
make_unique
<
PingPongInterceptor
>
(
1
,
nullptr
));
carrier
->
SetInterceptor
(
1
,
std
::
make_unique
<
PingPongInterceptor
>
(
1
,
nullptr
));
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
carrier
->
Wait
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
3658405c
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/
fleet_executor
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -107,42 +107,31 @@ TEST(InterceptorTest, PingPong) {
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank
=
{{
0
,
0
},
{
1
,
1
}};
int
exe_pid
=
fork
();
if
(
exe_pid
==
0
)
{
int
pid
=
fork
();
if
(
pid
==
0
)
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
->
Barrier
();
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
}
else
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
->
Barrier
();
carrier
->
Wait
();
int
status
;
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
pid
);
}
int
pid
=
fork
();
if
(
pid
==
0
)
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
carrier
->
Init
(
0
,
interceptor_id_to_rank
,
{
0
});
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
msg_bus
->
Barrier
();
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
}
else
{
int
status
;
int
ret
=
waitpid
(
exe_pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
exe_pid
);
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
Init
(
1
,
interceptor_id_to_rank
,
{
1
});
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
msg_bus
->
Barrier
();
carrier
->
Wait
();
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
3658405c
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{
0
,
1
,
2
,
3
,
4
,
5
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
...
...
@@ -73,21 +76,23 @@ TEST(AmplifierInterceptor, Amplifier) {
node_b
->
SetReplyUpPerSteps
(
micro_steps
);
node_e
->
SetSendDownPerSteps
(
micro_steps
);
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
.
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Compute"
,
3
,
node_d
));
carrier
.
SetInterceptor
(
4
,
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
carrier
.
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Compute"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
4
,
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
carrier
->
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
3658405c
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -69,10 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
0
,
1
,
2
,
3
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
...
...
@@ -91,19 +93,21 @@ TEST(AmplifierInterceptor, Amplifier) {
node_d
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunAtOffset
(
micro_steps
-
1
);
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
.
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录