Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
be3b7740
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
be3b7740
编写于
11月 24, 2021
作者:
W
WangXi
提交者:
GitHub
11月 24, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Complete compute interceptor (#37485)
上级
1799c032
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
161 addition
and
23 deletion
+161
-23
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+116
-17
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+15
-1
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+30
-5
未找到文件。
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
be3b7740
...
...
@@ -27,31 +27,130 @@ ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
void
ComputeInterceptor
::
PrepareDeps
()
{
auto
&
upstream
=
GetTaskNode
()
->
upstream
();
upstream_deps_
.
insert
(
upstream
.
begin
(),
upstream
.
end
());
auto
&
downstream
=
GetTaskNode
()
->
downstream
();
// TODO(wangxi): get from task node
int64_t
in_buff_size
=
std
::
numeric_limits
<
int64_t
>::
max
();
int64_t
out_buff_size
=
2
;
for
(
auto
up_id
:
upstream
)
{
in_readys_
.
emplace
(
up_id
,
std
::
make_pair
(
in_buff_size
,
0
));
}
for
(
auto
down_id
:
downstream
)
{
out_buffs_
.
emplace
(
down_id
,
std
::
make_pair
(
out_buff_size
,
0
));
}
}
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
)
{
auto
it
=
in_readys_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
in_readys_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_readys."
,
up_id
));
auto
max_ready_size
=
it
->
second
.
first
;
auto
ready_size
=
it
->
second
.
second
;
ready_size
+=
1
;
PADDLE_ENFORCE_LE
(
ready_size
,
max_ready_size
,
platform
::
errors
::
OutOfRange
(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld"
,
up_id
,
ready_size
,
max_ready_size
));
it
->
second
.
second
=
ready_size
;
}
void
ComputeInterceptor
::
DecreaseBuff
(
int64_t
down_id
)
{
auto
it
=
out_buffs_
.
find
(
down_id
);
PADDLE_ENFORCE_NE
(
it
,
out_buffs_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find downstream=%lld in out_buffs."
,
down_id
));
auto
used_size
=
it
->
second
.
second
;
used_size
-=
1
;
PADDLE_ENFORCE_GE
(
used_size
,
0
,
platform
::
errors
::
OutOfRange
(
"downstream=%lld used buff size must >= 0, but now equal %lld"
,
down_id
,
used_size
));
it
->
second
.
second
=
used_size
;
}
bool
ComputeInterceptor
::
IsInputReady
()
{
for
(
auto
&
ins
:
in_readys_
)
{
auto
ready_size
=
ins
.
second
.
second
;
// not ready, return false
if
(
ready_size
==
0
)
return
false
;
}
return
true
;
}
bool
ComputeInterceptor
::
CanWriteOutput
()
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
max_buffer_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
// full, return false
if
(
used_size
==
max_buffer_size
)
return
false
;
}
return
true
;
}
void
ComputeInterceptor
::
SendDataReadyToDownStream
()
{
auto
&
downstream
=
GetTaskNode
()
->
downstream
();
for
(
auto
dst_id
:
downstream
)
{
InterceptorMessage
dst_msg
;
dst_msg
.
set_message_type
(
DATA_IS_READY
);
VLOG
(
3
)
<<
"ComputeInterceptor Send msg to "
<<
dst_id
;
Send
(
dst_id
,
dst_msg
);
for
(
auto
&
outs
:
out_buffs_
)
{
auto
down_id
=
outs
.
first
;
auto
max_buff_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
used_size
+=
1
;
PADDLE_ENFORCE_LE
(
used_size
,
max_buff_size
,
platform
::
errors
::
OutOfRange
(
"downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld"
,
down_id
,
used_size
,
max_buff_size
));
outs
.
second
.
second
=
used_size
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
VLOG
(
3
)
<<
"ComputeInterceptor Send data_is_ready msg to "
<<
down_id
;
Send
(
down_id
,
ready_msg
);
}
}
void
ComputeInterceptor
::
ReplyCompletedToUpStream
()
{
for
(
auto
&
ins
:
in_readys_
)
{
auto
up_id
=
ins
.
first
;
auto
ready_size
=
ins
.
second
.
second
;
ready_size
-=
1
;
PADDLE_ENFORCE_GE
(
ready_size
,
0
,
platform
::
errors
::
OutOfRange
(
"upstream=%lld ready_size must >= 0, but now got %lld"
,
up_id
,
ready_size
));
ins
.
second
.
second
=
ready_size
;
InterceptorMessage
reply_msg
;
reply_msg
.
set_message_type
(
DATE_IS_USELESS
);
VLOG
(
3
)
<<
"ComputeInterceptor Reply data_is_useless msg to "
<<
up_id
;
Send
(
up_id
,
reply_msg
);
}
}
void
ComputeInterceptor
::
Run
()
{
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
// TODO(wangxi): add op run
// send to downstream and increase buff used
SendDataReadyToDownStream
();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream
();
}
}
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
auto
src_id
=
msg
.
src_id
();
upstream_deps_
.
erase
(
src_id
);
// all input is ready
if
(
upstream_deps_
.
empty
())
{
// TODO(wangxi): op run
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
SendDataReadyToDownStream
();
PrepareDeps
();
}
IncreaseReady
(
msg
.
src_id
());
Run
();
}
else
if
(
msg
.
message_type
()
==
DATE_IS_USELESS
)
{
DecreaseBuff
(
msg
.
src_id
());
Run
();
}
}
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
浏览文件 @
be3b7740
...
...
@@ -14,6 +14,8 @@
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
...
...
@@ -25,12 +27,24 @@ class ComputeInterceptor : public Interceptor {
void
PrepareDeps
();
void
IncreaseReady
(
int64_t
up_id
);
void
DecreaseBuff
(
int64_t
down_id
);
bool
IsInputReady
();
bool
CanWriteOutput
();
void
SendDataReadyToDownStream
();
void
ReplyCompletedToUpStream
();
void
Run
();
void
Compute
(
const
InterceptorMessage
&
msg
);
private:
std
::
unordered_set
<
int64_t
>
upstream_deps_
;
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0
int64_t
step_
{
0
};
// upstream_id-->(max_ready_size, ready_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
in_readys_
{};
// downstream_id-->(max_buffer_size, used_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
be3b7740
...
...
@@ -35,17 +35,35 @@ class StopInterceptor : public Interceptor {
void
Stop
(
const
InterceptorMessage
&
msg
)
{
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
count_
+=
1
;
if
(
count_
==
1
)
return
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
Send
(
2
,
stop
);
Send
(
3
,
stop
);
}
int
count_
{
0
};
};
class
StartInterceptor
:
public
Interceptor
{
public:
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
}
};
TEST
(
ComputeInterceptor
,
Compute
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
Carrier
&
carrier
=
Carrier
::
Instance
();
...
...
@@ -53,21 +71,28 @@ TEST(ComputeInterceptor, Compute) {
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
0
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
0
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
0
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
0
,
0
);
// a->b->c
// a->b->c
->d
node_a
->
AddDownstreamTask
(
1
);
node_b
->
AddUpstreamTask
(
0
);
node_b
->
AddDownstreamTask
(
2
);
node_c
->
AddUpstreamTask
(
1
);
node_c
->
AddDownstreamTask
(
3
);
node_d
->
AddUpstreamTask
(
2
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
std
::
make_unique
<
StopInterceptor
>
(
2
,
node_c
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
.
SetInterceptor
(
3
,
std
::
make_unique
<
StopInterceptor
>
(
3
,
node_c
));
carrier
.
SetCreatingFlag
(
false
);
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
// double buff, send twice
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录