Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3658405c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3658405c
编写于
12月 30, 2021
作者:
L
LiYuRio
提交者:
GitHub
12月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fleet Executor] Support multi carrier (#38535)
上级
2421a25a
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
338 addition
and
149 deletion
+338
-149
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+42
-17
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+16
-16
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+27
-26
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+0
-11
paddle/fluid/distributed/fleet_executor/global_map.h
paddle/fluid/distributed/fleet_executor/global_map.h
+49
-0
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+10
-2
paddle/fluid/distributed/fleet_executor/runtime_graph.h
paddle/fluid/distributed/fleet_executor/runtime_graph.h
+11
-0
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+3
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+9
-8
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+9
-8
paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc
...d/fleet_executor/test/interceptor_pass_the_parcel_test.cc
+101
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+7
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+25
-36
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+16
-11
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+13
-9
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
3658405c
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -27,16 +28,32 @@ namespace distributed {
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
void
Carrier
::
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"Carrier is already init."
));
void
Carrier
::
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
)
{
rank_
=
rank
;
runtime_graph_
=
runtime_graph
;
interceptor_id_to_rank_
=
runtime_graph_
->
interceptor_id_to_rank
();
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_ids_
=
interceptor_ids
;
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
}
void
Carrier
::
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_ids_
=
interceptor_ids
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
minibatch_scope_
=
minibatch_scope
;
microbatch_scopes_
=
microbatch_scopes
;
place_
=
place
;
...
...
@@ -72,8 +89,6 @@ bool Carrier::EnqueueInterceptorMessage(
return
true
;
}
void
Carrier
::
Barrier
()
{
msg_bus_
->
Barrier
();
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
...
...
@@ -100,7 +115,8 @@ void Carrier::Start() {
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
PADDLE_ENFORCE_EQ
(
is_init_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using carrier before initialized."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
<<
"."
;
...
...
@@ -140,7 +156,9 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if
(
src_rank
==
dst_rank
)
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
EnqueueInterceptorMessage
(
msg
);
int64_t
carrier_id
=
*
GlobalMap
<
int64_t
,
int64_t
>::
Get
(
dst_id
);
return
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
msg
);
}
else
{
PADDLE_ENFORCE_NOT_NULL
(
msg_bus_
.
get
(),
...
...
@@ -174,6 +192,9 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop
,
platform
::
errors
::
Fatal
(
"thread task loop must not null"
));
interceptor
->
RegisterTaskLoop
(
loop
);
// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap
<
int64_t
,
int64_t
>::
Create
(
interceptor_id
,
carrier_id_
);
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
...
...
@@ -199,15 +220,19 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
void
Carrier
::
CreateInterceptors
()
{
if
(
runtime_graph_
->
interceptor_id_to_node
()
.
empty
())
return
;
if
(
interceptor_ids_
.
empty
())
return
;
auto
gc
=
GetGC
(
place_
);
// create each Interceptor
// no auto init since there is no config
for
(
const
auto
&
item
:
runtime_graph_
->
interceptor_id_to_node
())
{
int64_t
interceptor_id
=
item
.
first
;
TaskNode
*
task_node
=
item
.
second
;
for
(
int64_t
interceptor_id
:
interceptor_ids_
)
{
const
auto
&
task_node_iter
=
interceptor_id_to_node_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_NE
(
task_node_iter
,
interceptor_id_to_node_
.
end
(),
platform
::
errors
::
NotFound
(
"Can not find task node for interceptor %ld"
,
interceptor_id
));
TaskNode
*
task_node
=
task_node_iter
->
second
;
PADDLE_ENFORCE_LT
(
task_node
->
run_at_offset
(),
task_node
->
run_per_steps
(),
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
3658405c
...
...
@@ -45,19 +45,19 @@ class MessageBus;
class
Carrier
final
{
public:
Carrier
()
=
default
;
Carrier
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
}
explicit
Carrier
(
int64_t
carrier_id
)
:
carrier_id_
(
carrier_id
)
{}
~
Carrier
();
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
);
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
void
Release
();
void
Wait
();
...
...
@@ -83,10 +83,9 @@ class Carrier final {
bool
Send
(
const
InterceptorMessage
&
msg
);
void
Barrier
();
private:
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
Carrier
()
=
delete
;
// create each Interceptor
void
CreateInterceptors
();
...
...
@@ -108,13 +107,14 @@ class Carrier final {
framework
::
Scope
*
minibatch_scope_
;
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
int64_t
carrier_id_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int
thread_num_
;
TaskLoopThreadPool
thread_pool_
;
std
::
unordered_set
<
int64_t
>
interceptor_ids_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
3658405c
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -27,8 +28,6 @@
namespace
paddle
{
namespace
distributed
{
std
::
unique_ptr
<
Carrier
>
FleetExecutor
::
carrier_
;
FleetExecutor
::
FleetExecutor
(
const
std
::
string
&
exe_desc_str
)
{
bool
parse_flag
=
exe_desc_
.
ParseFromString
(
exe_desc_str
);
PADDLE_ENFORCE
(
parse_flag
,
platform
::
errors
::
PreconditionNotMet
(
...
...
@@ -37,13 +36,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
GetCarrier
()
->
Release
();
}
Carrier
*
FleetExecutor
::
GetCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
.
get
(),
platform
::
errors
::
NotFound
(
"Carrier has not been created."
));
return
carrier_
.
get
();
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
)
->
Release
();
}
}
void
FleetExecutor
::
Init
(
...
...
@@ -63,13 +58,19 @@ void FleetExecutor::Init(
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>
carrier_id_to_interceptor_ids
;
std
::
unordered_set
<
int64_t
>
interceptor_ids
;
for
(
auto
task_node
:
task_nodes
)
{
task_node
->
SetUnusedVars
(
unused_vars
);
int64_t
interceptor_id
=
task_node
->
task_id
();
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
interceptor_ids
.
insert
(
interceptor_id
);
}
carrier_id_to_interceptor_ids
.
emplace
(
0
,
interceptor_ids
);
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
runtime_graph_
->
SetCarrierIdToInterceptorIds
(
carrier_id_to_interceptor_ids
);
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
...
...
@@ -86,21 +87,26 @@ void FleetExecutor::Init(
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
CreateCarrier
();
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
item
.
first
,
item
.
first
);
}
InitCarrier
();
InitMessageBus
();
// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier
()
->
Barrier
();
// Wait for all message bus connected.
msg_bus_
->
Barrier
();
}
void
FleetExecutor
::
InitCarrier
()
{
if
(
!
GetCarrier
()
->
IsInit
())
{
GetCarrier
()
->
SetMsgBus
(
msg_bus_
);
GetCarrier
()
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
);
PADDLE_ENFORCE_NOT_NULL
(
carrier
,
platform
::
errors
::
InvalidArgument
(
"Carrier has not been created."
));
carrier
->
SetMsgBus
(
msg_bus_
);
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
item
.
second
,
runtime_graph_
->
interceptor_id_to_node
(),
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
}
}
...
...
@@ -140,14 +146,9 @@ void FleetExecutor::InitMessageBus() {
}
void
FleetExecutor
::
Run
()
{
// Run
PADDLE_ENFORCE_EQ
(
GetCarrier
()
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
GetCarrier
()
->
Start
();
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
)
->
Start
();
}
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
3658405c
...
...
@@ -42,16 +42,6 @@ class FleetExecutor final {
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
// TODO(liyurui): Change to use registry table for multi-carrier.
static
Carrier
*
GetCarrier
();
template
<
typename
...
Args
>
static
Carrier
*
CreateCarrier
(
Args
&&
...
args
)
{
PADDLE_ENFORCE_EQ
(
carrier_
.
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Carrier has been created already."
));
carrier_
=
std
::
make_unique
<
Carrier
>
(
std
::
forward
<
Args
>
(
args
)...);
return
carrier_
.
get
();
}
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
...
@@ -67,7 +57,6 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
static
std
::
unique_ptr
<
Carrier
>
carrier_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/global_map.h
0 → 100644
浏览文件 @
3658405c
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
distributed
{
template
<
typename
KeyT
,
typename
ValueT
>
class
GlobalMap
final
{
public:
static
ValueT
*
Get
(
KeyT
id
)
{
ValueT
*
item
=
GetPPtr
(
id
)
->
get
();
PADDLE_ENFORCE_NOT_NULL
(
item
,
platform
::
errors
::
NotFound
(
"This value is not in global map."
));
return
item
;
}
template
<
typename
...
Args
>
static
ValueT
*
Create
(
KeyT
id
,
Args
&&
...
args
)
{
auto
*
ptr
=
GetPPtr
(
id
);
PADDLE_ENFORCE_EQ
(
ptr
->
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"This value has already in global map."
));
ValueT
*
item
=
new
ValueT
(
std
::
forward
<
Args
>
(
args
)...);
ptr
->
reset
(
item
);
return
item
;
}
private:
static
std
::
unique_ptr
<
ValueT
>*
GetPPtr
(
KeyT
id
)
{
static
std
::
mutex
mutex
;
static
std
::
unordered_map
<
KeyT
,
std
::
unique_ptr
<
ValueT
>>
id_to_ptr
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex
);
return
&
id_to_ptr
[
id
];
}
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
3658405c
...
...
@@ -16,7 +16,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/
fleet_executor
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
namespace
paddle
{
namespace
distributed
{
...
...
@@ -29,7 +29,15 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
bool
flag
=
FleetExecutor
::
GetCarrier
()
->
EnqueueInterceptorMessage
(
*
request
);
// TODO(liyurui): Remove this hard code.
int64_t
carrier_id
;
if
(
request
->
ctrl_message
())
{
carrier_id
=
0
;
}
else
{
carrier_id
=
*
GlobalMap
<
int64_t
,
int64_t
>::
Get
(
request
->
dst_id
());
}
bool
flag
=
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
flag
);
}
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.h
浏览文件 @
3658405c
...
...
@@ -35,6 +35,10 @@ class RuntimeGraph final {
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
()
const
{
return
interceptor_id_to_rank_
;
}
const
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>&
carrier_id_to_interceptor_ids
()
const
{
return
carrier_id_to_interceptor_ids_
;
}
void
SetInterceptorIdToRank
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
{
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
...
...
@@ -43,12 +47,19 @@ class RuntimeGraph final {
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
)
{
interceptor_id_to_node_
=
interceptor_id_to_node
;
}
void
SetCarrierIdToInterceptorIds
(
const
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>&
carrier_id_to_interceptor_ids
)
{
carrier_id_to_interceptor_ids_
=
carrier_id_to_interceptor_ids
;
}
std
::
string
DebugString
()
const
;
private:
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>
carrier_id_to_interceptor_ids_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
浏览文件 @
3658405c
...
...
@@ -13,6 +13,9 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties
(
compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
op_registry fill_constant_op elementwise_add_op scope device_context
)
set_source_files_properties
(
interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
set_source_files_properties
(
interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
3658405c
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/
fleet_executor
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -62,11 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}},
{
0
,
1
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
...
...
@@ -77,9 +78,9 @@ TEST(ComputeInterceptor, Compute) {
node_a
->
AddDownstreamTask
(
1
);
node_b
->
AddUpstreamTask
(
0
);
auto
*
a
=
carrier
.
SetInterceptor
(
auto
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
a
->
SetPlace
(
place
);
a
->
SetMicroBatchScope
(
scopes
);
...
...
@@ -89,10 +90,10 @@ TEST(ComputeInterceptor, Compute) {
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
3658405c
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/
fleet_executor
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -47,11 +47,12 @@ class StartInterceptor : public Interceptor {
};
TEST
(
ComputeInterceptor
,
Compute
)
{
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{
0
,
1
,
2
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
...
...
@@ -65,9 +66,9 @@ TEST(ComputeInterceptor, Compute) {
node_c
->
AddUpstreamTask
(
1
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
...
...
@@ -76,8 +77,8 @@ TEST(ComputeInterceptor, Compute) {
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
carrier
.
Release
();
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc
0 → 100644
浏览文件 @
3658405c
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace
paddle
{
namespace
distributed
{
class
ParcelInterceptor
:
public
Interceptor
{
public:
ParcelInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
(
[
this
](
const
InterceptorMessage
&
msg
)
{
PassParcel
(
msg
);
});
}
void
PassParcel
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
<<
std
::
endl
;
if
(
count_
==
5
&&
interceptor_id_
==
0
)
{
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
Send
(
2
,
stop
);
Send
(
3
,
stop
);
StopCarrier
();
return
;
}
++
count_
;
InterceptorMessage
new_msg
;
if
(
msg
.
dst_id
()
==
3
)
{
Send
(
0
,
new_msg
);
}
else
{
Send
(
msg
.
dst_id
()
+
1
,
new_msg
);
}
}
private:
int
count_
{
0
};
};
REGISTER_INTERCEPTOR
(
Parcel
,
ParcelInterceptor
);
TEST
(
InterceptorTest
,
PassTheParcel
)
{
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
Carrier
*
carrier_0
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier_0
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
0
});
carrier_0
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_1
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
1
,
1
);
carrier_1
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
1
});
carrier_1
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_2
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
2
,
2
);
carrier_2
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
2
});
carrier_2
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_3
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
3
,
3
);
carrier_3
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
3
});
carrier_3
->
SetMsgBus
(
msg_bus
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
Interceptor
*
a
=
carrier_0
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Parcel"
,
0
,
nullptr
));
carrier_1
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Parcel"
,
1
,
nullptr
));
carrier_2
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Parcel"
,
2
,
nullptr
));
carrier_3
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Parcel"
,
3
,
nullptr
));
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier_0
->
Wait
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
3658405c
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -59,20 +60,21 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}},
{
0
,
1
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
.
SetInterceptor
(
1
,
std
::
make_unique
<
PingPongInterceptor
>
(
1
,
nullptr
));
carrier
->
SetInterceptor
(
1
,
std
::
make_unique
<
PingPongInterceptor
>
(
1
,
nullptr
));
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
carrier
->
Wait
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
3658405c
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/
fleet_executor
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -107,42 +107,31 @@ TEST(InterceptorTest, PingPong) {
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank
=
{{
0
,
0
},
{
1
,
1
}};
int
exe_pid
=
fork
();
if
(
exe_pid
==
0
)
{
int
pid
=
fork
();
if
(
pid
==
0
)
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
->
Barrier
();
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
}
else
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
->
Barrier
();
carrier
->
Wait
();
int
status
;
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
pid
);
}
int
pid
=
fork
();
if
(
pid
==
0
)
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
carrier
->
Init
(
0
,
interceptor_id_to_rank
,
{
0
});
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
msg_bus
->
Barrier
();
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
}
else
{
int
status
;
int
ret
=
waitpid
(
exe_pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
exe_pid
);
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
Init
(
1
,
interceptor_id_to_rank
,
{
1
});
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
msg_bus
->
Barrier
();
carrier
->
Wait
();
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
3658405c
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -51,10 +52,12 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{
0
,
1
,
2
,
3
,
4
,
5
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
...
...
@@ -73,21 +76,23 @@ TEST(AmplifierInterceptor, Amplifier) {
node_b
->
SetReplyUpPerSteps
(
micro_steps
);
node_e
->
SetSendDownPerSteps
(
micro_steps
);
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
.
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Compute"
,
3
,
node_d
));
carrier
.
SetInterceptor
(
4
,
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
carrier
.
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Compute"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
4
,
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
carrier
->
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
3658405c
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -69,10 +70,11 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}});
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
0
,
1
,
2
,
3
});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
...
...
@@ -91,19 +93,21 @@ TEST(AmplifierInterceptor, Amplifier) {
node_d
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunAtOffset
(
micro_steps
-
1
);
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
.
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
carrier
.
Wait
();
carrier
.
Release
();
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录