Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
50f75fb5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
50f75fb5
编写于
11月 25, 2021
作者:
W
WangXi
提交者:
GitHub
11月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Compute Interceptor stop along data flow (#37531)
上级
992d4ebb
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
100 addition
and
34 deletion
+100
-34
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+9
-0
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+1
-1
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+49
-0
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+7
-0
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+17
-3
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+5
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+12
-30
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
50f75fb5
...
...
@@ -32,6 +32,15 @@ void Carrier::Init(
is_init_
=
true
;
}
Carrier
::~
Carrier
()
{
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// TODO(wangxi): Maybe need a better to use thread.
for
(
auto
&
interceptor
:
interceptor_idx_to_interceptor_
)
{
interceptor
.
second
->
Join
();
}
}
bool
Carrier
::
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
// enqueue message to interceptor
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
50f75fb5
...
...
@@ -42,7 +42,7 @@ class Carrier final {
void
Init
(
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
);
~
Carrier
()
=
default
;
~
Carrier
();
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
50f75fb5
...
...
@@ -35,6 +35,7 @@ void ComputeInterceptor::PrepareDeps() {
for
(
auto
up_id
:
upstream
)
{
in_readys_
.
emplace
(
up_id
,
std
::
make_pair
(
in_buff_size
,
0
));
in_stops_
.
emplace
(
up_id
,
false
);
}
for
(
auto
down_id
:
downstream
)
{
out_buffs_
.
emplace
(
down_id
,
std
::
make_pair
(
out_buff_size
,
0
));
...
...
@@ -144,6 +145,52 @@ void ComputeInterceptor::Run() {
}
}
void
ComputeInterceptor
::
ReceivedStop
(
int64_t
up_id
)
{
received_stop_
=
true
;
// source node has no upstream, stop is send by carrier or others
if
(
up_id
==
-
1
)
return
;
auto
it
=
in_stops_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
in_stops_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_stops."
,
up_id
));
PADDLE_ENFORCE_EQ
(
it
->
second
,
false
,
platform
::
errors
::
AlreadyExists
(
"Already received stop from %lld, stop "
"cannot be send more than once."
));
it
->
second
=
true
;
}
void
ComputeInterceptor
::
TryStop
()
{
if
(
!
received_stop_
)
return
;
// can stop only when all upstream is stop and
// downstream complete
for
(
auto
&
in_stop
:
in_stops_
)
{
if
(
!
in_stop
.
second
)
return
;
}
for
(
auto
&
out_buff
:
out_buffs_
)
{
auto
used_size
=
out_buff
.
second
.
second
;
if
(
used_size
!=
0
)
return
;
}
// send stop to downstream
for
(
auto
&
out
:
out_buffs_
)
{
auto
down_id
=
out
.
first
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
down_id
,
stop
);
}
stop_
=
true
;
}
void
ComputeInterceptor
::
HandleStop
(
const
InterceptorMessage
&
msg
)
{
ReceivedStop
(
msg
.
src_id
());
TryStop
();
}
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
IncreaseReady
(
msg
.
src_id
());
...
...
@@ -152,6 +199,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
DecreaseBuff
(
msg
.
src_id
());
Run
();
}
TryStop
();
}
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
浏览文件 @
50f75fb5
...
...
@@ -38,6 +38,10 @@ class ComputeInterceptor : public Interceptor {
void
Run
();
void
Compute
(
const
InterceptorMessage
&
msg
);
void
HandleStop
(
const
InterceptorMessage
&
msg
)
override
;
void
ReceivedStop
(
int64_t
up_id
);
void
TryStop
();
private:
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0
int64_t
step_
{
0
};
...
...
@@ -45,6 +49,9 @@ class ComputeInterceptor : public Interceptor {
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
in_readys_
{};
// downstream_id-->(max_buffer_size, used_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
bool
received_stop_
{
false
};
std
::
map
<
int64_t
,
bool
>
in_stops_
{};
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
50f75fb5
...
...
@@ -28,7 +28,13 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
});
}
Interceptor
::~
Interceptor
()
{
interceptor_thread_
.
join
();
}
Interceptor
::~
Interceptor
()
{
Join
();
}
void
Interceptor
::
Join
()
{
if
(
interceptor_thread_
.
joinable
())
{
interceptor_thread_
.
join
();
}
}
void
Interceptor
::
RegisterMsgHandle
(
MsgHandle
handle
)
{
handle_
=
handle
;
}
...
...
@@ -74,6 +80,9 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return
MessageBus
::
Instance
().
Send
(
msg
);
}
// maybe need a better method for interceptor base
void
Interceptor
::
HandleStop
(
const
InterceptorMessage
&
msg
)
{
stop_
=
true
;
}
void
Interceptor
::
PoolTheMailbox
()
{
// pool the local mailbox, parse the Message
for
(;;)
{
...
...
@@ -91,13 +100,18 @@ void Interceptor::PoolTheMailbox() {
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
interceptor_message
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
if
(
message_type
==
STOP
)
{
HandleStop
(
interceptor_message
);
}
else
{
Handle
(
interceptor_message
);
}
if
(
stop_
)
{
// break the pooling thread
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" is quiting."
;
break
;
}
Handle
(
interceptor_message
);
}
}
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
50f75fb5
...
...
@@ -43,9 +43,13 @@ class Interceptor {
virtual
~
Interceptor
();
void
Join
();
// register interceptor handle
void
RegisterMsgHandle
(
MsgHandle
handle
);
virtual
void
HandleStop
(
const
InterceptorMessage
&
msg
);
void
Handle
(
const
InterceptorMessage
&
msg
);
// return the interceptor id
...
...
@@ -64,6 +68,7 @@ class Interceptor {
protected:
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
bool
stop_
{
false
};
private:
// pool the local mailbox, parse the Message
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
50f75fb5
...
...
@@ -25,28 +25,6 @@ limitations under the License. */
namespace
paddle
{
namespace
distributed
{
class
StopInterceptor
:
public
Interceptor
{
public:
StopInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Stop
(
msg
);
});
}
void
Stop
(
const
InterceptorMessage
&
msg
)
{
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
count_
+=
1
;
if
(
count_
==
1
)
return
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
Send
(
2
,
stop
);
Send
(
3
,
stop
);
}
int
count_
{
0
};
};
class
StartInterceptor
:
public
Interceptor
{
public:
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
...
...
@@ -57,13 +35,20 @@ class StartInterceptor : public Interceptor {
void
NOP
(
const
InterceptorMessage
&
msg
)
{
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
++
count_
;
if
(
count_
==
3
)
{
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
msg
.
dst_id
(),
stop
);
// stop 0, this
Send
(
msg
.
src_id
(),
stop
);
// stop 1, compute
}
}
int
count_
{
0
};
};
TEST
(
ComputeInterceptor
,
Compute
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
Carrier
&
carrier
=
Carrier
::
Instance
();
...
...
@@ -71,27 +56,24 @@ TEST(ComputeInterceptor, Compute) {
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
0
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
0
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
0
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
0
,
0
);
// a->b->c
->d
// a->b->c
node_a
->
AddDownstreamTask
(
1
);
node_b
->
AddUpstreamTask
(
0
);
node_b
->
AddDownstreamTask
(
2
);
node_c
->
AddUpstreamTask
(
1
);
node_c
->
AddDownstreamTask
(
3
);
node_d
->
AddUpstreamTask
(
2
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
.
SetInterceptor
(
3
,
std
::
make_unique
<
StopInterceptor
>
(
3
,
node_c
));
carrier
.
SetCreatingFlag
(
false
);
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
// double buff, send twice
// test run three times
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录