Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
989e39a5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
989e39a5
编写于
1月 17, 2023
作者:
L
LiYuRio
提交者:
GitHub
1月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modified compute and amplifier interceptor (#42044)
上级
39c6765a
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
95 addition
and
161 deletion
+95
-161
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
...fluid/distributed/fleet_executor/amplifier_interceptor.cc
+3
-3
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
.../fluid/distributed/fleet_executor/amplifier_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+48
-13
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+0
-2
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+19
-90
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+3
-10
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+0
-4
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/source_interceptor.h
paddle/fluid/distributed/fleet_executor/source_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+19
-34
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+0
-1
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+0
-1
未找到文件。
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
浏览文件 @
989e39a5
...
@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
...
@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
// 4, 3 --> run at step 3, 7, 11, 15
if
((
step
_
%
run_per_steps_
)
==
run_at_offset_
)
{
if
((
cur_scope_id
_
%
run_per_steps_
)
==
run_at_offset_
)
{
ComputeInterceptor
::
RunOps
();
ComputeInterceptor
::
RunOps
();
}
}
}
}
...
@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
...
@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
void
AmplifierInterceptor
::
SendDataReadyToDownStream
()
{
void
AmplifierInterceptor
::
SendDataReadyToDownStream
()
{
// run multi times, send ready one times to downstream, that is
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
// input multi times, output one times
if
(
step
_
%
send_down_per_steps_
==
0
)
{
if
(
cur_scope_id
_
%
send_down_per_steps_
==
0
)
{
ComputeInterceptor
::
SendDataReadyToDownStream
();
ComputeInterceptor
::
SendDataReadyToDownStream
();
}
}
}
}
...
@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
...
@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
void
AmplifierInterceptor
::
ReplyCompletedToUpStream
()
{
void
AmplifierInterceptor
::
ReplyCompletedToUpStream
()
{
// run multi times, reply one times to upstream, that is
// run multi times, reply one times to upstream, that is
// input one times, output multi times
// input one times, output multi times
if
(
step
_
%
reply_up_per_steps_
==
0
)
{
if
(
cur_scope_id
_
%
reply_up_per_steps_
==
0
)
{
ComputeInterceptor
::
ReplyCompletedToUpStream
();
ComputeInterceptor
::
ReplyCompletedToUpStream
();
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
浏览文件 @
989e39a5
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
class
AmplifierInterceptor
:
public
ComputeInterceptor
{
class
AmplifierInterceptor
final
:
public
ComputeInterceptor
{
public:
public:
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
989e39a5
...
@@ -71,6 +71,9 @@ void Carrier::Init(
...
@@ -71,6 +71,9 @@ void Carrier::Init(
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
}
}
// Add source and sink interceptor id to rank
interceptor_id_to_rank_
.
emplace
(
SOURCE_ID
,
rank
);
interceptor_id_to_rank_
.
emplace
(
SINK_ID
,
rank
);
// TODO(fleet_exe dev): thread pool
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_num_
=
1
;
...
@@ -159,16 +162,10 @@ void Carrier::Start() {
...
@@ -159,16 +162,10 @@ void Carrier::Start() {
true
,
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Using carrier before initialized."
));
"Using carrier before initialized."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
<<
"."
;
InterceptorMessage
start_msg
;
InterceptorMessage
start_msg
;
// source node data_is_ready is send by carrier, so set src_id=-1
start_msg
.
set_dst_id
(
SOURCE_ID
);
start_msg
.
set_src_id
(
-
1
);
start_msg
.
set_message_type
(
START
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
start_msg
);
Send
(
start_msg
);
}
// TODO(wangxi): async step
// TODO(wangxi): async step
Wait
();
Wait
();
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
...
@@ -270,6 +267,38 @@ void Carrier::CreateInterceptors() {
...
@@ -270,6 +267,38 @@ void Carrier::CreateInterceptors() {
auto
gc
=
GetGC
(
place_
);
auto
gc
=
GetGC
(
place_
);
// create source and sink task node
auto
max_run_times
=
microbatch_scopes_
.
size
();
TaskNode
*
source
=
new
TaskNode
(
rank_
,
SOURCE_ID
,
max_run_times
);
// rank, task_id, max_run_times
TaskNode
*
sink
=
new
TaskNode
(
rank_
,
SINK_ID
,
max_run_times
);
// find nodes without upstreams or without downstreams
std
::
vector
<
TaskNode
*>
origin_sources
,
origin_sinks
;
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
TaskNode
*
task_node
=
item
.
second
;
if
(
task_node
->
upstream
().
empty
())
{
origin_sources
.
emplace_back
(
task_node
);
}
if
(
task_node
->
downstream
().
empty
())
{
origin_sinks
.
emplace_back
(
task_node
);
}
}
// link source node with origin source
for
(
const
auto
&
node
:
origin_sources
)
{
source
->
AddDownstreamTask
(
node
->
task_id
(),
std
::
numeric_limits
<
int64_t
>::
max
());
node
->
AddUpstreamTask
(
SOURCE_ID
,
std
::
numeric_limits
<
int64_t
>::
max
());
}
// link sink node with origin sink
for
(
const
auto
&
node
:
origin_sinks
)
{
sink
->
AddUpstreamTask
(
node
->
task_id
(),
std
::
numeric_limits
<
int64_t
>::
max
());
node
->
AddDownstreamTask
(
SINK_ID
,
std
::
numeric_limits
<
int64_t
>::
max
());
}
// create source and sink interceptor
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// create each Interceptor
// create each Interceptor
// no auto init since there is no config
// no auto init since there is no config
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
...
@@ -303,9 +332,15 @@ void Carrier::CreateInterceptors() {
...
@@ -303,9 +332,15 @@ void Carrier::CreateInterceptors() {
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
<<
" with type: "
<<
task_node
->
type
()
<<
"."
;
<<
" with type: "
<<
task_node
->
type
()
<<
"."
;
if
(
task_node
->
upstream
().
empty
())
{
PADDLE_ENFORCE_EQ
(
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
task_node
->
upstream
().
empty
(),
}
false
,
platform
::
errors
::
PreconditionNotMet
(
"There should not have normal nodes as source nodes"
));
PADDLE_ENFORCE_EQ
(
task_node
->
downstream
().
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"There should not have normal nodes as sink nodes"
));
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
989e39a5
...
@@ -100,8 +100,6 @@ class Carrier final {
...
@@ -100,8 +100,6 @@ class Carrier final {
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
interceptor_idx_to_interceptor_
;
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
bool
is_init_
{
false
};
bool
is_init_
{
false
};
std
::
mutex
running_mutex_
;
std
::
mutex
running_mutex_
;
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
989e39a5
...
@@ -34,29 +34,10 @@ void ComputeInterceptor::PrepareDeps() {
...
@@ -34,29 +34,10 @@ void ComputeInterceptor::PrepareDeps() {
for
(
auto
up
:
upstream
)
{
for
(
auto
up
:
upstream
)
{
in_readys_
.
emplace
(
up
.
first
,
std
::
make_pair
(
up
.
second
,
0
));
in_readys_
.
emplace
(
up
.
first
,
std
::
make_pair
(
up
.
second
,
0
));
in_stops_
.
emplace
(
up
.
first
,
false
);
}
}
for
(
auto
down
:
downstream
)
{
for
(
auto
down
:
downstream
)
{
out_buffs_
.
emplace
(
down
.
first
,
std
::
make_pair
(
down
.
second
,
0
));
out_buffs_
.
emplace
(
down
.
first
,
std
::
make_pair
(
down
.
second
,
0
));
}
}
// source compute node, should we add a new SourceInterceptor?
if
(
upstream
.
empty
())
{
is_source_
=
true
;
PADDLE_ENFORCE_GT
(
node_
->
max_run_times
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld"
,
node_
->
max_run_times
()));
in_readys_
.
emplace
(
-
1
,
std
::
make_pair
(
std
::
numeric_limits
<
int64_t
>::
max
(),
0
));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_
=
downstream
.
empty
();
}
}
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
)
{
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
)
{
...
@@ -66,12 +47,6 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
...
@@ -66,12 +47,6 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_readys."
,
up_id
));
"Cannot find upstream=%lld in in_readys."
,
up_id
));
// source node has no upstream, data_is_ready is send by carrier or others
if
(
is_source_
&&
up_id
==
-
1
)
{
it
->
second
.
second
+=
GetTaskNode
()
->
max_run_times
();
return
;
}
auto
max_ready_size
=
it
->
second
.
first
;
auto
max_ready_size
=
it
->
second
.
first
;
auto
ready_size
=
it
->
second
.
second
;
auto
ready_size
=
it
->
second
.
second
;
ready_size
+=
1
;
ready_size
+=
1
;
...
@@ -152,7 +127,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
...
@@ -152,7 +127,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_message_type
(
DATA_IS_READY
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Send data_is_ready msg to "
<<
down_id
<<
" Send data_is_ready msg to "
<<
down_id
<<
"
for step: "
<<
step
_
;
<<
"
in scope: "
<<
cur_scope_id
_
;
Send
(
down_id
,
ready_msg
);
Send
(
down_id
,
ready_msg
);
}
}
}
}
...
@@ -173,8 +148,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -173,8 +148,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" for step: "
<<
step_
;
<<
" in scope: "
<<
cur_scope_id_
;
if
(
is_source_
&&
up_id
==
-
1
)
return
;
InterceptorMessage
reply_msg
;
InterceptorMessage
reply_msg
;
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
...
@@ -183,13 +157,17 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -183,13 +157,17 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
}
void
ComputeInterceptor
::
RunOps
()
{
void
ComputeInterceptor
::
RunOps
()
{
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
for
(
auto
op
:
node_
->
ops
())
{
op
->
Run
(
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
place_
);
PADDLE_ENFORCE_LT
(
cur_scope_id_
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
cur_scope_id_
));
op
->
Run
(
*
microbatch_scopes_
[
cur_scope_id_
],
place_
);
if
(
gc_
)
{
if
(
gc_
)
{
framework
::
DeleteUnusedTensors
(
framework
::
DeleteUnusedTensors
(
*
microbatch_scopes_
[
cur_scope_id_
],
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
op
,
op
,
node_
->
unused_vars
(),
node_
->
unused_vars
(),
gc_
.
get
());
gc_
.
get
());
...
@@ -201,77 +179,28 @@ void ComputeInterceptor::Run() {
...
@@ -201,77 +179,28 @@ void ComputeInterceptor::Run() {
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
// get the ready scope id from queue
cur_scope_id_
=
ready_queue_
.
front
();
ready_queue_
.
pop
();
RunOps
();
RunOps
();
++
step_
;
// send to downstream and increase buff used
// send to downstream and increase buff used
SendDataReadyToDownStream
();
SendDataReadyToDownStream
();
// reply to upstream and decrease ready data
// reply to upstream and decrease ready data
ReplyCompletedToUpStream
();
ReplyCompletedToUpStream
();
// Try to stop Carrier
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" is stopping carrier."
;
// FIXME(wangxi): with multi sink interceptor
StopCarrier
();
}
}
}
}
}
void
ComputeInterceptor
::
ReceivedStop
(
int64_t
up_id
)
{
received_stop_
=
true
;
// source node has no upstream, stop is send by carrier or others
if
(
is_source_
&&
up_id
==
-
1
)
return
;
auto
it
=
in_stops_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
in_stops_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_stops."
,
up_id
));
PADDLE_ENFORCE_EQ
(
it
->
second
,
false
,
platform
::
errors
::
AlreadyExists
(
"Already received stop from %lld, stop "
"cannot be send more than once."
));
it
->
second
=
true
;
}
void
ComputeInterceptor
::
TryStop
()
{
if
(
!
received_stop_
)
return
;
// can stop only when all upstream is stop and
// downstream complete
for
(
auto
&
in_stop
:
in_stops_
)
{
if
(
!
in_stop
.
second
)
return
;
}
for
(
auto
&
out_buff
:
out_buffs_
)
{
auto
used_size
=
out_buff
.
second
.
second
;
if
(
used_size
!=
0
)
return
;
}
// send stop to downstream
for
(
auto
&
out
:
out_buffs_
)
{
auto
down_id
=
out
.
first
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
down_id
,
stop
);
}
stop_
=
true
;
}
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
IncreaseReady
(
msg
.
src_id
());
IncreaseReady
(
msg
.
src_id
());
ready_queue_
.
push
(
msg
.
scope_idx
());
Run
();
Run
();
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
DecreaseBuff
(
msg
.
src_id
());
DecreaseBuff
(
msg
.
src_id
());
Run
();
Run
();
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
ReceivedStop
(
msg
.
src_id
());
}
}
TryStop
();
}
}
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
浏览文件 @
989e39a5
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <queue>
#include <utility>
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...
@@ -30,7 +31,8 @@ class ComputeInterceptor : public Interceptor {
...
@@ -30,7 +31,8 @@ class ComputeInterceptor : public Interceptor {
virtual
void
SendDataReadyToDownStream
();
virtual
void
SendDataReadyToDownStream
();
virtual
void
ReplyCompletedToUpStream
();
virtual
void
ReplyCompletedToUpStream
();
int64_t
step_
{
0
};
std
::
queue
<
int64_t
>
ready_queue_
;
int64_t
cur_scope_id_
;
private:
private:
void
PrepareDeps
();
void
PrepareDeps
();
...
@@ -43,19 +45,10 @@ class ComputeInterceptor : public Interceptor {
...
@@ -43,19 +45,10 @@ class ComputeInterceptor : public Interceptor {
void
Run
();
void
Run
();
void
Compute
(
const
InterceptorMessage
&
msg
);
void
Compute
(
const
InterceptorMessage
&
msg
);
void
ReceivedStop
(
int64_t
up_id
);
void
TryStop
();
bool
is_source_
{
false
};
bool
is_last_
{
false
};
// upstream_id-->(max_ready_size, ready_size)
// upstream_id-->(max_ready_size, ready_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
in_readys_
{};
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
in_readys_
{};
// downstream_id-->(max_buffer_size, used_size)
// downstream_id-->(max_buffer_size, used_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
bool
received_stop_
{
false
};
std
::
map
<
int64_t
,
bool
>
in_stops_
{};
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
989e39a5
...
@@ -93,7 +93,6 @@ class Interceptor {
...
@@ -93,7 +93,6 @@ class Interceptor {
TaskNode
*
node_
;
TaskNode
*
node_
;
// for stop
// for stop
bool
stop_
{
false
};
void
StopCarrier
();
void
StopCarrier
();
// for runtime
// for runtime
...
@@ -114,9 +113,6 @@ class Interceptor {
...
@@ -114,9 +113,6 @@ class Interceptor {
std
::
mutex
mutex_
;
std
::
mutex
mutex_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
};
};
class
InterceptorFactory
{
class
InterceptorFactory
{
...
...
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
浏览文件 @
989e39a5
...
@@ -25,7 +25,7 @@ namespace distributed {
...
@@ -25,7 +25,7 @@ namespace distributed {
* 1. record the num of micro-step
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
* 2. check whether to notify carrier the current step is finished
*/
*/
class
SinkInterceptor
:
public
Interceptor
{
class
SinkInterceptor
final
:
public
Interceptor
{
public:
public:
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/source_interceptor.h
浏览文件 @
989e39a5
...
@@ -25,7 +25,7 @@ namespace distributed {
...
@@ -25,7 +25,7 @@ namespace distributed {
* 1. receive `start` message from carrier
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
* 2. send num_of_steps `data_is_ready` message to downstream
*/
*/
class
SourceInterceptor
:
public
Interceptor
{
class
SourceInterceptor
final
:
public
Interceptor
{
public:
public:
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
989e39a5
...
@@ -25,57 +25,42 @@ limitations under the License. */
...
@@ -25,57 +25,42 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
class
StartInterceptor
:
public
Interceptor
{
public:
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
1
,
stop
);
// stop 1, compute
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
}
};
TEST
(
ComputeInterceptor
,
Compute
)
{
TEST
(
ComputeInterceptor
,
Compute
)
{
std
::
string
carrier_id
=
"0"
;
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
3
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
3
);
// a->b->c
// source->a->b->sink
source
->
AddDownstreamTask
(
0
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
);
node_a
->
AddDownstreamTask
(
1
,
3
);
node_a
->
AddDownstreamTask
(
1
,
3
);
node_b
->
AddUpstreamTask
(
0
,
3
);
node_b
->
AddUpstreamTask
(
0
,
3
);
node_b
->
AddDownstreamTask
(
2
);
node_b
->
AddDownstreamTask
(
SINK_ID
);
node_c
->
AddUpstreamTask
(
1
);
sink
->
AddUpstreamTask
(
1
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_message_type
(
START
);
// test run three times
msg
.
set_dst_id
(
SOURCE_ID
);
a
->
Send
(
1
,
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
carrier
->
Release
();
carrier
->
Release
();
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
989e39a5
...
@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
...
@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
return
;
return
;
}
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
989e39a5
...
@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
...
@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
StopCarrier
();
StopCarrier
();
return
;
return
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录