Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
769e5bc4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
769e5bc4
编写于
1月 07, 2022
作者:
L
LiYuRio
提交者:
GitHub
1月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Support multi carriers (#38709)
上级
7f3b0877
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
97 addition
and
88 deletion
+97
-88
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+3
-0
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+9
-27
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+0
-5
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+16
-14
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+0
-3
paddle/fluid/distributed/fleet_executor/global.h
paddle/fluid/distributed/fleet_executor/global.h
+26
-9
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+3
-5
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+22
-0
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+1
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+7
-10
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+2
-3
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
769e5bc4
...
...
@@ -19,6 +19,9 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"
${
DISTRIBUTE_COMPILE_FLAGS
}
-faligned-new"
)
endif
()
set_source_files_properties
(
interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
compute_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
769e5bc4
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global
_map
.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -71,17 +71,13 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool
Carrier
::
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
if
(
interceptor_message
.
ctrl_message
())
{
VLOG
(
3
)
<<
"Receiving control message from rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
// for barrier
msg_bus_
->
IncreaseBarrierCount
();
}
else
{
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
}
PADDLE_ENFORCE_EQ
(
interceptor_message
.
ctrl_message
(),
false
,
platform
::
errors
::
Fatal
(
"Control message should be only send inter rank using message bus."
));
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
return
true
;
}
...
...
@@ -106,11 +102,6 @@ void Carrier::WakeUp() {
}
void
Carrier
::
Start
()
{
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
PADDLE_ENFORCE_EQ
(
is_init_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using carrier before initialized."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
...
...
@@ -154,19 +145,10 @@ bool Carrier::Send(const InterceptorMessage& msg) {
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
EnqueueInterceptorMessage
(
msg
);
}
else
{
PADDLE_ENFORCE_NOT_NULL
(
msg_bus_
.
get
(),
platform
::
errors
::
Unavailable
(
"Message bus is released accidently"
));
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in different ranks."
;
return
msg_bus_
->
Send
(
dst_rank
,
msg
);
return
GlobalVal
<
MessageBus
>::
Get
()
->
Send
(
dst_rank
,
msg
);
}
}
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
769e5bc4
...
...
@@ -73,10 +73,6 @@ class Carrier final {
Interceptor
*
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
);
void
SetMsgBus
(
const
std
::
shared_ptr
<
MessageBus
>&
msg_bus
)
{
msg_bus_
=
msg_bus
;
}
void
Start
();
bool
IsInit
()
const
;
...
...
@@ -107,7 +103,6 @@ class Carrier final {
framework
::
Scope
*
minibatch_scope_
;
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
std
::
string
carrier_id_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
769e5bc4
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global
_map
.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -32,6 +32,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool
parse_flag
=
exe_desc_
.
ParseFromString
(
exe_desc_str
);
PADDLE_ENFORCE
(
parse_flag
,
platform
::
errors
::
PreconditionNotMet
(
"Error occurs while parsing string to proto"
));
// Message bus will be created and inited only once
GlobalVal
<
MessageBus
>::
Create
();
InitMessageBus
();
}
FleetExecutor
::~
FleetExecutor
()
{
...
...
@@ -81,21 +84,16 @@ void FleetExecutor::Init(
CopyParameters
(
i
,
program_desc
);
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier_ids_
.
insert
(
carrier_id
);
GlobalVal
<
std
::
string
>::
Set
(
carrier_id
);
// TODO(liyurui): Maybe message bus should be created only once
// Set current running carrier
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
InitCarrier
(
carrier
);
InitMessageBus
();
// Wait for all message bus connected.
msg_bus_
->
Barrier
();
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
void
FleetExecutor
::
InitCarrier
(
Carrier
*
carrier
)
{
carrier
->
SetMsgBus
(
msg_bus_
);
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_node
(),
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
...
...
@@ -131,14 +129,18 @@ void FleetExecutor::InitMessageBus() {
VLOG
(
3
)
<<
"The number of ranks are "
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
VLOG
(
5
)
<<
ss
.
str
();
if
(
!
msg_bus_
->
IsInit
())
{
msg_bus_
->
Init
(
cur_rank
,
rank_to_addr
,
addr
);
}
GlobalVal
<
MessageBus
>::
Get
()
->
Init
(
cur_rank
,
rank_to_addr
,
addr
);
}
void
FleetExecutor
::
Run
(
const
std
::
string
&
carrier_id
)
{
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
Start
();
GlobalVal
<
std
::
string
>::
Set
(
carrier_id
);
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
);
// Set current running carrier
if
(
*
GlobalVal
<
std
::
string
>::
Get
()
!=
carrier_id
)
{
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
// TODO(liyurui): Move barrier to service
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
carrier
->
Start
();
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
769e5bc4
...
...
@@ -55,9 +55,6 @@ class FleetExecutor final {
framework
::
Scope
*
minibatch_scope_
;
platform
::
Place
place_
;
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
;
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
};
...
...
paddle/fluid/distributed/fleet_executor/global
_map
.h
→
paddle/fluid/distributed/fleet_executor/global.h
浏览文件 @
769e5bc4
...
...
@@ -14,24 +14,41 @@
#pragma once
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
// TODO(liyurui): Change this file to global.h
template
<
typename
T
>
class
GlobalVal
final
{
public:
static
T
Get
()
{
return
*
GetPtr
();
}
static
T
Set
(
T
val
)
{
auto
*
ptr
=
GetPtr
();
*
ptr
=
val
;
return
val
;
static
T
*
Get
()
{
T
*
ptr
=
GetPPtr
()
->
get
();
PADDLE_ENFORCE_NOT_NULL
(
ptr
,
platform
::
errors
::
NotFound
(
"This value is not global value."
));
return
ptr
;
}
template
<
typename
...
Args
>
static
T
*
Create
(
Args
&&
...
args
)
{
auto
*
ptr
=
GetPPtr
();
PADDLE_ENFORCE_EQ
(
ptr
->
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"This value is already a global value."
));
T
*
item
=
new
T
(
std
::
forward
<
Args
>
(
args
)...);
ptr
->
reset
(
item
);
return
item
;
}
static
T
*
Set
(
T
*
new_item
)
{
auto
*
ptr
=
GetPPtr
();
ptr
->
reset
(
new_item
);
return
ptr
->
get
();
}
private:
static
T
*
Get
Ptr
()
{
static
T
value
;
return
&
value
;
static
std
::
unique_ptr
<
T
>*
GetP
Ptr
()
{
static
std
::
unique_ptr
<
T
>
ptr
;
return
&
ptr
;
}
};
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
769e5bc4
...
...
@@ -15,8 +15,8 @@
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/
carrier
.h"
#include "paddle/fluid/distributed/fleet_executor/
global_map
.h"
#include "paddle/fluid/distributed/fleet_executor/
global
.h"
#include "paddle/fluid/distributed/fleet_executor/
message_bus
.h"
namespace
paddle
{
namespace
distributed
{
...
...
@@ -29,9 +29,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
const
auto
&
carrier_id
=
GlobalVal
<
std
::
string
>::
Get
();
bool
flag
=
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
*
request
);
bool
flag
=
GlobalVal
<
MessageBus
>::
Get
()
->
DispatchMsgToCarrier
(
*
request
);
response
->
set_rst
(
flag
);
}
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
769e5bc4
...
...
@@ -17,6 +17,8 @@
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
...
...
@@ -81,6 +83,10 @@ const std::string& MessageBus::GetAddr(int64_t rank) const {
bool
MessageBus
::
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
PADDLE_ENFORCE_EQ
(
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized."
));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
int
retry_time
=
0
;
// message bus will retry sending for 10 times
...
...
@@ -155,6 +161,22 @@ void MessageBus::Barrier() {
}
}
bool
MessageBus
::
DispatchMsgToCarrier
(
const
InterceptorMessage
&
interceptor_message
)
{
if
(
interceptor_message
.
ctrl_message
())
{
VLOG
(
3
)
<<
"Receiving control message from rank "
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
// for barrier
IncreaseBarrierCount
();
return
true
;
}
else
{
const
std
::
string
&
carrier_id
=
*
GlobalVal
<
std
::
string
>::
Get
();
return
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
interceptor_message
);
}
}
void
MessageBus
::
ListenPort
()
{
if
(
addr_
==
""
)
{
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
...
...
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
769e5bc4
...
...
@@ -54,6 +54,7 @@ class MessageBus final {
void
IncreaseBarrierCount
();
void
Barrier
();
bool
DispatchMsgToCarrier
(
const
InterceptorMessage
&
interceptor_message
);
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
769e5bc4
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global
_map
.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -67,9 +67,8 @@ TEST(ComputeInterceptor, Compute) {
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
->
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
769e5bc4
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global
_map
.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -52,9 +52,8 @@ TEST(ComputeInterceptor, Compute) {
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
769e5bc4
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global
_map
.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -64,9 +64,8 @@ TEST(InterceptorTest, PingPong) {
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
->
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
769e5bc4
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global
_map
.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
...
@@ -112,12 +112,10 @@ TEST(InterceptorTest, PingPong) {
if
(
pid
==
0
)
{
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalVal
<
std
::
string
>::
Set
(
carrier_id
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
carrier
->
Init
(
0
,
interceptor_id_to_rank
);
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
Init
(
0
,
interceptor_id_to_rank
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
msg_bus
->
Barrier
();
...
...
@@ -127,11 +125,10 @@ TEST(InterceptorTest, PingPong) {
}
else
{
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalVal
<
std
::
string
>::
Set
(
carrier_id
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
Init
(
1
,
interceptor_id_to_rank
);
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
Init
(
1
,
interceptor_id_to_rank
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
msg_bus
->
Barrier
();
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
769e5bc4
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global
_map
.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -56,9 +56,8 @@ TEST(AmplifierInterceptor, Amplifier) {
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
->
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
769e5bc4
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global
_map
.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
...
@@ -74,9 +74,8 @@ TEST(AmplifierInterceptor, Amplifier) {
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
carrier
->
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录