Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4974fdfd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4974fdfd
编写于
3月 31, 2022
作者:
L
LiYuRio
提交者:
GitHub
3月 31, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[FleetExecutor] Add source interceptor and test (#41122)
上级
7c555f4e
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
193 addition
and
4 deletion
+193
-4
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+2
-1
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+1
-0
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+2
-2
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
...luid/distributed/fleet_executor/interceptor_message.proto
+3
-1
paddle/fluid/distributed/fleet_executor/source_interceptor.cc
...le/fluid/distributed/fleet_executor/source_interceptor.cc
+57
-0
paddle/fluid/distributed/fleet_executor/source_interceptor.h
paddle/fluid/distributed/fleet_executor/source_interceptor.h
+41
-0
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+3
-0
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
...istributed/fleet_executor/test/source_interceptor_test.cc
+84
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
4974fdfd
...
...
@@ -13,7 +13,7 @@ endif()
cc_library
(
task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog
)
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc interceptor.cc
compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc
compute_interceptor.cc amplifier_interceptor.cc
source_interceptor.cc
message_service.cc message_bus.cc dist_model_tensor_wrapper.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
...
...
@@ -25,6 +25,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
compute_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
source_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_bus.h PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_bus.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
fleet_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
4974fdfd
...
...
@@ -28,6 +28,7 @@
namespace
paddle
{
namespace
distributed
{
USE_INTERCEPTOR
(
Source
);
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
4974fdfd
...
...
@@ -164,7 +164,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
if
(
up_id
==
-
1
)
return
;
InterceptorMessage
reply_msg
;
reply_msg
.
set_message_type
(
DAT
E
_IS_USELESS
);
reply_msg
.
set_message_type
(
DAT
A
_IS_USELESS
);
Send
(
up_id
,
reply_msg
);
}
}
...
...
@@ -247,7 +247,7 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
IncreaseReady
(
msg
.
src_id
());
Run
();
}
else
if
(
msg
.
message_type
()
==
DAT
E
_IS_USELESS
)
{
}
else
if
(
msg
.
message_type
()
==
DAT
A
_IS_USELESS
)
{
DecreaseBuff
(
msg
.
src_id
());
Run
();
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
浏览文件 @
4974fdfd
...
...
@@ -20,9 +20,10 @@ option cc_enable_arenas = true;
enum
MessageType
{
STOP
=
1
;
// STOP an Interceptor
DATA_IS_READY
=
2
;
// upstream data is ready
DAT
E
_IS_USELESS
=
3
;
// downstream has used the data
DAT
A
_IS_USELESS
=
3
;
// downstream has used the data
ERR
=
4
;
// current Interceptor encounters error
RESET
=
5
;
// reset the status
START
=
6
;
}
message
InterceptorMessage
{
...
...
@@ -30,6 +31,7 @@ message InterceptorMessage {
optional
int64
dst_id
=
2
[
default
=
0
];
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
int64
scope_idx
=
5
[
default
=
0
];
}
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
...
...
paddle/fluid/distributed/fleet_executor/source_interceptor.cc
0 → 100644
浏览文件 @
4974fdfd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/source_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
SourceInterceptor
::
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
),
max_run_times_
(
node
->
max_run_times
())
{
// prepare the downstream running status
for
(
const
auto
&
down
:
node
->
downstream
())
{
downstream_step_
.
emplace
(
down
.
first
,
0
);
}
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Run
(
msg
);
});
}
void
SourceInterceptor
::
SendDataReadyToDownStream
(
int64_t
downstream_id
)
{
int64_t
micro_step
=
downstream_step_
.
at
(
downstream_id
);
if
(
micro_step
>=
max_run_times_
)
{
return
;
}
int64_t
scope_idx
=
micro_step
%
max_run_times_
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
scope_idx
);
Send
(
downstream_id
,
ready_msg
);
downstream_step_
.
at
(
downstream_id
)
=
micro_step
+
1
;
}
void
SourceInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
START
)
{
// start run in a new step, reset the previous running status
for
(
const
auto
&
down
:
downstream_step_
)
{
downstream_step_
.
at
(
down
.
first
)
=
0
;
SendDataReadyToDownStream
(
down
.
first
);
}
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
SendDataReadyToDownStream
(
msg
.
src_id
());
}
}
REGISTER_INTERCEPTOR
(
Source
,
SourceInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/source_interceptor.h
0 → 100644
浏览文件 @
4974fdfd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
/*
* Source interceptor
* There is only one source in the runtime graph
* Take charge of:
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
*/
class
SourceInterceptor
:
public
Interceptor
{
public:
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
SendDataReadyToDownStream
(
int64_t
down_id
);
void
Run
(
const
InterceptorMessage
&
msg
);
int64_t
max_run_times_
;
// downstream_id->cur_step
std
::
map
<
int64_t
,
int64_t
>
downstream_step_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
浏览文件 @
4974fdfd
...
...
@@ -4,6 +4,9 @@ cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet
set_source_files_properties
(
compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
source_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
source_interceptor_test SRCS source_interceptor_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
...
...
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
0 → 100644
浏览文件 @
4974fdfd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
class
FakeInterceptor
:
public
Interceptor
{
public:
FakeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
step_
=
0
;
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
std
::
cout
<<
"FakeInterceptor run in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
InterceptorMessage
reply
;
reply
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
-
1
,
reply
);
step_
++
;
if
(
step_
==
node_
->
max_run_times
())
{
carrier_
->
WakeUp
();
}
}
}
private:
int64_t
step_
;
};
TEST
(
SourceInterceptor
,
Source
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
-
1
,
0
},
{
0
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
-
1
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
-
1
,
1
);
carrier
->
SetInterceptor
(
-
1
,
InterceptorFactory
::
Create
(
"Source"
,
-
1
,
source
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
-
1
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录