Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8a4460f5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8a4460f5
编写于
11月 30, 2021
作者:
W
WangXi
提交者:
GitHub
11月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] interceptor run from python interface (#37693)
上级
82b55961
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
58 addition
and
30 deletion
+58
-30
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+28
-15
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+3
-0
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+12
-12
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+11
-3
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+3
-0
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
+1
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
8a4460f5
...
@@ -92,19 +92,22 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
...
@@ -92,19 +92,22 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
}
}
void
Carrier
::
Start
()
{
void
Carrier
::
Start
()
{
// TODO(fleet_executor dev): this start is a faked one, need replace
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
for
(
const
auto
&
pair
:
interceptor_idx_to_interceptor_
)
{
PADDLE_ENFORCE_EQ
(
msg_bus
.
IsInit
(),
true
,
VLOG
(
3
)
<<
"Fake run is sending start to interceptor "
<<
pair
.
first
<<
"."
;
platform
::
errors
::
PreconditionNotMet
(
InterceptorMessage
tmp_msg
;
"Message bus has not been initialized."
));
tmp_msg
.
set_src_id
(
pair
.
first
);
tmp_msg
.
set_dst_id
(
pair
.
first
);
for
(
int64_t
id
:
source_interceptor_ids_
)
{
tmp_msg
.
set_message_type
(
DATA_IS_READY
);
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
MessageBus
&
message_bus_instance
=
MessageBus
::
Instance
();
<<
"."
;
PADDLE_ENFORCE_EQ
(
message_bus_instance
.
IsInit
(),
true
,
InterceptorMessage
start_msg
;
platform
::
errors
::
PreconditionNotMet
(
// source node data_is_ready is send by carrier, so set src_id=-1
"Message bus has not been initialized."
));
start_msg
.
set_src_id
(
-
1
);
message_bus_instance
.
Send
(
tmp_msg
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
msg_bus
.
Send
(
start_msg
);
}
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
cond_var_
.
wait
(
lock
);
cond_var_
.
wait
(
lock
);
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
...
@@ -164,16 +167,26 @@ void Carrier::CreateInterceptors() {
...
@@ -164,16 +167,26 @@ void Carrier::CreateInterceptors() {
int64_t
interceptor_id
=
item
.
first
;
int64_t
interceptor_id
=
item
.
first
;
TaskNode
*
task_node
=
item
.
second
;
TaskNode
*
task_node
=
item
.
second
;
// TODO(wangxi): use node_type to select different Interceptor
std
::
unique_ptr
<
Interceptor
>
interceptor
;
auto
interceptor
=
if
(
task_node
->
type
().
empty
())
{
std
::
make_unique
<
Interceptor
>
(
interceptor_id
,
task_node
);
// TODO(wangxi): delete this in future
interceptor
.
reset
(
new
Interceptor
(
interceptor_id
,
task_node
));
}
else
{
interceptor
=
InterceptorFactory
::
Create
(
task_node
->
type
(),
interceptor_id
,
task_node
);
}
interceptor
->
SetPlace
(
place_
);
interceptor
->
SetPlace
(
place_
);
interceptor
->
SetMiniBatchScope
(
minibatch_scope_
);
interceptor
->
SetMiniBatchScope
(
minibatch_scope_
);
interceptor
->
SetMicroBatchScope
(
microbatch_scopes_
);
interceptor
->
SetMicroBatchScope
(
microbatch_scopes_
);
interceptor
->
SetRootScope
(
root_scope_
);
interceptor
->
SetRootScope
(
root_scope_
);
SetInterceptor
(
interceptor_id
,
std
::
move
(
interceptor
));
SetInterceptor
(
interceptor_id
,
std
::
move
(
interceptor
));
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
<<
"."
;
<<
"."
;
if
(
task_node
->
upstream
().
empty
())
{
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
}
}
}
// The carrier will be always waiting for outside initializer
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
// since there is no interceptor has been created during auto init
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
8a4460f5
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <condition_variable>
#include <condition_variable>
#include <memory>
#include <memory>
#include <mutex>
#include <mutex>
#include <set>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -90,6 +91,8 @@ class Carrier final {
...
@@ -90,6 +91,8 @@ class Carrier final {
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
interceptor_idx_to_interceptor_
;
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
std
::
vector
<
InterceptorMessage
>
message_tmp_
{};
std
::
vector
<
InterceptorMessage
>
message_tmp_
{};
std
::
mutex
tmp_message_mutex_
;
std
::
mutex
tmp_message_mutex_
;
bool
creating_interceptors_
{
true
};
bool
creating_interceptors_
{
true
};
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
8a4460f5
...
@@ -154,18 +154,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -154,18 +154,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
}
void
ComputeInterceptor
::
Run
()
{
void
ComputeInterceptor
::
Run
()
{
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run
if
(
ShouldReset
())
{
for
(
auto
&
out_buff
:
out_buffs_
)
{
// buffer is using
if
(
out_buff
.
second
.
second
!=
0
)
return
;
}
step_
=
0
;
// reset
return
;
}
while
(
IsInputReady
()
&&
CanWriteOutput
()
&&
!
ShouldReset
())
{
while
(
IsInputReady
()
&&
CanWriteOutput
()
&&
!
ShouldReset
())
{
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
...
@@ -181,6 +169,18 @@ void ComputeInterceptor::Run() {
...
@@ -181,6 +169,18 @@ void ComputeInterceptor::Run() {
// reply to upstream and decrease ready data
// reply to upstream and decrease ready data
ReplyCompletedToUpStream
();
ReplyCompletedToUpStream
();
}
}
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if
(
ShouldReset
())
{
for
(
auto
&
out_buff
:
out_buffs_
)
{
// buffer is using
if
(
out_buff
.
second
.
second
!=
0
)
return
;
}
step_
=
0
;
// reset
return
;
}
}
}
void
ComputeInterceptor
::
ReceivedStop
(
int64_t
up_id
)
{
void
ComputeInterceptor
::
ReceivedStop
(
int64_t
up_id
)
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
8a4460f5
...
@@ -46,11 +46,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
...
@@ -46,11 +46,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
VLOG
(
3
)
<<
"Interceptor is using default message handler. This handler is "
VLOG
(
3
)
<<
"Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor "
"only used for test purpose. Check whether you init interceptor "
"in the proper way."
;
"in the proper way."
;
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
if
(
node_
->
role
()
!=
2
)
{
VLOG
(
3
)
<<
"Fake handler is sending DATA_IS_READY message to: "
<<
interceptor_id_
+
1
<<
"."
;
InterceptorMessage
data_is_ready_msg
;
data_is_ready_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
interceptor_id_
+
1
,
data_is_ready_msg
);
}
VLOG
(
3
)
<<
"Fake handler is sending stop message to it self."
;
VLOG
(
3
)
<<
"Fake handler is sending stop message to it self."
;
InterceptorMessage
msg
;
InterceptorMessage
stop_
msg
;
msg
.
set_message_type
(
STOP
);
stop_
msg
.
set_message_type
(
STOP
);
Send
(
interceptor_id_
,
msg
);
Send
(
interceptor_id_
,
stop_
msg
);
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
stop_
=
true
;
StopCarrier
();
StopCarrier
();
...
...
paddle/fluid/distributed/fleet_executor/message_bus.cc
浏览文件 @
8a4460f5
...
@@ -136,6 +136,9 @@ void MessageBus::ListenPort() {
...
@@ -136,6 +136,9 @@ void MessageBus::ListenPort() {
}
}
bool
MessageBus
::
IsSameRank
(
int64_t
src_id
,
int64_t
dst_id
)
{
bool
MessageBus
::
IsSameRank
(
int64_t
src_id
,
int64_t
dst_id
)
{
// -1 is sent by carrier to source interceptor
if
(
src_id
==
-
1
)
src_id
=
dst_id
;
// check whether the dst is the same rank or different rank with src
// check whether the dst is the same rank or different rank with src
const
auto
&
src_rank
=
interceptor_id_to_rank_
.
find
(
src_id
);
const
auto
&
src_rank
=
interceptor_id_to_rank_
.
find
(
src_id
);
const
auto
&
dst_rank
=
interceptor_id_to_rank_
.
find
(
dst_id
);
const
auto
&
dst_rank
=
interceptor_id_to_rank_
.
find
(
dst_id
);
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
浏览文件 @
8a4460f5
...
@@ -112,6 +112,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
...
@@ -112,6 +112,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
for
(
const
auto
&
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
for
(
const
auto
&
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
ops_
.
emplace_back
(
OpRegistry
::
CreateOp
(
*
op_desc
));
ops_
.
emplace_back
(
OpRegistry
::
CreateOp
(
*
op_desc
));
}
}
std
::
unordered_map
<
int32_t
,
std
::
vector
<
OperatorBase
*>>
role_to_ops
;
std
::
unordered_map
<
int32_t
,
std
::
vector
<
OperatorBase
*>>
role_to_ops
;
for
(
const
auto
&
op
:
ops_
)
{
for
(
const
auto
&
op
:
ops_
)
{
int32_t
op_role
=
op
->
Attr
<
int32_t
>
(
"op_role"
);
int32_t
op_role
=
op
->
Attr
<
int32_t
>
(
"op_role"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录