Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b3e79731
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b3e79731
编写于
4月 09, 2022
作者:
L
LiYuRio
提交者:
GitHub
4月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet executor] Add sink interceptor and test (#41497)
上级
330582e2
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
201 addition
and
1 deletion
+201
-1
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+2
-1
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+1
-0
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
+65
-0
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
+41
-0
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+3
-0
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
.../distributed/fleet_executor/test/sink_interceptor_test.cc
+89
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
b3e79731
...
...
@@ -13,7 +13,7 @@ endif()
cc_library
(
task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog
)
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc interceptor.cc
compute_interceptor.cc amplifier_interceptor.cc source_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc
compute_interceptor.cc amplifier_interceptor.cc source_interceptor.cc
sink_interceptor.cc
message_service.cc message_bus.cc dist_model_tensor_wrapper.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
...
...
@@ -26,6 +26,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
compute_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
source_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
sink_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_bus.h PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_bus.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
fleet_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
b3e79731
...
...
@@ -31,6 +31,7 @@ namespace distributed {
USE_INTERCEPTOR
(
Source
);
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
USE_INTERCEPTOR
(
Sink
);
void
Carrier
::
Init
(
int64_t
rank
,
...
...
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
0 → 100644
浏览文件 @
b3e79731
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/sink_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
SinkInterceptor
::
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
),
max_run_times_
(
node
->
max_run_times
())
{
// prepare the upstream running status
for
(
const
auto
&
up
:
node
->
upstream
())
{
upstream_step_
.
emplace
(
up
.
first
,
0
);
}
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Run
(
msg
);
});
}
void
SinkInterceptor
::
StopCarrierIfComplete
()
{
bool
flag
=
true
;
for
(
const
auto
&
up
:
upstream_step_
)
{
flag
=
flag
&
(
up
.
second
==
max_run_times_
);
}
if
(
flag
)
{
VLOG
(
3
)
<<
"Sink Interceptor is stopping carrier"
;
StopCarrier
();
for
(
const
auto
&
up
:
upstream_step_
)
{
upstream_step_
.
at
(
up
.
first
)
=
0
;
}
}
}
void
SinkInterceptor
::
ReplyCompletedToUpStream
(
int64_t
upstream_id
)
{
int64_t
micro_step
=
upstream_step_
.
at
(
upstream_id
);
int64_t
scope_idx
=
micro_step
%
max_run_times_
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_USELESS
);
msg
.
set_scope_idx
(
scope_idx
);
Send
(
upstream_id
,
msg
);
upstream_step_
.
at
(
upstream_id
)
=
micro_step
+
1
;
if
(
micro_step
==
max_run_times_
-
1
)
{
StopCarrierIfComplete
();
}
}
void
SinkInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
ReplyCompletedToUpStream
(
msg
.
src_id
());
}
}
REGISTER_INTERCEPTOR
(
Sink
,
SinkInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
0 → 100644
浏览文件 @
b3e79731
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
/*
* Sink interceptor
* There is only one sink in the runtime graph
* Take charge of:
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
*/
class
SinkInterceptor
:
public
Interceptor
{
public:
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
ReplyCompletedToUpStream
(
int64_t
up_id
);
void
Run
(
const
InterceptorMessage
&
msg
);
void
StopCarrierIfComplete
();
int64_t
max_run_times_
;
// upstream_id->cur_step
std
::
map
<
int64_t
,
int64_t
>
upstream_step_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
浏览文件 @
b3e79731
...
...
@@ -7,6 +7,9 @@ cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_exe
set_source_files_properties
(
source_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
source_interceptor_test SRCS source_interceptor_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
sink_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
sink_interceptor_test SRCS sink_interceptor_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
...
...
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
0 → 100644
浏览文件 @
b3e79731
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
class
FakeInterceptor
:
public
Interceptor
{
public:
FakeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
std
::
cout
<<
"FakeInterceptor run in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
InterceptorMessage
reply
;
reply
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
-
1
,
reply
);
InterceptorMessage
ready
;
ready
.
set_message_type
(
DATA_IS_READY
);
Send
(
-
2
,
ready
);
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
std
::
cout
<<
"FakeInterceptor remove result in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
}
}
private:
int64_t
step_
;
};
TEST
(
SourceInterceptor
,
Source
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
-
1
,
0
},
{
0
,
0
},
{
-
2
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
-
1
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
-
2
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
-
1
,
1
);
node_a
->
AddDownstreamTask
(
-
2
,
1
);
sink
->
AddUpstreamTask
(
0
,
1
);
carrier
->
SetInterceptor
(
-
1
,
InterceptorFactory
::
Create
(
"Source"
,
-
1
,
source
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
-
2
,
InterceptorFactory
::
Create
(
"Sink"
,
-
2
,
sink
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
-
1
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录