Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
964e20e0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
964e20e0
编写于
11月 22, 2021
作者:
W
WangXi
提交者:
GitHub
11月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Add compute interceptor (#37376)
上级
9d3e1896
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
208 addition
and
8 deletion
+208
-8
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+2
-1
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+2
-0
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+61
-0
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+37
-0
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+1
-1
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+20
-6
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+8
-0
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+2
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+75
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
964e20e0
...
...
@@ -11,12 +11,13 @@ else()
endif
()
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc
compute_interceptor.cc
interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
compute_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_bus.h PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_bus.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
fleet_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
964e20e0
...
...
@@ -21,6 +21,8 @@
namespace
paddle
{
namespace
distributed
{
USE_INTERCEPTOR
(
Compute
);
void
Carrier
::
Init
(
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
0 → 100644
浏览文件 @
964e20e0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
ComputeInterceptor
::
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
PrepareDeps
();
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Compute
(
msg
);
});
}
void
ComputeInterceptor
::
PrepareDeps
()
{
auto
&
upstream
=
GetTaskNode
()
->
upstream
();
upstream_deps_
.
insert
(
upstream
.
begin
(),
upstream
.
end
());
}
void
ComputeInterceptor
::
SendDataReadyToDownStream
()
{
auto
&
downstream
=
GetTaskNode
()
->
downstream
();
for
(
auto
dst_id
:
downstream
)
{
InterceptorMessage
dst_msg
;
dst_msg
.
set_message_type
(
DATA_IS_READY
);
VLOG
(
3
)
<<
"ComputeInterceptor Send msg to "
<<
dst_id
;
Send
(
dst_id
,
dst_msg
);
}
}
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
auto
src_id
=
msg
.
src_id
();
upstream_deps_
.
erase
(
src_id
);
// all input is ready
if
(
upstream_deps_
.
empty
())
{
// TODO(wangxi): op run
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
SendDataReadyToDownStream
();
PrepareDeps
();
}
}
}
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
0 → 100644
浏览文件 @
964e20e0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
class
ComputeInterceptor
:
public
Interceptor
{
public:
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
void
PrepareDeps
();
void
SendDataReadyToDownStream
();
void
Compute
(
const
InterceptorMessage
&
msg
);
private:
std
::
unordered_set
<
int64_t
>
upstream_deps_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
964e20e0
...
...
@@ -76,7 +76,7 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
void
Interceptor
::
PoolTheMailbox
()
{
// pool the local mailbox, parse the Message
while
(
true
)
{
for
(;;
)
{
if
(
local_mailbox_
.
empty
())
{
// local mailbox is empty, fetch the remote mailbox
VLOG
(
3
)
<<
interceptor_id_
<<
"'s local mailbox is empty. "
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
964e20e0
...
...
@@ -62,6 +62,9 @@ class Interceptor {
DISABLE_COPY_AND_ASSIGN
(
Interceptor
);
protected:
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
private:
// pool the local mailbox, parse the Message
void
PoolTheMailbox
();
...
...
@@ -114,19 +117,30 @@ class InterceptorFactory {
int64_t
id
,
TaskNode
*
node
);
};
template
<
typename
InterceptorClass
>
std
::
unique_ptr
<
Interceptor
>
CreatorInterceptor
(
int64_t
id
,
TaskNode
*
node
)
{
return
std
::
make_unique
<
InterceptorClass
>
(
id
,
node
);
}
#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
std::unique_ptr<Interceptor> CreatorInterceptor_##interceptor_type( \
int64_t id, TaskNode* node) { \
return std::make_unique<interceptor_class>(id, node); \
} \
class __RegisterInterceptor_##interceptor_type { \
public: \
__RegisterInterceptor_##interceptor_type() { \
InterceptorFactory::Register(#interceptor_type, \
CreatorInterceptor
_##interceptor_type
); \
CreatorInterceptor
<interceptor_class>
); \
} \
void Touch() {} \
}; \
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type;
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type; \
int TouchRegisterInterceptor_##interceptor_type() { \
g_register_##interceptor_type.Touch(); \
return 0; \
}
#define USE_INTERCEPTOR(interceptor_type) \
extern int TouchRegisterInterceptor_##interceptor_type(); \
UNUSED static int use_interceptor_##interceptor_type = \
TouchRegisterInterceptor_##interceptor_type();
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_node.h
浏览文件 @
964e20e0
...
...
@@ -15,8 +15,10 @@
#pragma once
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
...
...
@@ -33,6 +35,7 @@ class TaskNode final {
TaskNode
(
int32_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
~
TaskNode
()
=
default
;
int64_t
rank
()
const
{
return
rank_
;
}
int64_t
task_id
()
const
{
return
task_id_
;
}
int32_t
role
()
const
{
return
role_
;
}
...
...
@@ -40,9 +43,12 @@ class TaskNode final {
int64_t
max_slot_nums
()
const
{
return
max_slot_nums_
;
}
const
std
::
unordered_set
<
int64_t
>&
upstream
()
const
{
return
upstream_
;
}
const
std
::
unordered_set
<
int64_t
>&
downstream
()
const
{
return
downstream_
;
}
const
std
::
string
&
type
()
const
{
return
type_
;
}
void
AddUpstreamTask
(
int64_t
task_id
);
void
AddDownstreamTask
(
int64_t
task_id
);
std
::
string
DebugString
()
const
;
static
std
::
unique_ptr
<
TaskNode
>
CreateEmptyTaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
...
...
@@ -63,6 +69,8 @@ class TaskNode final {
int64_t
task_id_
;
int64_t
max_run_times_
;
int64_t
max_slot_nums_
;
std
::
string
type_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
浏览文件 @
964e20e0
set_source_files_properties
(
interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
cc_test
(
compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
0 → 100644
浏览文件 @
964e20e0
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
class
StopInterceptor
:
public
Interceptor
{
public:
StopInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Stop
(
msg
);
});
}
void
Stop
(
const
InterceptorMessage
&
msg
)
{
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
Send
(
2
,
stop
);
}
};
TEST
(
ComputeInterceptor
,
Compute
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
Carrier
&
carrier
=
Carrier
::
Instance
();
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
0
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
0
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
0
,
0
);
// a->b->c
node_a
->
AddDownstreamTask
(
1
);
node_b
->
AddUpstreamTask
(
0
);
node_b
->
AddDownstreamTask
(
2
);
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
std
::
make_unique
<
StopInterceptor
>
(
2
,
node_c
));
carrier
.
SetCreatingFlag
(
false
);
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
a
->
Send
(
1
,
msg
);
}
}
// namespace distributed
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录