Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ddf38a3f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ddf38a3f
编写于
12月 02, 2021
作者:
W
WangXi
提交者:
GitHub
12月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Add amplifier interceptor and 1F1B scheduler test (#37755)
上级
c0d5b7ec
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
352 addition
and
12 deletion
+352
-12
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+2
-1
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
...fluid/distributed/fleet_executor/amplifier_interceptor.cc
+82
-0
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
.../fluid/distributed/fleet_executor/amplifier_interceptor.h
+43
-0
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+1
-0
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+9
-6
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+8
-5
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+17
-0
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+6
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+94
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+90
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
ddf38a3f
...
...
@@ -11,7 +11,7 @@ else()
endif
()
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc
amplifier_interceptor.cc
interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper
${
BRPC_DEPS
}
)
...
...
@@ -19,6 +19,7 @@ if(WITH_DISTRIBUTE)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
compute_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_bus.h PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
message_bus.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
fleet_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
0 → 100644
浏览文件 @
ddf38a3f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
distributed
{
AmplifierInterceptor
::
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
ComputeInterceptor
(
interceptor_id
,
node
)
{
run_per_steps_
=
node
->
run_per_steps
();
run_at_offset_
=
node
->
run_at_offset
();
reply_up_per_steps_
=
node
->
reply_up_per_steps
();
send_down_per_steps_
=
node
->
send_down_per_steps
();
PADDLE_ENFORCE_GE
(
run_per_steps_
,
1
,
platform
::
errors
::
InvalidArgument
(
"run_per_steps must >= 1, but now is %ld"
,
run_per_steps_
));
PADDLE_ENFORCE_GE
(
run_at_offset_
,
0
,
platform
::
errors
::
InvalidArgument
(
"run_at_offset must >= 0, but now is %ld"
,
run_at_offset_
));
PADDLE_ENFORCE_LT
(
run_at_offset_
,
run_per_steps_
,
platform
::
errors
::
InvalidArgument
(
"run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld"
,
run_at_offset_
,
run_per_steps_
));
PADDLE_ENFORCE_GE
(
reply_up_per_steps_
,
1
,
platform
::
errors
::
InvalidArgument
(
"reply_up_per_steps must >= 1, but now is %ld"
,
reply_up_per_steps_
));
PADDLE_ENFORCE_GE
(
send_down_per_steps_
,
1
,
platform
::
errors
::
InvalidArgument
(
"send_down_per_steps must >= 1, but now is %ld"
,
send_down_per_steps_
));
}
void
AmplifierInterceptor
::
RunOps
()
{
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
if
((
step_
%
run_per_steps_
)
==
run_at_offset_
)
{
ComputeInterceptor
::
RunOps
();
}
}
void
AmplifierInterceptor
::
SendDataReadyToDownStream
()
{
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
if
(
step_
%
send_down_per_steps_
==
0
)
{
ComputeInterceptor
::
SendDataReadyToDownStream
();
}
}
void
AmplifierInterceptor
::
ReplyCompletedToUpStream
()
{
// run multi times, reply one times to upstream, that is
// input one times, output multi times
if
(
step_
%
reply_up_per_steps_
==
0
)
{
ComputeInterceptor
::
ReplyCompletedToUpStream
();
}
}
REGISTER_INTERCEPTOR
(
Amplifier
,
AmplifierInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
0 → 100644
浏览文件 @
ddf38a3f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace
paddle
{
namespace
distributed
{
class
AmplifierInterceptor
:
public
ComputeInterceptor
{
public:
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
RunOps
()
override
;
void
SendDataReadyToDownStream
()
override
;
void
ReplyCompletedToUpStream
()
override
;
int64_t
run_per_steps_
{
1
};
int64_t
run_at_offset_
{
0
};
// one input produces multi times output
int64_t
reply_up_per_steps_
{
1
};
// one output need multi times input
int64_t
send_down_per_steps_
{
1
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
ddf38a3f
...
...
@@ -24,6 +24,7 @@ namespace paddle {
namespace
distributed
{
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
void
Carrier
::
Init
(
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph
,
framework
::
Scope
*
root_scope
,
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
ddf38a3f
...
...
@@ -160,15 +160,18 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
}
void
ComputeInterceptor
::
RunOps
()
{
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops."
;
for
(
auto
op
:
node_
->
ops
())
{
op
->
Run
(
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
place_
);
}
}
void
ComputeInterceptor
::
Run
()
{
while
(
IsInputReady
()
&&
CanWriteOutput
()
&&
!
ShouldReset
())
{
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
// step_ %= node_->max_run_times();
for
(
auto
op
:
node_
->
ops
())
{
auto
*
scope
=
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()];
op
->
Run
(
*
scope
,
place_
);
}
RunOps
();
++
step_
;
// send to downstream and increase buff used
...
...
@@ -176,7 +179,7 @@ void ComputeInterceptor::Run() {
// reply to upstream and decrease ready data
ReplyCompletedToUpStream
();
// Try to stop Carrier
if
(
step_
%
node_
->
max_run_times
()
==
0
&&
is_last_
)
{
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
)
)
{
StopCarrier
();
}
}
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
浏览文件 @
ddf38a3f
...
...
@@ -25,6 +25,14 @@ class ComputeInterceptor : public Interceptor {
public:
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
protected:
virtual
void
RunOps
();
virtual
void
SendDataReadyToDownStream
();
virtual
void
ReplyCompletedToUpStream
();
int64_t
step_
{
0
};
private:
void
PrepareDeps
();
void
IncreaseReady
(
int64_t
up_id
);
...
...
@@ -33,19 +41,14 @@ class ComputeInterceptor : public Interceptor {
bool
CanWriteOutput
();
bool
ShouldReset
();
void
SendDataReadyToDownStream
();
void
ReplyCompletedToUpStream
();
void
Run
();
void
Compute
(
const
InterceptorMessage
&
msg
);
void
ReceivedStop
(
int64_t
up_id
);
void
TryStop
();
private:
bool
is_source_
{
false
};
bool
is_last_
{
false
};
int64_t
step_
{
0
};
// upstream_id-->(max_ready_size, ready_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
in_readys_
{};
...
...
paddle/fluid/distributed/fleet_executor/task_node.h
浏览文件 @
ddf38a3f
...
...
@@ -44,12 +44,22 @@ class TaskNode final {
int32_t
role
()
const
{
return
role_
;
}
int64_t
max_run_times
()
const
{
return
max_run_times_
;
}
int64_t
max_slot_nums
()
const
{
return
max_slot_nums_
;
}
int64_t
run_per_steps
()
const
{
return
run_per_steps_
;
}
int64_t
run_at_offset
()
const
{
return
run_at_offset_
;
}
int64_t
reply_up_per_steps
()
const
{
return
reply_up_per_steps_
;
}
int64_t
send_down_per_steps
()
const
{
return
send_down_per_steps_
;
}
const
std
::
unordered_set
<
int64_t
>&
upstream
()
const
{
return
upstream_
;
}
const
std
::
unordered_set
<
int64_t
>&
downstream
()
const
{
return
downstream_
;
}
const
std
::
string
&
type
()
const
{
return
type_
;
}
const
paddle
::
framework
::
ProgramDesc
&
program
()
const
{
return
program_
;
}
const
std
::
vector
<
OperatorBase
*>&
ops
()
const
{
return
ops_
;
}
void
SetRunPerSteps
(
int64_t
value
)
{
run_per_steps_
=
value
;
}
void
SetRunAtOffset
(
int64_t
value
)
{
run_at_offset_
=
value
;
}
void
SetReplyUpPerSteps
(
int64_t
value
)
{
reply_up_per_steps_
=
value
;
}
void
SetSendDownPerSteps
(
int64_t
value
)
{
send_down_per_steps_
=
value
;
}
void
SetType
(
const
std
::
string
&
type
)
{
type_
=
type
;
}
bool
AddUpstreamTask
(
int64_t
task_id
);
bool
AddDownstreamTask
(
int64_t
task_id
);
std
::
string
DebugString
()
const
;
...
...
@@ -76,6 +86,13 @@ class TaskNode final {
int64_t
max_run_times_
;
int64_t
max_slot_nums_
;
int64_t
run_per_steps_
{
1
};
int64_t
run_at_offset_
{
0
};
// one input produces multi times output
int64_t
reply_up_per_steps_
{
1
};
// one output need multi times input
int64_t
send_down_per_steps_
{
1
};
std
::
string
type_
;
};
...
...
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
浏览文件 @
ddf38a3f
...
...
@@ -4,6 +4,12 @@ cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet
set_source_files_properties
(
compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
interceptor_pipeline_long_path_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
op_registry fill_constant_op elementwise_add_op scope device_context
)
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
0 → 100644
浏览文件 @
ddf38a3f
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
void
LinkNodes
(
const
std
::
vector
<
TaskNode
*>&
nodes
)
{
size_t
size
=
nodes
.
size
();
if
(
size
<=
1
)
return
;
{
// i = 0
TaskNode
*
now
=
nodes
[
0
];
TaskNode
*
next
=
nodes
[
1
];
now
->
AddDownstreamTask
(
next
->
task_id
());
}
{
// i = size - 1
TaskNode
*
prev
=
nodes
[
size
-
2
];
TaskNode
*
now
=
nodes
[
size
-
1
];
now
->
AddUpstreamTask
(
prev
->
task_id
());
}
for
(
size_t
i
=
1
;
i
<
size
-
1
;
++
i
)
{
TaskNode
*
prev
=
nodes
[
i
-
1
];
TaskNode
*
now
=
nodes
[
i
];
TaskNode
*
next
=
nodes
[
i
+
1
];
now
->
AddUpstreamTask
(
prev
->
task_id
());
now
->
AddDownstreamTask
(
next
->
task_id
());
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
int64_t
micro_steps
=
3
;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
,
0
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
,
0
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
,
0
);
// a->b->c->d->e->f
LinkNodes
({
node_a
,
node_b
,
node_c
,
node_d
,
node_e
,
node_f
});
// LR->b(1:3)->F->B->e(3:1)->U
node_b
->
SetReplyUpPerSteps
(
micro_steps
);
node_e
->
SetSendDownPerSteps
(
micro_steps
);
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
.
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Compute"
,
3
,
node_d
));
carrier
.
SetInterceptor
(
4
,
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
carrier
.
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
carrier
.
SetCreatingFlag
(
false
);
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
0 → 100644
浏览文件 @
ddf38a3f
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
void
LinkNodes
(
const
std
::
vector
<
TaskNode
*>&
nodes
)
{
size_t
size
=
nodes
.
size
();
if
(
size
<=
1
)
return
;
{
// i = 0
TaskNode
*
now
=
nodes
[
0
];
TaskNode
*
next
=
nodes
[
1
];
now
->
AddDownstreamTask
(
next
->
task_id
());
}
{
// i = size - 1
TaskNode
*
prev
=
nodes
[
size
-
2
];
TaskNode
*
now
=
nodes
[
size
-
1
];
now
->
AddUpstreamTask
(
prev
->
task_id
());
}
for
(
size_t
i
=
1
;
i
<
size
-
1
;
++
i
)
{
TaskNode
*
prev
=
nodes
[
i
-
1
];
TaskNode
*
now
=
nodes
[
i
];
TaskNode
*
next
=
nodes
[
i
+
1
];
now
->
AddUpstreamTask
(
prev
->
task_id
());
now
->
AddDownstreamTask
(
next
->
task_id
());
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
&
carrier
=
Carrier
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{{
0
,
""
}},
""
);
int64_t
micro_steps
=
3
;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
micro_steps
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
,
0
);
// a->b->c->d
LinkNodes
({
node_a
,
node_b
,
node_c
,
node_d
});
node_a
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunAtOffset
(
micro_steps
-
1
);
carrier
.
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
carrier
.
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
.
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
.
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
carrier
.
SetCreatingFlag
(
false
);
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_src_id
(
-
1
);
msg
.
set_dst_id
(
0
);
carrier
.
EnqueueInterceptorMessage
(
msg
);
}
}
// namespace distributed
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录