Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ddc15a18
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ddc15a18
编写于
12月 22, 2021
作者:
L
LiYuRio
提交者:
GitHub
12月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Move IntraSend to Carrier. Using blocking queue (#38322)
上级
142ea171
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
228 addition
and
262 deletion
+228
-262
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+2
-2
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+48
-23
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+8
-2
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+15
-15
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+11
-2
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+5
-23
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+5
-14
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+2
-2
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+62
-110
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+8
-10
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
+1
-1
paddle/fluid/distributed/fleet_executor/runtime_graph.h
paddle/fluid/distributed/fleet_executor/runtime_graph.h
+10
-10
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+2
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+35
-28
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+2
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+2
-4
paddle/fluid/framework/blocking_queue.h
paddle/fluid/framework/blocking_queue.h
+6
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
ddc15a18
...
@@ -5,7 +5,7 @@ endif()
...
@@ -5,7 +5,7 @@ endif()
proto_library
(
interceptor_message_proto SRCS interceptor_message.proto
)
proto_library
(
interceptor_message_proto SRCS interceptor_message.proto
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
if
(
WITH_DISTRIBUTE AND WITH_PSCORE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
set
(
BRPC_DEPS brpc ssl crypto protobuf
gflags glog
zlib leveldb snappy gflags glog
)
set
(
BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog
)
else
()
else
()
set
(
BRPC_DEPS
""
)
set
(
BRPC_DEPS
""
)
endif
()
endif
()
...
@@ -13,7 +13,7 @@ endif()
...
@@ -13,7 +13,7 @@ endif()
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper
${
BRPC_DEPS
}
)
executor_gc_helper
gflags glog
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
ddc15a18
...
@@ -27,14 +27,16 @@ namespace distributed {
...
@@ -27,14 +27,16 @@ namespace distributed {
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
USE_INTERCEPTOR
(
Amplifier
);
void
Carrier
::
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Carrier
::
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"Carrier is already init."
));
"Carrier is already init."
));
rank_
=
rank
;
runtime_graph_
=
runtime_graph
;
runtime_graph_
=
runtime_graph
;
interceptor_id_to_rank_
=
runtime_graph_
->
interceptor_id_to_rank
();
minibatch_scope_
=
minibatch_scope
;
minibatch_scope_
=
minibatch_scope
;
microbatch_scopes_
=
microbatch_scopes
;
microbatch_scopes_
=
microbatch_scopes
;
place_
=
place
;
place_
=
place
;
...
@@ -48,12 +50,6 @@ void Carrier::Release() {
...
@@ -48,12 +50,6 @@ void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
VLOG
(
3
)
<<
"Carrier Release is sending stop to source interceptor "
<<
id
<<
"."
;
<<
"."
;
...
@@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
...
@@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool
Carrier
::
EnqueueInterceptorMessage
(
bool
Carrier
::
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
const
InterceptorMessage
&
interceptor_message
)
{
// enqueue message to interceptor
if
(
interceptor_message
.
ctrl_message
())
{
if
(
interceptor_message
.
ctrl_message
())
{
// handle control message
VLOG
(
3
)
<<
"Receiving control message from rank "
return
true
;
<<
interceptor_message
.
src_id
()
<<
" to rank "
<<
interceptor_message
.
dst_id
();
}
else
{
}
else
{
{
{
std
::
unique_lock
<
std
::
mutex
>
lock_creating
(
creating_flag_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock_creating
(
creating_flag_mutex_
);
...
@@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
...
@@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
}
}
int64_t
dst_id
=
interceptor_message
.
dst_id
();
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
bool
rst
=
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
if
(
rst
)
{
std
::
condition_variable
&
interceptor_cond_var
=
dst_interceptor
->
GetCondVar
();
interceptor_cond_var
.
notify_all
();
}
return
rst
;
}
}
return
true
;
}
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
...
@@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
...
@@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
// TODO(liyurui): Move SendIntra into carrier
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
const
{
PADDLE_ENFORCE_NE
(
return
msg_bus_
->
Send
(
msg
);
interceptor_id_to_rank_
.
find
(
interceptor_id
),
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for interceptor id %lld."
,
interceptor_id
));
return
interceptor_id_to_rank_
.
at
(
interceptor_id
);
}
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
{
int64_t
src_id
=
(
msg
.
src_id
()
==
-
1
)
?
msg
.
dst_id
()
:
msg
.
src_id
();
int64_t
dst_id
=
msg
.
dst_id
();
int64_t
src_rank
=
GetRank
(
src_id
);
int64_t
dst_rank
=
GetRank
(
dst_id
);
PADDLE_ENFORCE_EQ
(
src_rank
,
rank_
,
platform
::
errors
::
Fatal
(
"The source rank id %lld, which is not equal to "
"the carrier rank id %lld."
,
src_rank
,
rank_
));
if
(
src_rank
==
dst_rank
)
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
EnqueueInterceptorMessage
(
msg
);
}
else
{
PADDLE_ENFORCE_NOT_NULL
(
msg_bus_
.
get
(),
platform
::
errors
::
Unavailable
(
"Message bus is released accidently"
));
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."
));
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in different ranks."
;
return
msg_bus_
->
Send
(
dst_rank
,
msg
);
}
}
}
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
...
@@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
...
@@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
}
void
Carrier
::
CreateInterceptors
()
{
void
Carrier
::
CreateInterceptors
()
{
if
(
runtime_graph_
->
intercept
e
r_id_to_node
().
empty
())
return
;
if
(
runtime_graph_
->
intercept
o
r_id_to_node
().
empty
())
return
;
auto
gc
=
GetGC
(
place_
);
auto
gc
=
GetGC
(
place_
);
// create each Interceptor
// create each Interceptor
// no auto init since there is no config
// no auto init since there is no config
for
(
const
auto
&
item
:
runtime_graph_
->
intercept
e
r_id_to_node
())
{
for
(
const
auto
&
item
:
runtime_graph_
->
intercept
o
r_id_to_node
())
{
int64_t
interceptor_id
=
item
.
first
;
int64_t
interceptor_id
=
item
.
first
;
TaskNode
*
task_node
=
item
.
second
;
TaskNode
*
task_node
=
item
.
second
;
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
ddc15a18
...
@@ -45,8 +45,11 @@ class MessageBus;
...
@@ -45,8 +45,11 @@ class MessageBus;
class
Carrier
final
{
class
Carrier
final
{
public:
public:
Carrier
()
=
default
;
Carrier
()
=
default
;
Carrier
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
:
rank_
(
rank
),
interceptor_id_to_rank_
(
interceptor_id_to_rank
)
{}
~
Carrier
();
~
Carrier
();
void
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
void
Init
(
int64_t
rank
,
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
);
const
platform
::
Place
&
place
);
...
@@ -75,7 +78,7 @@ class Carrier final {
...
@@ -75,7 +78,7 @@ class Carrier final {
bool
IsInit
()
const
;
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
)
const
;
bool
Send
(
const
InterceptorMessage
&
msg
);
// NOTE: This mutex will be used in interceptor's RunOps function.
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
// This mutex is used for avoiding forward ops and backward ops run
...
@@ -90,6 +93,8 @@ class Carrier final {
...
@@ -90,6 +93,8 @@ class Carrier final {
void
HandleTmpMessages
();
void
HandleTmpMessages
();
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
// interceptor logic id to actually interceptor
// interceptor logic id to actually interceptor
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
interceptor_idx_to_interceptor_
;
...
@@ -111,6 +116,7 @@ class Carrier final {
...
@@ -111,6 +116,7 @@ class Carrier final {
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
};
};
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
ddc15a18
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
@@ -28,6 +27,8 @@
...
@@ -28,6 +27,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
std
::
unique_ptr
<
Carrier
>
FleetExecutor
::
carrier_
;
FleetExecutor
::
FleetExecutor
(
const
std
::
string
&
exe_desc_str
)
{
FleetExecutor
::
FleetExecutor
(
const
std
::
string
&
exe_desc_str
)
{
bool
parse_flag
=
exe_desc_
.
ParseFromString
(
exe_desc_str
);
bool
parse_flag
=
exe_desc_
.
ParseFromString
(
exe_desc_str
);
PADDLE_ENFORCE
(
parse_flag
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE
(
parse_flag
,
platform
::
errors
::
PreconditionNotMet
(
...
@@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
...
@@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor
::~
FleetExecutor
()
{
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
root_scope_
->
DropKids
();
GetCarrier
()
.
Release
();
GetCarrier
()
->
Release
();
}
}
Carrier
&
FleetExecutor
::
GetCarrier
()
{
Carrier
*
FleetExecutor
::
GetCarrier
()
{
static
Carrier
carrier
;
PADDLE_ENFORCE_NOT_NULL
(
carrier_
.
get
(),
platform
::
errors
::
NotFound
(
return
carrier
;
"Carrier has not been created."
));
return
carrier_
.
get
();
}
}
void
FleetExecutor
::
Init
(
void
FleetExecutor
::
Init
(
...
@@ -84,16 +86,16 @@ void FleetExecutor::Init(
...
@@ -84,16 +86,16 @@ void FleetExecutor::Init(
}
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
CreateCarrier
();
InitCarrier
();
InitCarrier
();
InitMessageBus
();
InitMessageBus
();
}
}
void
FleetExecutor
::
InitCarrier
()
{
void
FleetExecutor
::
InitCarrier
()
{
Carrier
&
carrier
=
GetCarrier
();
if
(
!
GetCarrier
()
->
IsInit
())
{
if
(
!
carrier
.
IsInit
())
{
GetCarrier
()
->
SetMsgBus
(
msg_bus_
);
carrier
.
SetMsgBus
(
msg_bus_
);
GetCarrier
()
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
,
root_scope_
,
carrier
.
Init
(
runtime_graph_
,
root_scope_
,
minibatch_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
microbatch_scopes_
,
place_
);
}
}
}
}
...
@@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
...
@@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
VLOG
(
5
)
<<
ss
.
str
();
VLOG
(
5
)
<<
ss
.
str
();
if
(
!
msg_bus_
->
IsInit
())
{
if
(
!
msg_bus_
->
IsInit
())
{
msg_bus_
->
Init
(
runtime_graph_
->
intercepter_id_to_rank
(),
rank_to_addr
,
msg_bus_
->
Init
(
cur_rank
,
rank_to_addr
,
addr
);
addr
);
}
}
}
}
void
FleetExecutor
::
Run
()
{
void
FleetExecutor
::
Run
()
{
// Run
// Run
Carrier
&
carrier
=
GetCarrier
();
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
carrier
.
IsInit
(),
true
,
GetCarrier
()
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
platform
::
errors
::
Unavailable
(
"Carrier has not been init yet."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
msg_bus_
->
IsInit
(),
true
,
msg_bus_
->
IsInit
(),
true
,
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
platform
::
errors
::
Unavailable
(
"MessageBus has not been init yet."
));
carrier
.
Start
();
GetCarrier
()
->
Start
();
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
ddc15a18
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <memory>
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -30,7 +31,6 @@ namespace distributed {
...
@@ -30,7 +31,6 @@ namespace distributed {
class
RuntimeGraph
;
class
RuntimeGraph
;
class
MessageBus
;
class
MessageBus
;
class
TaskNode
;
class
TaskNode
;
class
Carrier
;
class
FleetExecutor
final
{
class
FleetExecutor
final
{
public:
public:
...
@@ -43,7 +43,15 @@ class FleetExecutor final {
...
@@ -43,7 +43,15 @@ class FleetExecutor final {
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
void
Run
();
// TODO(liyurui): Change to use registry table for multi-carrier.
// TODO(liyurui): Change to use registry table for multi-carrier.
static
Carrier
&
GetCarrier
();
static
Carrier
*
GetCarrier
();
template
<
typename
...
Args
>
static
Carrier
*
CreateCarrier
(
Args
&&
...
args
)
{
PADDLE_ENFORCE_EQ
(
carrier_
.
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Carrier has been created already."
));
carrier_
=
std
::
make_unique
<
Carrier
>
(
std
::
forward
<
Args
>
(
args
)...);
return
carrier_
.
get
();
}
private:
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
@@ -59,6 +67,7 @@ class FleetExecutor final {
...
@@ -59,6 +67,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
static
std
::
unique_ptr
<
Carrier
>
carrier_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
ddc15a18
...
@@ -52,24 +52,17 @@ void Interceptor::StopCarrier() {
...
@@ -52,24 +52,17 @@ void Interceptor::StopCarrier() {
cond_var
.
notify_all
();
cond_var
.
notify_all
();
}
}
std
::
condition_variable
&
Interceptor
::
GetCondVar
()
{
// get the conditional var
return
cond_var_
;
}
int64_t
Interceptor
::
GetInterceptorId
()
const
{
int64_t
Interceptor
::
GetInterceptorId
()
const
{
// return the interceptor id
// return the interceptor id
return
interceptor_id_
;
return
interceptor_id_
;
}
}
bool
Interceptor
::
EnqueueRemoteInterceptorMessage
(
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
const
InterceptorMessage
&
interceptor_message
)
{
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG
(
3
)
<<
"Enqueue message: "
<<
interceptor_message
.
message_type
()
VLOG
(
3
)
<<
"Enqueue message: "
<<
interceptor_message
.
message_type
()
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
remote_mailbox_mutex_
);
remote_mailbox_
.
Push
(
interceptor_message
);
remote_mailbox_
.
push
(
interceptor_message
);
return
true
;
}
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
...
@@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() {
...
@@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() {
"Error encountered when fetch remote mailbox."
));
"Error encountered when fetch remote mailbox."
));
}
}
const
InterceptorMessage
interceptor_message
=
local_mailbox_
.
front
();
const
InterceptorMessage
interceptor_message
=
local_mailbox_
.
front
();
local_mailbox_
.
pop
();
local_mailbox_
.
pop
_front
();
const
MessageType
message_type
=
interceptor_message
.
message_type
();
const
MessageType
message_type
=
interceptor_message
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
interceptor_message
.
src_id
()
<<
" from interceptor "
<<
interceptor_message
.
src_id
()
...
@@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() {
...
@@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() {
}
}
bool
Interceptor
::
FetchRemoteMailbox
()
{
bool
Interceptor
::
FetchRemoteMailbox
()
{
// fetch all Message from remote mailbox to local mailbox
remote_mailbox_
.
PopAll
(
&
local_mailbox_
);
// return true if remote mailbox not empty, otherwise return false
return
!
local_mailbox_
.
empty
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
remote_mailbox_mutex_
);
cond_var_
.
wait
(
lock
,
[
this
]()
{
return
!
remote_mailbox_
.
empty
();
});
if
(
remote_mailbox_
.
empty
())
{
// the thread has been unblocked accidentally
return
false
;
}
while
(
!
remote_mailbox_
.
empty
())
{
local_mailbox_
.
push
(
std
::
move
(
remote_mailbox_
.
front
()));
remote_mailbox_
.
pop
();
}
return
true
;
}
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
ddc15a18
...
@@ -15,14 +15,15 @@
...
@@ -15,14 +15,15 @@
#pragma once
#pragma once
#include <condition_variable>
#include <condition_variable>
#include <deque>
#include <functional>
#include <functional>
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <queue>
#include <thread>
#include <thread>
#include <vector>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
...
@@ -59,11 +60,8 @@ class Interceptor {
...
@@ -59,11 +60,8 @@ class Interceptor {
// return the interceptor id
// return the interceptor id
int64_t
GetInterceptorId
()
const
;
int64_t
GetInterceptorId
()
const
;
// return the conditional var
std
::
condition_variable
&
GetCondVar
();
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
bool
EnqueueRemoteInterceptorMessage
(
void
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
);
// NOLINT
bool
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
);
// NOLINT
...
@@ -115,23 +113,16 @@ class Interceptor {
...
@@ -115,23 +113,16 @@ class Interceptor {
// interceptor handle which process message
// interceptor handle which process message
MsgHandle
handle_
{
nullptr
};
MsgHandle
handle_
{
nullptr
};
// mutex to control read/write conflict for remote mailbox
std
::
mutex
remote_mailbox_mutex_
;
// interceptor runs PoolTheMailbox() function to poll local mailbox
// interceptor runs PoolTheMailbox() function to poll local mailbox
std
::
thread
interceptor_thread_
;
std
::
thread
interceptor_thread_
;
// conditional variable for blocking the thread when
// fetch an empty remote mailbox
std
::
condition_variable
cond_var_
;
// remote mailbox, written by EnqueueRemoteMessage()
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
// read by FetchRemoteMailbox()
std
::
q
ueue
<
InterceptorMessage
>
remote_mailbox_
;
framework
::
BlockingQ
ueue
<
InterceptorMessage
>
remote_mailbox_
;
// local mailbox, written by FetchRemoteMailbox()
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
// read by PoolTheMailbox()
std
::
que
ue
<
InterceptorMessage
>
local_mailbox_
;
std
::
deq
ue
<
InterceptorMessage
>
local_mailbox_
;
int64_t
already_run_times_
{
0
};
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
int64_t
used_slot_nums_
{
0
};
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
ddc15a18
...
@@ -29,8 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
...
@@ -29,8 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
<<
", with the message: "
<<
request
->
message_type
();
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
*
request
);
bool
flag
=
FleetExecutor
::
GetCarrier
()
->
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
true
);
response
->
set_rst
(
flag
);
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
ddc15a18
...
@@ -17,8 +17,6 @@
...
@@ -17,8 +17,6 @@
#include <set>
#include <set>
#include <thread>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
...
@@ -26,16 +24,25 @@ namespace paddle {
...
@@ -26,16 +24,25 @@ namespace paddle {
namespace
distributed
{
namespace
distributed
{
void
MessageBus
::
Init
(
void
MessageBus
::
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
)
{
const
std
::
string
&
addr
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"MessageBus is already init."
));
"MessageBus is already init."
));
rank_
=
rank
;
is_init_
=
true
;
is_init_
=
true
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
rank_to_addr_
=
rank_to_addr
;
rank_to_addr_
=
rank_to_addr
;
addr_
=
addr
;
addr_
=
addr
;
if
(
addr_
!=
""
)
{
const
auto
&
addr
=
GetAddr
(
rank_
);
PADDLE_ENFORCE_EQ
(
addr
,
addr_
,
platform
::
errors
::
Fatal
(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error."
,
addr
,
addr_
));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
// NOTE: To make the brpc is compatible with collective,
...
@@ -65,26 +72,23 @@ MessageBus::~MessageBus() {
...
@@ -65,26 +72,23 @@ MessageBus::~MessageBus() {
#endif
#endif
}
}
bool
MessageBus
::
Send
(
const
InterceptorMessage
&
interceptor_message
)
{
const
std
::
string
&
MessageBus
::
GetAddr
(
int64_t
rank
)
const
{
// called by Interceptor, send InterceptorMessage to dst
PADDLE_ENFORCE_NE
(
int64_t
src_id
=
interceptor_message
.
src_id
();
rank_to_addr_
.
find
(
rank
),
rank_to_addr_
.
end
(),
int64_t
dst_id
=
interceptor_message
.
dst_id
();
platform
::
errors
::
NotFound
(
"Cannot find addr rank id %lld."
,
rank
));
if
(
IsSameRank
(
src_id
,
dst_id
))
{
return
rank_to_addr_
.
at
(
rank
);
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
}
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
SendIntraRank
(
interceptor_message
);
bool
MessageBus
::
Send
(
int64_t
dst_rank
,
}
else
{
const
InterceptorMessage
&
interceptor_message
)
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in different ranks."
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
!defined(PADDLE_WITH_ASCEND_CL)
int
retry_time
=
0
;
// message bus will retry sending for 10 times
int
retry_time
=
0
;
// message bus will retry sending for 10 times
while
(
retry_time
<
10
)
{
while
(
retry_time
<
10
)
{
++
retry_time
;
++
retry_time
;
if
(
SendInterRank
(
interceptor_message
))
{
if
(
SendInterRank
(
dst_rank
,
interceptor_message
))
{
VLOG
(
3
)
<<
"Message bus sends inter rank successfully with "
VLOG
(
3
)
<<
"Message bus sends inter rank successfully with "
<<
retry_time
<<
retry_time
<<
" times retries."
;
<<
" times retries."
;
return
true
;
return
true
;
}
}
VLOG
(
3
)
<<
"Message bus sends failed, retry after 1 seconds."
;
VLOG
(
3
)
<<
"Message bus sends failed, retry after 1 seconds."
;
...
@@ -98,10 +102,27 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
...
@@ -98,10 +102,27 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
"ranks when Paddle is compiled with npu or "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."
));
"isn't compiled with distributed for now."
));
#endif
#endif
}
return
true
;
return
true
;
}
}
void
MessageBus
::
TestConnection
()
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
for
(
const
auto
&
dst_rank_pair
:
rank_to_addr_
)
{
int64_t
dst_rank
=
dst_rank_pair
.
first
;
if
(
dst_rank
!=
rank_
)
{
ctrl_msg
.
set_dst_id
(
dst_rank
);
VLOG
(
3
)
<<
"Send control message bus from rank "
<<
rank_
<<
" to rank "
<<
dst_rank
;
while
(
!
Send
(
dst_rank
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
dst_rank
<<
"."
;
}
}
}
void
MessageBus
::
ListenPort
()
{
void
MessageBus
::
ListenPort
()
{
if
(
addr_
==
""
)
{
if
(
addr_
==
""
)
{
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
...
@@ -130,30 +151,7 @@ void MessageBus::ListenPort() {
...
@@ -130,30 +151,7 @@ void MessageBus::ListenPort() {
interval
+=
500
;
interval
+=
500
;
}
}
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
TestConnection
();
std
::
set
<
int64_t
>
visit
;
InterceptorMessage
tmp_msg
;
tmp_msg
.
set_ctrl_message
(
true
);
for
(
auto
pair
:
interceptor_id_to_rank_
)
{
if
(
rank_to_addr_
.
at
(
pair
.
second
)
==
addr_
)
{
tmp_msg
.
set_src_id
(
pair
.
first
);
}
}
for
(
auto
pair
:
interceptor_id_to_rank_
)
{
int64_t
rank
=
pair
.
second
;
if
(
rank_to_addr_
.
at
(
rank
)
==
addr_
)
{
continue
;
}
tmp_msg
.
set_dst_id
(
pair
.
first
);
if
(
visit
.
find
(
rank
)
==
visit
.
end
())
{
VLOG
(
3
)
<<
"Message bus is testing connection for rank: "
<<
rank
<<
"."
;
visit
.
insert
(
rank
);
while
(
!
Send
(
tmp_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus has connected to rank: "
<<
rank
<<
"."
;
}
}
#else
#else
LOG
(
WARNING
)
LOG
(
WARNING
)
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
...
@@ -162,53 +160,13 @@ void MessageBus::ListenPort() {
...
@@ -162,53 +160,13 @@ void MessageBus::ListenPort() {
#endif
#endif
}
}
bool
MessageBus
::
IsSameRank
(
int64_t
src_id
,
int64_t
dst_id
)
{
// -1 is sent by carrier to source interceptor
if
(
src_id
==
-
1
)
src_id
=
dst_id
;
// check whether the dst is the same rank or different rank with src
const
auto
&
src_rank
=
interceptor_id_to_rank_
.
find
(
src_id
);
const
auto
&
dst_rank
=
interceptor_id_to_rank_
.
find
(
dst_id
);
PADDLE_ENFORCE_NE
(
src_rank
,
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for src interceptor id %lld. Init error."
,
src_id
));
PADDLE_ENFORCE_NE
(
dst_rank
,
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for dst interceptor id %lld. Init error."
,
dst_id
));
if
(
addr_
==
""
)
{
// single card training, must be same rank
return
true
;
}
const
auto
&
src_ip
=
rank_to_addr_
.
find
(
src_rank
->
second
);
PADDLE_ENFORCE_NE
(
src_ip
,
rank_to_addr_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find addr for src rank id %lld. Init error."
,
src_rank
->
second
));
PADDLE_ENFORCE_EQ
(
src_ip
->
second
,
addr_
,
platform
::
errors
::
Fatal
(
"The src interceptor's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error."
,
src_ip
->
second
,
addr_
));
return
src_rank
->
second
==
dst_rank
->
second
;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
!defined(PADDLE_WITH_ASCEND_CL)
bool
MessageBus
::
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
)
{
bool
MessageBus
::
SendInterRank
(
int64_t
dst_rank
,
// send the message inter rank (dst is different rank with src)
const
InterceptorMessage
&
interceptor_message
)
{
int64_t
dst_id
=
interceptor_message
.
dst_id
();
const
auto
&
dst_addr
=
GetAddr
(
dst_rank
);
int64_t
dst_rank
=
interceptor_id_to_rank_
[
dst_id
];
VLOG
(
3
)
<<
"Message bus sending to addr: "
<<
dst_addr
;
auto
dst_ip
=
rank_to_addr_
.
find
(
dst_rank
);
const
char
*
dst_addr_for_brpc
=
dst_addr
.
c_str
();
PADDLE_ENFORCE_NE
(
dst_ip
,
rank_to_addr_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Cannot find rank for dst interceptor id %lld. "
"Init error."
,
dst_id
));
VLOG
(
3
)
<<
"Message bus sending to addr: "
<<
dst_ip
->
second
;
const
char
*
dst_ip_for_brpc
=
dst_ip
->
second
.
c_str
();
brpc
::
Channel
channel
;
brpc
::
Channel
channel
;
brpc
::
ChannelOptions
options
;
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
protocol
=
"baidu_std"
;
...
@@ -216,7 +174,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
...
@@ -216,7 +174,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
options
.
timeout_ms
=
1000
;
options
.
timeout_ms
=
1000
;
options
.
max_retry
=
5
;
options
.
max_retry
=
5
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
channel
.
Init
(
dst_
ip
_for_brpc
,
&
options
),
0
,
channel
.
Init
(
dst_
addr
_for_brpc
,
&
options
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc channel error."
));
platform
::
errors
::
Unavailable
(
"Message bus: init brpc channel error."
));
TheInterceptorMessageService_Stub
stub
(
&
channel
);
TheInterceptorMessageService_Stub
stub
(
&
channel
);
InterceptorResponse
response
;
InterceptorResponse
response
;
...
@@ -239,11 +197,5 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
...
@@ -239,11 +197,5 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
}
}
#endif
#endif
bool
MessageBus
::
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
)
{
// send the message intra rank (dst is the same rank with src)
return
FleetExecutor
::
GetCarrier
().
EnqueueInterceptorMessage
(
interceptor_message
);
}
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/message_bus.h
浏览文件 @
ddc15a18
...
@@ -42,14 +42,14 @@ class MessageBus final {
...
@@ -42,14 +42,14 @@ class MessageBus final {
MessageBus
()
=
default
;
MessageBus
()
=
default
;
~
MessageBus
();
~
MessageBus
();
void
Init
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_
rank
,
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
);
const
std
::
string
&
addr
);
bool
IsInit
()
const
;
bool
IsInit
()
const
;
// called by Interceptor, send InterceptorMessage to dst
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
private:
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
...
@@ -57,22 +57,20 @@ class MessageBus final {
...
@@ -57,22 +57,20 @@ class MessageBus final {
// function keep listen the port and handle the message
// function keep listen the port and handle the message
void
ListenPort
();
void
ListenPort
();
// check whether the dst is the same rank or different rank with src
void
TestConnection
();
bool
IsSameRank
(
int64_t
src_id
,
int64_t
dst_id
);
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
!defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src)
// send the message inter rank (dst is different rank with src)
bool
SendInterRank
(
const
InterceptorMessage
&
interceptor_message
);
bool
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
#endif
#endif
bool
is_init_
{
false
};
bool
is_init_
{
false
};
// send the message intra rank (dst is the same rank with src)
int64_t
rank_
;
bool
SendIntraRank
(
const
InterceptorMessage
&
interceptor_message
);
// handed by above layer, save the info mapping interceptor id to rank id
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
// handed by above layer, save the info mapping rank id to addr
// handed by above layer, save the info mapping rank id to addr
std
::
unordered_map
<
int64_t
,
std
::
string
>
rank_to_addr_
;
std
::
unordered_map
<
int64_t
,
std
::
string
>
rank_to_addr_
;
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
浏览文件 @
ddc15a18
...
@@ -21,7 +21,7 @@ namespace distributed {
...
@@ -21,7 +21,7 @@ namespace distributed {
std
::
string
RuntimeGraph
::
DebugString
()
const
{
std
::
string
RuntimeGraph
::
DebugString
()
const
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
"
\n
Runtime Graph Debug:
\n
"
;
os
<<
"
\n
Runtime Graph Debug:
\n
"
;
for
(
const
auto
&
pair
:
intercept
e
r_id_to_node_
)
{
for
(
const
auto
&
pair
:
intercept
o
r_id_to_node_
)
{
os
<<
pair
.
second
->
DebugString
();
os
<<
pair
.
second
->
DebugString
();
os
<<
"
\n
"
;
os
<<
"
\n
"
;
}
}
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.h
浏览文件 @
ddc15a18
...
@@ -29,26 +29,26 @@ class RuntimeGraph final {
...
@@ -29,26 +29,26 @@ class RuntimeGraph final {
public:
public:
RuntimeGraph
()
=
default
;
RuntimeGraph
()
=
default
;
~
RuntimeGraph
()
=
default
;
~
RuntimeGraph
()
=
default
;
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
e
r_id_to_node
()
const
{
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
o
r_id_to_node
()
const
{
return
intercept
e
r_id_to_node_
;
return
intercept
o
r_id_to_node_
;
}
}
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
e
r_id_to_rank
()
const
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
o
r_id_to_rank
()
const
{
return
intercept
e
r_id_to_rank_
;
return
intercept
o
r_id_to_rank_
;
}
}
void
SetInterceptorIdToRank
(
void
SetInterceptorIdToRank
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
e
r_id_to_rank
)
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercept
o
r_id_to_rank
)
{
intercept
er_id_to_rank_
=
intercepte
r_id_to_rank
;
intercept
or_id_to_rank_
=
intercepto
r_id_to_rank
;
}
}
void
SetInterceptorIdToNode
(
void
SetInterceptorIdToNode
(
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
e
r_id_to_node
)
{
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
intercept
o
r_id_to_node
)
{
intercept
er_id_to_node_
=
intercepte
r_id_to_node
;
intercept
or_id_to_node_
=
intercepto
r_id_to_node
;
}
}
std
::
string
DebugString
()
const
;
std
::
string
DebugString
()
const
;
private:
private:
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
std
::
unordered_map
<
int64_t
,
TaskNode
*>
intercept
e
r_id_to_node_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
intercept
o
r_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
intercept
e
r_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
intercept
o
r_id_to_rank_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
ddc15a18
...
@@ -62,11 +62,10 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -62,11 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
platform
::
Place
place
=
platform
::
CPUPlace
();
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetMsgBus
(
msg_bus
);
// FIXME: don't delete, otherwise interceptor will use undefined node
// FIXME: don't delete, otherwise interceptor will use undefined node
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
ddc15a18
...
@@ -47,11 +47,10 @@ class StartInterceptor : public Interceptor {
...
@@ -47,11 +47,10 @@ class StartInterceptor : public Interceptor {
};
};
TEST
(
ComputeInterceptor
,
Compute
)
{
TEST
(
ComputeInterceptor
,
Compute
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetMsgBus
(
msg_bus
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
ddc15a18
...
@@ -18,7 +18,6 @@ limitations under the License. */
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
@@ -60,11 +59,9 @@ class PingPongInterceptor : public Interceptor {
...
@@ -60,11 +59,9 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
TEST
(
InterceptorTest
,
PingPong
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
}});
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
}}
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
Interceptor
*
a
=
carrier
.
SetInterceptor
(
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
ddc15a18
...
@@ -104,35 +104,42 @@ TEST(InterceptorTest, PingPong) {
...
@@ -104,35 +104,42 @@ TEST(InterceptorTest, PingPong) {
std
::
string
ip1
=
"127.0.0.1:"
+
std
::
to_string
(
port1
);
std
::
string
ip1
=
"127.0.0.1:"
+
std
::
to_string
(
port1
);
std
::
cout
<<
"ip0: "
<<
ip0
<<
std
::
endl
;
std
::
cout
<<
"ip0: "
<<
ip0
<<
std
::
endl
;
std
::
cout
<<
"ip1: "
<<
ip1
<<
std
::
endl
;
std
::
cout
<<
"ip1: "
<<
ip1
<<
std
::
endl
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank
=
{{
0
,
0
},
{
1
,
1
}};
int
exe_pid
=
fork
();
if
(
exe_pid
==
0
)
{
int
pid
=
fork
();
int
pid
=
fork
();
if
(
pid
==
0
)
{
if
(
pid
==
0
)
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
0
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Interceptor
*
a
=
carrier
->
SetInterceptor
(
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
carrier
.
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
InterceptorMessage
msg
;
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
.
Wait
();
carrier
->
Wait
();
}
else
{
}
else
{
Carrier
*
carrier
=
FleetExecutor
::
CreateCarrier
(
1
,
interceptor_id_to_rank
);
carrier
->
SetCreatingFlag
(
false
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
1
}},
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetMsgBus
(
msg_bus
);
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
carrier
->
SetInterceptor
(
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
carrier
.
SetMsgBus
(
msg_bus
);
carrier
->
Wait
();
int
status
;
carrier
.
SetInterceptor
(
1
,
int
ret
=
waitpid
(
pid
,
&
status
,
0
);
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
CHECK_EQ
(
ret
,
pid
);
carrier
.
SetCreatingFlag
(
false
);
}
carrier
.
Wait
();
}
else
{
int
status
;
int
ret
=
waitpid
(
exe_pid
,
&
status
,
0
);
CHECK_EQ
(
ret
,
exe_pid
);
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
ddc15a18
...
@@ -18,7 +18,6 @@ limitations under the License. */
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
@@ -52,11 +51,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
...
@@ -52,11 +51,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}});
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
3
;
int64_t
micro_steps
=
3
;
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
ddc15a18
...
@@ -18,7 +18,6 @@ limitations under the License. */
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
...
@@ -70,10 +69,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
...
@@ -70,10 +69,9 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier
carrier
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}});
Carrier
&
carrier
=
FleetExecutor
::
GetCarrier
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}}
,
{{
0
,
""
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
carrier
.
SetMsgBus
(
msg_bus
);
carrier
.
SetMsgBus
(
msg_bus
);
int64_t
micro_steps
=
6
;
int64_t
micro_steps
=
6
;
...
...
paddle/fluid/framework/blocking_queue.h
浏览文件 @
ddc15a18
...
@@ -75,6 +75,12 @@ class BlockingQueue {
...
@@ -75,6 +75,12 @@ class BlockingQueue {
return
ret
;
return
ret
;
}
}
void
PopAll
(
std
::
deque
<
T
>
*
empty_queue
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
!
q_
.
empty
();
});
std
::
swap
(
*
empty_queue
,
q_
);
}
T
Pop
()
{
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
cv_
.
wait
(
lock
,
[
=
]
{
return
!
q_
.
empty
();
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录