Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
92c2dcbd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
92c2dcbd
编写于
3月 20, 2023
作者:
L
LiYuRio
提交者:
GitHub
3月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cherry-pick fleet executor and auto parallel (#50071)
上级
4bacf2ab
变更
74
展开全部
显示空白变更内容
内联
并排
Showing
74 changed file
with
7436 addition
and
3086 deletion
+7436
-3086
cmake/third_party.cmake
cmake/third_party.cmake
+2
-1
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+6
-0
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
...fluid/distributed/fleet_executor/amplifier_interceptor.cc
+3
-3
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
.../fluid/distributed/fleet_executor/amplifier_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+90
-38
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+3
-3
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+261
-125
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+17
-20
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
+167
-0
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
+55
-0
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+136
-38
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+5
-2
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+0
-4
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
...luid/distributed/fleet_executor/interceptor_message.proto
+16
-0
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/source_interceptor.h
paddle/fluid/distributed/fleet_executor/source_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/start_interceptor.cc
paddle/fluid/distributed/fleet_executor/start_interceptor.cc
+115
-0
paddle/fluid/distributed/fleet_executor/start_interceptor.h
paddle/fluid/distributed/fleet_executor/start_interceptor.h
+39
-0
paddle/fluid/distributed/fleet_executor/task_node.cc
paddle/fluid/distributed/fleet_executor/task_node.cc
+26
-36
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+44
-19
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+24
-36
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+0
-1
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+0
-1
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+7
-7
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+4
-5
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
.../distributed/fleet_executor/test/sink_interceptor_test.cc
+3
-4
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
...istributed/fleet_executor/test/source_interceptor_test.cc
+2
-3
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
+2
-0
paddle/fluid/operators/collective/c_embedding_op.cu
paddle/fluid/operators/collective/c_embedding_op.cu
+78
-22
paddle/fluid/pybind/bind_fleet_executor.cc
paddle/fluid/pybind/bind_fleet_executor.cc
+9
-6
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+8
-0
python/paddle/distributed/auto_parallel/completion.py
python/paddle/distributed/auto_parallel/completion.py
+798
-390
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+10
-0
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
...on/paddle/distributed/auto_parallel/cost/estimate_cost.py
+2
-2
python/paddle/distributed/auto_parallel/dist_context.py
python/paddle/distributed/auto_parallel/dist_context.py
+454
-173
python/paddle/distributed/auto_parallel/dist_op.py
python/paddle/distributed/auto_parallel/dist_op.py
+101
-53
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+109
-71
python/paddle/distributed/auto_parallel/interface.py
python/paddle/distributed/auto_parallel/interface.py
+60
-30
python/paddle/distributed/auto_parallel/operators/__init__.py
...on/paddle/distributed/auto_parallel/operators/__init__.py
+1
-0
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+120
-72
python/paddle/distributed/auto_parallel/operators/dist_default.py
...addle/distributed/auto_parallel/operators/dist_default.py
+158
-86
python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py
..._parallel/operators/dist_fill_constant_batch_size_like.py
+23
-31
python/paddle/distributed/auto_parallel/operators/dist_scale.py
.../paddle/distributed/auto_parallel/operators/dist_scale.py
+90
-0
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+215
-106
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+159
-66
python/paddle/distributed/auto_parallel/process_group.py
python/paddle/distributed/auto_parallel/process_group.py
+47
-30
python/paddle/distributed/auto_parallel/process_mesh.py
python/paddle/distributed/auto_parallel/process_mesh.py
+35
-23
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+1107
-559
python/paddle/distributed/auto_parallel/strategy.py
python/paddle/distributed/auto_parallel/strategy.py
+14
-10
python/paddle/distributed/auto_parallel/tuner/profiler.py
python/paddle/distributed/auto_parallel/tuner/profiler.py
+48
-31
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+6
-0
python/paddle/distributed/fleet/fleet_executor_utils.py
python/paddle/distributed/fleet/fleet_executor_utils.py
+195
-131
python/paddle/distributed/parallel.py
python/paddle/distributed/parallel.py
+7
-0
python/paddle/distributed/passes/__init__.py
python/paddle/distributed/passes/__init__.py
+1
-0
python/paddle/distributed/passes/auto_parallel_grad_clip.py
python/paddle/distributed/passes/auto_parallel_grad_clip.py
+69
-37
python/paddle/distributed/passes/auto_parallel_pipeline.py
python/paddle/distributed/passes/auto_parallel_pipeline.py
+635
-0
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+961
-542
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
.../fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
+7
-3
python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py
...tests/unittests/auto_parallel/clip_grad_by_global_norm.py
+7
-5
python/paddle/fluid/tests/unittests/auto_parallel/generation_pipeline_pass_unittest.py
...ttests/auto_parallel/generation_pipeline_pass_unittest.py
+177
-0
python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py
...s/unittests/auto_parallel/gradient_merge_pass_unittest.py
+10
-11
python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py
.../tests/unittests/auto_parallel/recompute_pass_unittest.py
+4
-2
python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py
...d/tests/unittests/auto_parallel/sharding_pass_unittest.py
+13
-11
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py
.../fluid/tests/unittests/auto_parallel/test_dist_context.py
+121
-69
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_generation_pipeline.py
.../unittests/auto_parallel/test_pass_generation_pipeline.py
+58
-0
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
...addle/fluid/tests/unittests/test_auto_parallel_reshard.py
+154
-82
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py
...luid/tests/unittests/test_auto_parallel_reshard_serial.py
+76
-48
python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py
...d/tests/unittests/test_fleet_executor_cond_interceptor.py
+217
-0
python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py
...le/fluid/tests/unittests/test_fleet_executor_task_node.py
+17
-10
python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py
...id/tests/unittests/test_fleet_executor_with_task_nodes.py
+20
-22
python/paddle/tensor/stat.py
python/paddle/tensor/stat.py
+1
-0
未找到文件。
cmake/third_party.cmake
浏览文件 @
92c2dcbd
...
@@ -426,7 +426,8 @@ endif()
...
@@ -426,7 +426,8 @@ endif()
if
(
WITH_DISTRIBUTE
if
(
WITH_DISTRIBUTE
AND NOT WITH_PSLIB
AND NOT WITH_PSLIB
AND NOT WITH_PSCORE
)
AND NOT WITH_PSCORE
AND NOT WITH_RPC
)
include
(
external/snappy
)
include
(
external/snappy
)
list
(
APPEND third_party_deps extern_snappy
)
list
(
APPEND third_party_deps extern_snappy
)
...
...
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
92c2dcbd
...
@@ -36,6 +36,8 @@ cc_library(
...
@@ -36,6 +36,8 @@ cc_library(
interceptor.cc
interceptor.cc
compute_interceptor.cc
compute_interceptor.cc
amplifier_interceptor.cc
amplifier_interceptor.cc
cond_interceptor.cc
start_interceptor.cc
source_interceptor.cc
source_interceptor.cc
sink_interceptor.cc
sink_interceptor.cc
message_service.cc
message_service.cc
...
@@ -66,6 +68,10 @@ if(WITH_DISTRIBUTE)
...
@@ -66,6 +68,10 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
set_source_files_properties
(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
start_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
set_source_files_properties
(
source_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
source_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
set_source_files_properties
(
...
...
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
浏览文件 @
92c2dcbd
...
@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
...
@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
// 4, 3 --> run at step 3, 7, 11, 15
if
((
step
_
%
run_per_steps_
)
==
run_at_offset_
)
{
if
((
cur_scope_id
_
%
run_per_steps_
)
==
run_at_offset_
)
{
ComputeInterceptor
::
RunOps
();
ComputeInterceptor
::
RunOps
();
}
}
}
}
...
@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
...
@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
void
AmplifierInterceptor
::
SendDataReadyToDownStream
()
{
void
AmplifierInterceptor
::
SendDataReadyToDownStream
()
{
// run multi times, send ready one times to downstream, that is
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
// input multi times, output one times
if
(
step
_
%
send_down_per_steps_
==
0
)
{
if
(
cur_scope_id
_
%
send_down_per_steps_
==
0
)
{
ComputeInterceptor
::
SendDataReadyToDownStream
();
ComputeInterceptor
::
SendDataReadyToDownStream
();
}
}
}
}
...
@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
...
@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
void
AmplifierInterceptor
::
ReplyCompletedToUpStream
()
{
void
AmplifierInterceptor
::
ReplyCompletedToUpStream
()
{
// run multi times, reply one times to upstream, that is
// run multi times, reply one times to upstream, that is
// input one times, output multi times
// input one times, output multi times
if
(
step
_
%
reply_up_per_steps_
==
0
)
{
if
(
cur_scope_id
_
%
reply_up_per_steps_
==
0
)
{
ComputeInterceptor
::
ReplyCompletedToUpStream
();
ComputeInterceptor
::
ReplyCompletedToUpStream
();
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
浏览文件 @
92c2dcbd
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
class
AmplifierInterceptor
:
public
ComputeInterceptor
{
class
AmplifierInterceptor
final
:
public
ComputeInterceptor
{
public:
public:
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
92c2dcbd
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include <algorithm>
#include <algorithm>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...
@@ -24,6 +25,7 @@
...
@@ -24,6 +25,7 @@
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -33,6 +35,8 @@ USE_INTERCEPTOR(Source);
...
@@ -33,6 +35,8 @@ USE_INTERCEPTOR(Source);
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
USE_INTERCEPTOR
(
Amplifier
);
USE_INTERCEPTOR
(
Sink
);
USE_INTERCEPTOR
(
Sink
);
USE_INTERCEPTOR
(
Cond
);
USE_INTERCEPTOR
(
Start
);
void
Carrier
::
Init
(
void
Carrier
::
Init
(
int64_t
rank
,
int64_t
rank
,
...
@@ -54,23 +58,37 @@ void Carrier::Init(
...
@@ -54,23 +58,37 @@ void Carrier::Init(
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
rank_
=
rank
;
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
place_
=
place
;
place_
=
place
;
root_scope_
=
scope
;
root_scope_
=
scope
;
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
bool
need_create_scope
=
micro_scope_list
.
empty
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
root_scope_
,
platform
::
errors
::
InvalidArgument
(
"root_scope can not be nullptr"
));
platform
::
errors
::
InvalidArgument
(
"root_scope can not be nullptr"
));
if
(
need_create_scope
)
{
minibatch_scope_
=
&
root_scope_
->
NewScope
();
minibatch_scope_
=
&
root_scope_
->
NewScope
();
microbatch_scopes_
.
resize
(
num_micro_batches
);
microbatch_scopes_
.
resize
(
num_micro_batches
);
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
}
}
}
else
{
microbatch_scopes_
=
micro_scope_list
;
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
}
}
// Add source and sink interceptor id to rank
interceptor_id_to_rank_
.
emplace
(
SOURCE_ID
,
rank
);
interceptor_id_to_rank_
.
emplace
(
SINK_ID
,
rank
);
// TODO(fleet_exe dev): thread pool
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_num_
=
1
;
...
@@ -93,18 +111,18 @@ void Carrier::CopyParameters(
...
@@ -93,18 +111,18 @@ void Carrier::CopyParameters(
int
microbatch_id
,
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
,
const
framework
::
ProgramDesc
&
program
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
auto
&
global_block
=
program
.
Block
(
0
);
std
::
map
<
std
::
string
,
int
>
inference_root_scope_var_map
;
std
::
map
<
std
::
string
,
int
>
inference_root_scope_var_map
;
for
(
auto
var_name
:
inference_root_scope_vars
)
{
for
(
auto
var_name
:
inference_root_scope_vars
)
{
inference_root_scope_var_map
.
insert
({
var_name
,
1
});
inference_root_scope_var_map
.
insert
({
var_name
,
1
});
}
}
for
(
auto
&
var
:
global_block
.
AllVars
())
{
for
(
size_t
i
=
0
;
i
<
program
.
Size
();
++
i
)
{
for
(
auto
&
var
:
program
.
Block
(
i
).
AllVars
())
{
std
::
string
var_name
=
var
->
Name
();
std
::
string
var_name
=
var
->
Name
();
bool
force_root
=
inference_root_scope_var_map
.
find
(
var_name
)
!=
bool
force_root
=
inference_root_scope_var_map
.
find
(
var_name
)
!=
inference_root_scope_var_map
.
end
();
inference_root_scope_var_map
.
end
();
if
(
force_root
)
{
if
(
force_root
)
{
VLOG
(
4
)
<<
var_name
<<
" will be forced to be created in the root scope."
;
VLOG
(
4
)
<<
var_name
<<
" will be forced to be created in the root scope."
;
}
}
if
((
var
->
Persistable
()
||
force_root
)
&&
microbatch_id
==
0
)
{
if
((
var
->
Persistable
()
||
force_root
)
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
...
@@ -118,6 +136,7 @@ void Carrier::CopyParameters(
...
@@ -118,6 +136,7 @@ void Carrier::CopyParameters(
InitializeVariable
(
ptr
,
var
->
GetType
());
InitializeVariable
(
ptr
,
var
->
GetType
());
}
}
}
}
}
}
}
bool
Carrier
::
EnqueueInterceptorMessage
(
bool
Carrier
::
EnqueueInterceptorMessage
(
...
@@ -159,16 +178,11 @@ void Carrier::Start() {
...
@@ -159,16 +178,11 @@ void Carrier::Start() {
true
,
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Using carrier before initialized."
));
"Using carrier before initialized."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
<<
"."
;
InterceptorMessage
start_msg
;
InterceptorMessage
start_msg
;
// source node data_is_ready is send by carrier, so set src_id=-1
start_msg
.
set_src_id
(
SOURCE_ID
);
start_msg
.
set_src_id
(
-
1
);
start_msg
.
set_dst_id
(
SOURCE_ID
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_message_type
(
START
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
start_msg
);
Send
(
start_msg
);
}
// TODO(wangxi): async step
// TODO(wangxi): async step
Wait
();
Wait
();
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
...
@@ -270,6 +284,38 @@ void Carrier::CreateInterceptors() {
...
@@ -270,6 +284,38 @@ void Carrier::CreateInterceptors() {
auto
gc
=
GetGC
(
place_
);
auto
gc
=
GetGC
(
place_
);
// create source and sink task node
auto
max_run_times
=
microbatch_scopes_
.
size
();
TaskNode
*
source
=
new
TaskNode
(
rank_
,
SOURCE_ID
,
max_run_times
);
// rank, task_id, max_run_times
TaskNode
*
sink
=
new
TaskNode
(
rank_
,
SINK_ID
,
max_run_times
);
// find nodes without upstreams or without downstreams
std
::
vector
<
TaskNode
*>
origin_sources
,
origin_sinks
;
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
TaskNode
*
task_node
=
item
.
second
;
if
(
task_node
->
upstream
().
empty
())
{
origin_sources
.
emplace_back
(
task_node
);
}
if
(
task_node
->
downstream
().
empty
())
{
origin_sinks
.
emplace_back
(
task_node
);
}
}
// link source node with origin source
for
(
const
auto
&
node
:
origin_sources
)
{
source
->
AddDownstreamTask
(
node
->
task_id
(),
std
::
numeric_limits
<
int64_t
>::
max
());
node
->
AddUpstreamTask
(
SOURCE_ID
,
std
::
numeric_limits
<
int64_t
>::
max
());
}
// link sink node with origin sink
for
(
const
auto
&
node
:
origin_sinks
)
{
sink
->
AddUpstreamTask
(
node
->
task_id
(),
std
::
numeric_limits
<
int64_t
>::
max
());
node
->
AddDownstreamTask
(
SINK_ID
,
std
::
numeric_limits
<
int64_t
>::
max
());
}
// create source and sink interceptor
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// create each Interceptor
// create each Interceptor
// no auto init since there is no config
// no auto init since there is no config
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
...
@@ -303,9 +349,15 @@ void Carrier::CreateInterceptors() {
...
@@ -303,9 +349,15 @@ void Carrier::CreateInterceptors() {
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
<<
" with type: "
<<
task_node
->
type
()
<<
"."
;
<<
" with type: "
<<
task_node
->
type
()
<<
"."
;
if
(
task_node
->
upstream
().
empty
())
{
PADDLE_ENFORCE_EQ
(
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
task_node
->
upstream
().
empty
(),
}
false
,
platform
::
errors
::
PreconditionNotMet
(
"There should not have normal nodes as source nodes"
));
PADDLE_ENFORCE_EQ
(
task_node
->
downstream
().
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"There should not have normal nodes as sink nodes"
));
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
92c2dcbd
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
...
@@ -60,7 +61,8 @@ class Carrier final {
...
@@ -60,7 +61,8 @@ class Carrier final {
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
void
CopyParameters
(
void
CopyParameters
(
int
microbatch_id
,
int
microbatch_id
,
...
@@ -100,8 +102,6 @@ class Carrier final {
...
@@ -100,8 +102,6 @@ class Carrier final {
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
interceptor_idx_to_interceptor_
;
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
bool
is_init_
{
false
};
bool
is_init_
{
false
};
std
::
mutex
running_mutex_
;
std
::
mutex
running_mutex_
;
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
92c2dcbd
...
@@ -18,10 +18,85 @@
...
@@ -18,10 +18,85 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/core/utils/dim.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
namespace
{
template
<
typename
T
>
void
SetVarResult
(
const
std
::
string
&
name
,
T
value
,
int64_t
scope_id
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
int64_t
>&
dim_vec
)
{
auto
*
var
=
scope
->
FindVar
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
!
var
)
{
VLOG
(
3
)
<<
"Create var and memory for var "
<<
name
;
var
=
scope
->
Var
(
name
);
phi
::
DDim
dims
=
phi
::
make_ddim
(
dim_vec
);
tensor
->
Resize
(
dims
);
tensor
->
mutable_data
<
T
>
(
dims
,
place
);
}
PADDLE_ENFORCE_EQ
(
tensor
->
dims
().
size
(),
1
,
platform
::
errors
::
OutOfRange
(
"Only support transfer size 1 value."
));
PADDLE_ENFORCE_EQ
(
tensor
->
dims
().
at
(
0
),
1
,
platform
::
errors
::
OutOfRange
(
"Only support transfer size 1 value."
));
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi
::
DenseTensor
cpu_tensor
;
auto
dim
=
phi
::
make_ddim
({
1
});
cpu_tensor
.
mutable_data
<
T
>
(
dim
,
platform
::
CPUPlace
());
auto
*
cpu_tensor_ptr
=
cpu_tensor
.
data
<
T
>
();
cpu_tensor_ptr
[
0
]
=
value
;
framework
::
TensorCopySync
(
cpu_tensor
,
tensor
->
place
(),
tensor
);
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport device for cond interceptor."
));
}
}
template
<
typename
T
>
T
GetVarResult
(
const
std
::
string
&
name
,
int64_t
scope_id
,
framework
::
Scope
*
scope
)
{
auto
*
var
=
scope
->
FindVar
(
name
);
PADDLE_ENFORCE
(
var
,
platform
::
errors
::
NotFound
(
"Variable %s not exists in scope %ld"
,
name
,
scope_id
));
const
auto
&
tensor
=
var
->
Get
<
phi
::
DenseTensor
>
();
T
res
;
if
(
platform
::
is_gpu_place
(
tensor
.
place
()))
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi
::
DenseTensor
cpu_tensor
;
framework
::
TensorCopySync
(
tensor
,
platform
::
CPUPlace
(),
&
cpu_tensor
);
res
=
cpu_tensor
.
data
<
T
>
()[
0
];
#endif
}
else
if
(
platform
::
is_cpu_place
(
tensor
.
place
()))
{
res
=
tensor
.
data
<
T
>
()[
0
];
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport device for cond interceptor."
));
}
return
res
;
}
}
// namespace
ComputeInterceptor
::
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
ComputeInterceptor
::
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
:
Interceptor
(
interceptor_id
,
node
)
{
PrepareDeps
();
PrepareDeps
();
...
@@ -33,49 +108,33 @@ void ComputeInterceptor::PrepareDeps() {
...
@@ -33,49 +108,33 @@ void ComputeInterceptor::PrepareDeps() {
auto
&
downstream
=
node_
->
downstream
();
auto
&
downstream
=
node_
->
downstream
();
for
(
auto
up
:
upstream
)
{
for
(
auto
up
:
upstream
)
{
in_readys_
.
emplace
(
up
.
first
,
std
::
make_pair
(
up
.
second
,
0
));
std
::
map
<
int64_t
,
int64_t
>
ready_size_map
;
in_stops_
.
emplace
(
up
.
first
,
false
);
for
(
int64_t
i
=
0
;
i
<
node_
->
max_run_times
();
++
i
)
{
ready_size_map
.
emplace
(
i
,
0
);
}
in_readys_
.
emplace
(
up
.
first
,
std
::
make_pair
(
up
.
second
,
ready_size_map
));
}
}
for
(
auto
down
:
downstream
)
{
for
(
auto
down
:
downstream
)
{
out_buffs_
.
emplace
(
down
.
first
,
std
::
make_pair
(
down
.
second
,
0
));
out_buffs_
.
emplace
(
down
.
first
,
std
::
make_pair
(
down
.
second
,
0
));
}
}
// source compute node, should we add a new SourceInterceptor?
if
(
upstream
.
empty
())
{
is_source_
=
true
;
PADDLE_ENFORCE_GT
(
node_
->
max_run_times
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld"
,
node_
->
max_run_times
()));
in_readys_
.
emplace
(
-
1
,
std
::
make_pair
(
std
::
numeric_limits
<
int64_t
>::
max
(),
0
));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_
=
downstream
.
empty
();
}
}
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
)
{
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
,
int64_t
scope_id
)
{
auto
it
=
in_readys_
.
find
(
up_id
);
auto
it
=
in_readys_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
PADDLE_ENFORCE_NE
(
it
,
in_readys_
.
end
(),
in_readys_
.
end
(),
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_readys."
,
up_id
));
"Cannot find upstream=%lld in in_readys."
,
up_id
));
// source node has no upstream, data_is_ready is send by carrier or others
if
(
is_source_
&&
up_id
==
-
1
)
{
it
->
second
.
second
+=
GetTaskNode
()
->
max_run_times
();
return
;
}
auto
max_ready_size
=
it
->
second
.
first
;
auto
max_ready_size
=
it
->
second
.
first
;
auto
ready_size
=
it
->
second
.
second
;
const
auto
&
ready_scope_map
=
it
->
second
.
second
;
ready_size
+=
1
;
int64_t
ready_size
=
0
;
PADDLE_ENFORCE_LE
(
ready_size
,
for
(
auto
&
scope_iter
:
ready_scope_map
)
{
ready_size
+=
scope_iter
.
second
;
}
if
(
max_ready_size
!=
INFINITE_BUFFER_SIZE
)
{
PADDLE_ENFORCE_LE
(
ready_size
,
max_ready_size
,
max_ready_size
,
platform
::
errors
::
OutOfRange
(
platform
::
errors
::
OutOfRange
(
"upstream=%lld ready_size must <= max_ready_size, but "
"upstream=%lld ready_size must <= max_ready_size, but "
...
@@ -83,7 +142,15 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
...
@@ -83,7 +142,15 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) {
up_id
,
up_id
,
ready_size
,
ready_size
,
max_ready_size
));
max_ready_size
));
it
->
second
.
second
=
ready_size
;
}
PADDLE_ENFORCE_NE
(
it
->
second
.
second
.
find
(
scope_id
),
it
->
second
.
second
.
end
(),
platform
::
errors
::
OutOfRange
(
"Interceptor %lld can not find scope %lld in upstream ready map"
,
interceptor_id_
,
scope_id
));
it
->
second
.
second
.
at
(
scope_id
)
=
ready_scope_map
.
at
(
scope_id
)
+
1
;
}
}
void
ComputeInterceptor
::
DecreaseBuff
(
int64_t
down_id
)
{
void
ComputeInterceptor
::
DecreaseBuff
(
int64_t
down_id
)
{
...
@@ -105,22 +172,40 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
...
@@ -105,22 +172,40 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
}
}
bool
ComputeInterceptor
::
IsInputReady
()
{
bool
ComputeInterceptor
::
IsInputReady
()
{
for
(
int64_t
i
=
0
;
i
<
node_
->
max_run_times
();
++
i
)
{
bool
flag
=
true
;
for
(
auto
&
ins
:
in_readys_
)
{
for
(
auto
&
ins
:
in_readys_
)
{
auto
ready_size
=
ins
.
second
.
second
;
auto
ready_size_map
=
ins
.
second
.
second
;
// not ready, return false
flag
=
flag
&&
(
ready_size_map
.
at
(
i
)
!=
0
);
if
(
ready_size
==
0
)
{
}
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
if
(
flag
)
{
<<
"'s upstreams aren't all ready."
;
for
(
auto
iter
:
scope_id_to_finish_flag_
)
{
if
(
iter
.
first
==
i
)
{
break
;
}
else
if
(
!
iter
.
second
)
{
VLOG
(
3
)
<<
"The previous scope is not ready, waiting for the "
"previous scope "
<<
iter
.
first
;
return
false
;
return
false
;
}
}
}
}
cur_scope_id_
=
i
;
return
true
;
return
true
;
}
else
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" in scope "
<<
i
<<
"'s upstreams aren't all ready."
;
}
}
return
false
;
}
}
bool
ComputeInterceptor
::
CanWriteOutput
()
{
bool
ComputeInterceptor
::
CanWriteOutput
()
{
for
(
auto
&
outs
:
out_buffs_
)
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
max_buffer_size
=
outs
.
second
.
first
;
auto
max_buffer_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
auto
used_size
=
outs
.
second
.
second
;
if
(
max_buffer_size
==
INFINITE_BUFFER_SIZE
)
{
continue
;
}
// full, return false
// full, return false
if
(
used_size
==
max_buffer_size
)
{
if
(
used_size
==
max_buffer_size
)
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
...
@@ -137,6 +222,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
...
@@ -137,6 +222,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
auto
max_buff_size
=
outs
.
second
.
first
;
auto
max_buff_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
auto
used_size
=
outs
.
second
.
second
;
used_size
+=
1
;
used_size
+=
1
;
if
(
max_buff_size
!=
INFINITE_BUFFER_SIZE
)
{
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_LE
(
used_size
,
used_size
,
max_buff_size
,
max_buff_size
,
...
@@ -146,21 +232,66 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
...
@@ -146,21 +232,66 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
down_id
,
down_id
,
used_size
,
used_size
,
max_buff_size
));
max_buff_size
));
}
outs
.
second
.
second
=
used_size
;
outs
.
second
.
second
=
used_size
;
bool
need_send_vars
=
!
(
node_
->
vars_to_dtype
().
empty
());
if
(
need_send_vars
)
{
InterceptorMessage
ready_msg
=
PrepareVarsMsg
();
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Send data_with_vars msg to "
<<
down_id
<<
" in scope: "
<<
cur_scope_id_
;
Send
(
down_id
,
ready_msg
);
}
else
{
InterceptorMessage
ready_msg
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Send data_is_ready msg to "
<<
down_id
<<
" Send data_is_ready msg to "
<<
down_id
<<
" for step: "
<<
step
_
;
<<
" in scope: "
<<
cur_scope_id
_
;
Send
(
down_id
,
ready_msg
);
Send
(
down_id
,
ready_msg
);
}
}
}
}
InterceptorMessage
ComputeInterceptor
::
PrepareVarsMsg
()
{
PADDLE_ENFORCE_LT
(
cur_scope_id_
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
cur_scope_id_
));
auto
*
scope
=
microbatch_scopes_
[
cur_scope_id_
];
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_WITH_VARS
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
for
(
auto
iter
:
node_
->
vars_to_dtype
())
{
VarList
*
vars
=
ready_msg
.
add_vars_list
();
const
auto
&
var_name
=
iter
.
first
;
vars
->
set_name
(
var_name
);
std
::
ostringstream
ss
;
auto
&
dev_ctx
=
*
pool
.
Get
(
place_
);
auto
*
var
=
scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE
(
var
,
platform
::
errors
::
NotFound
(
"Variable %s not exists in scope %ld"
,
var_name
,
cur_scope_id_
));
const
auto
&
tensor
=
var
->
Get
<
phi
::
DenseTensor
>
();
SerializeToStream
(
ss
,
tensor
,
dev_ctx
);
vars
->
set_stensor
(
ss
.
str
());
VLOG
(
3
)
<<
"Prepare vars msg "
<<
var_name
<<
" with dimension "
<<
tensor
.
dims
()
<<
" dtype "
<<
tensor
.
dtype
();
}
return
ready_msg
;
}
}
void
ComputeInterceptor
::
ReplyCompletedToUpStream
()
{
void
ComputeInterceptor
::
ReplyCompletedToUpStream
()
{
for
(
auto
&
ins
:
in_readys_
)
{
for
(
auto
&
ins
:
in_readys_
)
{
auto
up_id
=
ins
.
first
;
auto
up_id
=
ins
.
first
;
auto
ready_size
=
ins
.
second
.
second
;
auto
ready_size
=
ins
.
second
.
second
.
at
(
cur_scope_id_
)
;
ready_size
-=
1
;
ready_size
-=
1
;
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
ready_size
,
ready_size
,
...
@@ -169,27 +300,31 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -169,27 +300,31 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
"upstream=%lld ready_size must >= 0, but now got %lld"
,
"upstream=%lld ready_size must >= 0, but now got %lld"
,
up_id
,
up_id
,
ready_size
));
ready_size
));
ins
.
second
.
second
=
ready_size
;
ins
.
second
.
second
[
cur_scope_id_
]
=
ready_size
;
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" for step: "
<<
step_
;
<<
" in scope: "
<<
cur_scope_id_
;
if
(
is_source_
&&
up_id
==
-
1
)
return
;
InterceptorMessage
reply_msg
;
InterceptorMessage
reply_msg
;
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
reply_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
up_id
,
reply_msg
);
Send
(
up_id
,
reply_msg
);
}
}
}
}
void
ComputeInterceptor
::
RunOps
()
{
void
ComputeInterceptor
::
RunOps
()
{
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
for
(
auto
op
:
node_
->
ops
())
{
op
->
Run
(
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
place_
);
PADDLE_ENFORCE_LT
(
cur_scope_id_
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
cur_scope_id_
));
op
->
Run
(
*
microbatch_scopes_
[
cur_scope_id_
],
place_
);
if
(
gc_
)
{
if
(
gc_
)
{
framework
::
DeleteUnusedTensors
(
framework
::
DeleteUnusedTensors
(
*
microbatch_scopes_
[
cur_scope_id_
],
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
op
,
op
,
node_
->
unused_vars
(),
node_
->
unused_vars
(),
gc_
.
get
());
gc_
.
get
());
...
@@ -199,79 +334,80 @@ void ComputeInterceptor::RunOps() {
...
@@ -199,79 +334,80 @@ void ComputeInterceptor::RunOps() {
void
ComputeInterceptor
::
Run
()
{
void
ComputeInterceptor
::
Run
()
{
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running in scope "
<<
cur_scope_id_
;
RunOps
();
RunOps
();
++
step_
;
if
(
!
scope_id_to_finish_flag_
.
empty
())
{
PADDLE_ENFORCE_NE
(
scope_id_to_finish_flag_
.
find
(
cur_scope_id_
),
scope_id_to_finish_flag_
.
end
(),
platform
::
errors
::
NotFound
(
"Can not find scope %ld in scope_id_to_finish"
,
cur_scope_id_
));
scope_id_to_finish_flag_
.
erase
(
cur_scope_id_
);
}
// send to downstream and increase buff used
// send to downstream and increase buff used
SendDataReadyToDownStream
();
SendDataReadyToDownStream
();
// reply to upstream and decrease ready data
// reply to upstream and decrease ready data
ReplyCompletedToUpStream
();
ReplyCompletedToUpStream
();
// Try to stop Carrier
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" is stopping carrier."
;
// FIXME(wangxi): with multi sink interceptor
StopCarrier
();
}
}
}
}
void
ComputeInterceptor
::
ReceivedStop
(
int64_t
up_id
)
{
received_stop_
=
true
;
// source node has no upstream, stop is send by carrier or others
if
(
is_source_
&&
up_id
==
-
1
)
return
;
auto
it
=
in_stops_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
in_stops_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_stops."
,
up_id
));
PADDLE_ENFORCE_EQ
(
it
->
second
,
false
,
platform
::
errors
::
AlreadyExists
(
"Already received stop from %lld, stop "
"cannot be send more than once."
));
it
->
second
=
true
;
}
}
void
ComputeInterceptor
::
TryStop
()
{
void
ComputeInterceptor
::
DecodeMsgVars
(
const
InterceptorMessage
&
msg
)
{
if
(
!
received_stop_
)
return
;
int64_t
scope_id
=
msg
.
scope_idx
();
PADDLE_ENFORCE_LT
(
scope_id
,
// can stop only when all upstream is stop and
microbatch_scopes_
.
size
(),
// downstream complete
platform
::
errors
::
InvalidArgument
(
for
(
auto
&
in_stop
:
in_stops_
)
{
"Step out of range. There are %ld "
if
(
!
in_stop
.
second
)
return
;
"microbatch_scopes, but recevice scope index %ld"
,
}
microbatch_scopes_
.
size
(),
for
(
auto
&
out_buff
:
out_buffs_
)
{
scope_id
));
auto
used_size
=
out_buff
.
second
.
second
;
auto
*
scope
=
microbatch_scopes_
[
scope_id
];
if
(
used_size
!=
0
)
return
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
}
for
(
const
auto
&
var_iter
:
msg
.
vars_list
())
{
const
std
::
string
&
name
=
var_iter
.
name
();
// send stop to downstream
auto
&
dev_ctx
=
*
pool
.
Get
(
place_
);
for
(
auto
&
out
:
out_buffs_
)
{
std
::
istringstream
ss
(
var_iter
.
stensor
());
auto
down_id
=
out
.
first
;
auto
*
var
=
scope
->
Var
(
name
);
InterceptorMessage
stop
;
auto
*
tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
stop
.
set_message_type
(
STOP
);
DeserializeFromStream
(
ss
,
tensor
,
dev_ctx
);
Send
(
down_id
,
stop
);
VLOG
(
3
)
<<
"Set vars "
<<
name
<<
" with value in scope "
<<
scope_id
<<
" with dims "
<<
tensor
->
dims
()
<<
" with dtype "
<<
tensor
->
dtype
();
}
}
stop_
=
true
;
}
}
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
IncreaseReady
(
msg
.
src_id
());
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive data_is_ready "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
Run
();
Run
();
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive data_is_useless "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
DecreaseBuff
(
msg
.
src_id
());
DecreaseBuff
(
msg
.
src_id
());
Run
();
Run
();
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
}
else
if
(
msg
.
message_type
()
==
DATA_WITH_VARS
)
{
ReceivedStop
(
msg
.
src_id
());
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive data_with_vars "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
DecodeMsgVars
(
msg
);
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
Run
();
}
else
if
(
msg
.
message_type
()
==
START_LOOP
)
{
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive start_loop "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
scope_id_to_finish_flag_
.
emplace
(
msg
.
scope_idx
(),
false
);
Run
();
}
}
TryStop
();
}
}
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
浏览文件 @
92c2dcbd
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <queue>
#include <utility>
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...
@@ -21,6 +22,8 @@
...
@@ -21,6 +22,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
const
int64_t
INFINITE_BUFFER_SIZE
=
-
1
;
class
ComputeInterceptor
:
public
Interceptor
{
class
ComputeInterceptor
:
public
Interceptor
{
public:
public:
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
@@ -29,33 +32,27 @@ class ComputeInterceptor : public Interceptor {
...
@@ -29,33 +32,27 @@ class ComputeInterceptor : public Interceptor {
virtual
void
RunOps
();
virtual
void
RunOps
();
virtual
void
SendDataReadyToDownStream
();
virtual
void
SendDataReadyToDownStream
();
virtual
void
ReplyCompletedToUpStream
();
virtual
void
ReplyCompletedToUpStream
();
virtual
void
Compute
(
const
InterceptorMessage
&
msg
);
void
Run
();
void
IncreaseReady
(
int64_t
up_id
,
int64_t
scope_id
);
void
DecreaseBuff
(
int64_t
down_id
);
int64_t
cur_scope_id_
;
int64_t
step_
{
0
};
// upstream_id-->(max_ready_size, scope-->ready_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
std
::
map
<
int64_t
,
int64_t
>>>
in_readys_
{};
// downstream_id-->(max_buffer_size, used_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
private:
private:
void
PrepareDeps
();
void
PrepareDeps
();
InterceptorMessage
PrepareVarsMsg
();
void
DecodeMsgVars
(
const
InterceptorMessage
&
msg
);
void
IncreaseReady
(
int64_t
up_id
);
void
DecreaseBuff
(
int64_t
down_id
);
bool
IsInputReady
();
bool
IsInputReady
();
bool
CanWriteOutput
();
bool
CanWriteOutput
();
std
::
map
<
int64_t
,
bool
>
scope_id_to_finish_flag_
;
void
Run
();
void
Compute
(
const
InterceptorMessage
&
msg
);
void
ReceivedStop
(
int64_t
up_id
);
void
TryStop
();
bool
is_source_
{
false
};
bool
is_last_
{
false
};
// upstream_id-->(max_ready_size, ready_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
in_readys_
{};
// downstream_id-->(max_buffer_size, used_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
bool
received_stop_
{
false
};
std
::
map
<
int64_t
,
bool
>
in_stops_
{};
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
0 → 100644
浏览文件 @
92c2dcbd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include <algorithm>
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
distributed
{
CondInterceptor
::
CondInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
PrepareDeps
();
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Run
(
msg
);
});
}
void
CondInterceptor
::
PrepareDeps
()
{
auto
&
upstream
=
node_
->
upstream
();
auto
&
downstream
=
node_
->
downstream
();
auto
&
id_to_dep_type
=
node_
->
id_to_dep_type
();
for
(
const
auto
&
up
:
upstream
)
{
if
(
id_to_dep_type
.
at
(
up
.
first
)
==
DependType
::
NORMAL
)
{
normal_in_id_
.
insert
(
up
.
first
);
}
else
if
(
id_to_dep_type
.
at
(
up
.
first
)
==
DependType
::
LOOP
)
{
loop_id_
=
up
.
first
;
}
}
for
(
const
auto
&
down
:
downstream
)
{
if
(
id_to_dep_type
.
at
(
down
.
first
)
==
DependType
::
NORMAL
)
{
normal_out_id_
.
insert
(
down
.
first
);
}
else
if
(
id_to_dep_type
.
at
(
down
.
first
)
==
DependType
::
STOP_LOOP
)
{
stop_loop_id_
=
down
.
first
;
}
}
}
bool
CondInterceptor
::
GetCondResult
()
{
PADDLE_ENFORCE_LT
(
cur_scope_id_
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
cur_scope_id_
));
auto
*
cond_var
=
microbatch_scopes_
[
cur_scope_id_
]
->
FindVar
(
node_
->
cond_var
());
PADDLE_ENFORCE
(
cond_var
,
platform
::
errors
::
NotFound
(
"Condition variable %s not exists in scope %ld"
,
node_
->
cond_var
(),
cur_scope_id_
));
const
auto
&
cond_tensor
=
cond_var
->
Get
<
phi
::
DenseTensor
>
();
bool
res
=
false
;
if
(
platform
::
is_gpu_place
(
cond_tensor
.
place
()))
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi
::
DenseTensor
cpu_tensor
;
framework
::
TensorCopy
(
cond_tensor
,
platform
::
CPUPlace
(),
&
cpu_tensor
);
platform
::
DeviceContextPool
::
Instance
().
Get
(
cond_tensor
.
place
())
->
Wait
();
res
=
cpu_tensor
.
data
<
bool
>
()[
0
];
#endif
}
else
if
(
platform
::
is_cpu_place
(
cond_tensor
.
place
()))
{
res
=
cond_tensor
.
data
<
bool
>
()[
0
];
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport device for cond interceptor."
));
}
return
res
;
}
void
CondInterceptor
::
SendDataReady
(
int64_t
down_id
)
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
down_id
,
ready_msg
);
}
void
CondInterceptor
::
SendStartLoop
(
int64_t
down_id
)
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
START_LOOP
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
down_id
,
ready_msg
);
}
void
CondInterceptor
::
ReplyDataIsUseless
(
int64_t
up_id
)
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_USELESS
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
up_id
,
ready_msg
);
}
void
CondInterceptor
::
Compute
()
{
bool
cond
=
GetCondResult
();
VLOG
(
3
)
<<
"Cond interceptor get condition var "
<<
node_
->
cond_var
()
<<
" with value "
<<
cond
;
if
(
cond
)
{
VLOG
(
3
)
<<
"Loop again in scope "
<<
cur_scope_id_
;
for
(
auto
&
down_id
:
normal_out_id_
)
{
SendStartLoop
(
down_id
);
}
++
num_of_scopes_
;
}
else
{
VLOG
(
3
)
<<
"Finish loop in scope "
<<
cur_scope_id_
;
SendDataReady
(
stop_loop_id_
);
}
}
void
CondInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
||
msg
.
message_type
()
==
DATA_WITH_VARS
)
{
if
(
msg
.
src_id
()
==
loop_id_
)
{
--
num_of_scopes_
;
VLOG
(
3
)
<<
"Receving loop again message from "
<<
msg
.
src_id
()
<<
" waiting other "
<<
num_of_scopes_
<<
" scopes ready"
;
ready_scope_id_
.
emplace_back
(
msg
.
scope_idx
());
if
(
num_of_scopes_
==
0
)
{
std
::
sort
(
ready_scope_id_
.
begin
(),
ready_scope_id_
.
end
());
for
(
auto
scope_id
:
ready_scope_id_
)
{
VLOG
(
3
)
<<
"Start a new loop in scope "
<<
scope_id
;
cur_scope_id_
=
scope_id
;
Compute
();
}
ready_scope_id_
.
clear
();
}
}
else
{
cur_scope_id_
=
msg
.
scope_idx
();
Compute
();
}
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
if
(
node_
->
id_to_dep_type
().
at
(
msg
.
src_id
())
==
DependType
::
STOP_LOOP
)
{
for
(
auto
&
up_id
:
normal_in_id_
)
{
ReplyDataIsUseless
(
up_id
);
}
// Gc the variable in while block
int64_t
scope_id
=
msg
.
scope_idx
();
if
(
gc_
)
{
VLOG
(
3
)
<<
"Release vars in while block in scope "
<<
scope_id
;
framework
::
DeleteUnusedTensors
(
*
microbatch_scopes_
[
scope_id
],
node_
->
while_block_vars
(),
gc_
.
get
());
}
}
}
}
REGISTER_INTERCEPTOR
(
Cond
,
CondInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
0 → 100644
浏览文件 @
92c2dcbd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iomanip>
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
/* Condition Interceptor
* This is a special interceptor and only one condition op in the task node.
* This interceptor has two downstreams,
* 1. If the program result is true, select one of the downstreams, otherwise
* select another.
* 2. Used to implement while op in program.
*/
class
CondInterceptor
final
:
public
Interceptor
{
public:
CondInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
PrepareDeps
();
void
Run
(
const
InterceptorMessage
&
msg
);
void
Compute
();
bool
GetCondResult
();
void
SendDataReady
(
int64_t
down_id
);
void
SendStartLoop
(
int64_t
down_id
);
void
ReplyDataIsUseless
(
int64_t
up_id
);
int64_t
cur_scope_id_
;
std
::
set
<
int64_t
>
normal_in_id_
;
std
::
set
<
int64_t
>
normal_out_id_
;
int64_t
stop_loop_id_
;
int64_t
loop_id_
;
int64_t
num_of_scopes_
{
0
};
std
::
vector
<
int64_t
>
ready_scope_id_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
92c2dcbd
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include <algorithm>
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
@@ -24,6 +26,7 @@
...
@@ -24,6 +26,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/variable.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -51,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
...
@@ -51,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
}
}
}
}
void
FleetExecutor
::
Init
(
namespace
{
const
std
::
string
&
carrier_id
,
void
GetSubBlockTask
(
const
std
::
vector
<
TaskNode
*>&
tasks
,
const
framework
::
ProgramDesc
&
program_desc
,
TaskNode
*
cur_task
,
framework
::
Scope
*
scope
,
std
::
set
<
TaskNode
*>*
sub_block_task
)
{
const
platform
::
Place
&
place
,
auto
&
downstream
=
cur_task
->
downstream
();
int64_t
num_micro_batches
,
auto
&
id_to_dep_type
=
cur_task
->
id_to_dep_type
();
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
for
(
auto
&
down
:
downstream
)
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
int64_t
task_id
=
down
.
first
;
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
if
(
id_to_dep_type
.
at
(
task_id
)
==
DependType
::
NORMAL
)
{
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
for
(
const
auto
&
task
:
tasks
)
{
0
,
if
(
task
->
task_id
()
==
task_id
)
{
platform
::
errors
::
InvalidArgument
(
sub_block_task
->
emplace
(
task
);
"Fleet executor is inited with empty task node"
));
GetSubBlockTask
(
tasks
,
task
,
sub_block_task
);
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
for
(
auto
task_node
:
task_nodes
)
{
for
(
auto
op
:
task_node
->
ops
())
{
ops
.
emplace_back
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
));
}
}
}
}
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
}
// NOTE: For inference, the vars in inference_root_scope_vars
}
// shouldn't be deleted during inf, for that they may be the result of the
}
// inf. If they are GCed, it will cause error during ZeroCopy the result.
void
PreventVarsDelete
(
std
::
unordered_map
<
const
framework
::
OperatorBase
*
,
std
::
vector
<
std
::
string
>>*
unused_vars
,
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
std
::
vector
<
const
framework
::
OperatorBase
*>
changed_ops
;
std
::
vector
<
const
framework
::
OperatorBase
*>
changed_ops
;
for
(
auto
pair
:
unused_vars
)
{
for
(
const
auto
&
pair
:
*
unused_vars
)
{
const
framework
::
OperatorBase
*
op
=
pair
.
first
;
const
framework
::
OperatorBase
*
op
=
pair
.
first
;
std
::
vector
<
std
::
string
>
unused
=
pair
.
second
;
std
::
vector
<
std
::
string
>
cur_
unused
=
pair
.
second
;
for
(
auto
name
:
inference_root_scope_vars
)
{
for
(
auto
name
:
vars_not_gc
)
{
auto
iter
=
std
::
find
(
unused
.
begin
(),
unused
.
end
(),
name
);
auto
iter
=
std
::
find
(
cur_unused
.
begin
(),
cur_
unused
.
end
(),
name
);
if
(
iter
!=
unused
.
end
())
{
if
(
iter
!=
cur_
unused
.
end
())
{
VLOG
(
3
)
<<
"Removing var: ["
<<
name
VLOG
(
3
)
<<
"Removing var: ["
<<
name
<<
"] from the unused vars list of op: ["
<<
op
->
Type
()
<<
"]"
;
<<
"] from the unused vars list of op: ["
<<
op
->
Type
()
<<
"]"
;
unused
.
erase
(
iter
);
cur_
unused
.
erase
(
iter
);
if
(
std
::
find
(
changed_ops
.
begin
(),
changed_ops
.
end
(),
op
)
==
if
(
std
::
find
(
changed_ops
.
begin
(),
changed_ops
.
end
(),
op
)
==
changed_ops
.
end
())
{
changed_ops
.
end
())
{
// record the op whose unused vars have been updated
// record the op whose unused vars have been updated
...
@@ -93,28 +96,120 @@ void FleetExecutor::Init(
...
@@ -93,28 +96,120 @@ void FleetExecutor::Init(
}
}
}
}
// update the unused vars list in the map
// update the unused vars list in the map
unused_vars
[
op
]
=
unused
;
unused_vars
->
at
(
op
)
=
cur_
unused
;
}
}
for
(
auto
op
:
changed_ops
)
{
for
(
auto
op
:
changed_ops
)
{
auto
iter
=
unused_vars
.
find
(
op
);
const
auto
&
iter
=
unused_vars
->
find
(
op
);
if
(
iter
->
second
.
empty
())
{
if
(
iter
->
second
.
empty
())
{
// remove those ops in the map that have empty unused vars list
// remove those ops in the map that have empty unused vars list
VLOG
(
3
)
<<
"Removing op: ["
<<
op
->
Type
()
<<
"] from unused_vars map."
;
VLOG
(
3
)
<<
"Removing op: ["
<<
op
->
Type
()
<<
"] from unused_vars map."
;
unused_vars
.
erase
(
iter
);
unused_vars
->
erase
(
iter
);
}
}
}
std
::
vector
<
std
::
string
>
GetUnusedVarsAfterWhile
(
const
framework
::
ProgramDesc
&
program_desc
,
TaskNode
*
cond_task
,
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// The vars in while block should not be free until the while op is finished.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
std
::
vector
<
std
::
string
>
while_block_vars
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
for
(
const
auto
&
desc
:
program_desc
.
Block
(
0
).
AllOps
())
{
ops
.
emplace_back
(
framework
::
OpRegistry
::
CreateOp
(
*
desc
));
}
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
PreventVarsDelete
(
&
unused_vars
,
vars_not_gc
);
for
(
const
auto
&
pair
:
unused_vars
)
{
if
(
pair
.
first
->
Type
()
==
"while"
)
{
for
(
const
auto
&
var_name
:
pair
.
second
)
{
while_block_vars
.
emplace_back
(
var_name
);
}
for
(
auto
&
var
:
program_desc
.
Block
(
1
).
AllVars
())
{
while_block_vars
.
emplace_back
(
var
->
Name
());
}
}
}
return
while_block_vars
;
}
}
// namespace
void
FleetExecutor
::
Init
(
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Fleet executor is inited with empty task node"
));
// Set the unused var after running while op
std
::
set
<
TaskNode
*>
sub_block_tasks
;
std
::
vector
<
std
::
string
>
while_block_vars
;
for
(
const
auto
&
task_node
:
task_nodes
)
{
if
(
task_node
->
type
()
==
"Cond"
)
{
GetSubBlockTask
(
task_nodes
,
task_node
,
&
sub_block_tasks
);
while_block_vars
=
GetUnusedVarsAfterWhile
(
program_desc
,
task_node
,
inference_root_scope_vars
);
VLOG
(
3
)
<<
"Vars will be gced after while op"
;
for
(
auto
var
:
while_block_vars
)
{
VLOG
(
3
)
<<
var
;
}
task_node
->
SetWhileBlockVars
(
while_block_vars
);
}
}
std
::
vector
<
framework
::
OperatorBase
*>
sub_block_ops
;
for
(
const
auto
&
task_node
:
sub_block_tasks
)
{
for
(
const
auto
&
op
:
task_node
->
ops
())
{
sub_block_ops
.
emplace_back
(
op
);
}
}
}
}
// Analyse the unused vars in block 0. The operators in block 1
// should be passed in first for prevent vars been released but removed soon.
// Since the unused vars in block 1 need to analyse separately.
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
for
(
const
auto
&
task_node
:
task_nodes
)
{
for
(
const
auto
&
op
:
task_node
->
ops
())
{
ops
.
emplace_back
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
));
}
}
auto
global_unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
PreventVarsDelete
(
&
global_unused_vars
,
inference_root_scope_vars
);
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
for
(
auto
task_node
:
task_nodes
)
{
for
(
auto
task_node
:
task_nodes
)
{
task_node
->
SetUnusedVars
(
unused_vars
);
if
(
sub_block_tasks
.
find
(
task_node
)
==
sub_block_tasks
.
end
())
{
task_node
->
SetUnusedVars
(
global_unused_vars
);
}
int64_t
interceptor_id
=
task_node
->
task_id
();
int64_t
interceptor_id
=
task_node
->
task_id
();
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
}
}
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
...
@@ -126,7 +221,8 @@ void FleetExecutor::Init(
...
@@ -126,7 +221,8 @@ void FleetExecutor::Init(
place
,
place
,
num_micro_batches
,
num_micro_batches
,
program_desc
,
program_desc
,
inference_root_scope_vars
);
inference_root_scope_vars
,
micro_scope_list
);
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
}
...
@@ -136,7 +232,8 @@ void FleetExecutor::InitCarrier(
...
@@ -136,7 +232,8 @@ void FleetExecutor::InitCarrier(
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_node
(),
runtime_graph_
->
interceptor_id_to_node
(),
...
@@ -144,7 +241,8 @@ void FleetExecutor::InitCarrier(
...
@@ -144,7 +241,8 @@ void FleetExecutor::InitCarrier(
scope
,
scope
,
num_micro_batches
,
num_micro_batches
,
place
,
place
,
inference_root_scope_vars
);
inference_root_scope_vars
,
micro_scope_list
);
}
}
void
FleetExecutor
::
InitMessageBus
()
{
void
FleetExecutor
::
InitMessageBus
()
{
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
92c2dcbd
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -45,7 +46,8 @@ class FleetExecutor final {
...
@@ -45,7 +46,8 @@ class FleetExecutor final {
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
void
Run
(
const
std
::
string
&
carrier_id
);
void
Run
(
const
std
::
string
&
carrier_id
);
private:
private:
...
@@ -57,7 +59,8 @@ class FleetExecutor final {
...
@@ -57,7 +59,8 @@ class FleetExecutor final {
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
FleetExecutorDesc
exe_desc_
;
FleetExecutorDesc
exe_desc_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
92c2dcbd
...
@@ -93,7 +93,6 @@ class Interceptor {
...
@@ -93,7 +93,6 @@ class Interceptor {
TaskNode
*
node_
;
TaskNode
*
node_
;
// for stop
// for stop
bool
stop_
{
false
};
void
StopCarrier
();
void
StopCarrier
();
// for runtime
// for runtime
...
@@ -114,9 +113,6 @@ class Interceptor {
...
@@ -114,9 +113,6 @@ class Interceptor {
std
::
mutex
mutex_
;
std
::
mutex
mutex_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
};
};
class
InterceptorFactory
{
class
InterceptorFactory
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
浏览文件 @
92c2dcbd
...
@@ -24,6 +24,21 @@ enum MessageType {
...
@@ -24,6 +24,21 @@ enum MessageType {
ERR
=
4
;
// current Interceptor encounters error
ERR
=
4
;
// current Interceptor encounters error
RESET
=
5
;
// reset the status
RESET
=
5
;
// reset the status
START
=
6
;
START
=
6
;
DATA_WITH_VARS
=
7
;
START_LOOP
=
8
;
}
enum
ValueType
{
INT3
=
0
;
INT6
=
1
;
FLOAT
=
2
;
DOUBLE
=
3
;
BOOL
=
4
;
}
message
VarList
{
required
string
name
=
1
;
required
string
stensor
=
2
;
}
}
message
InterceptorMessage
{
message
InterceptorMessage
{
...
@@ -32,6 +47,7 @@ message InterceptorMessage {
...
@@ -32,6 +47,7 @@ message InterceptorMessage {
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
int64
scope_idx
=
5
[
default
=
0
];
optional
int64
scope_idx
=
5
[
default
=
0
];
repeated
VarList
vars_list
=
6
;
}
}
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
...
...
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
浏览文件 @
92c2dcbd
...
@@ -25,7 +25,7 @@ namespace distributed {
...
@@ -25,7 +25,7 @@ namespace distributed {
* 1. record the num of micro-step
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
* 2. check whether to notify carrier the current step is finished
*/
*/
class
SinkInterceptor
:
public
Interceptor
{
class
SinkInterceptor
final
:
public
Interceptor
{
public:
public:
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/source_interceptor.h
浏览文件 @
92c2dcbd
...
@@ -25,7 +25,7 @@ namespace distributed {
...
@@ -25,7 +25,7 @@ namespace distributed {
* 1. receive `start` message from carrier
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
* 2. send num_of_steps `data_is_ready` message to downstream
*/
*/
class
SourceInterceptor
:
public
Interceptor
{
class
SourceInterceptor
final
:
public
Interceptor
{
public:
public:
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/start_interceptor.cc
0 → 100644
浏览文件 @
92c2dcbd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/start_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
distributed
{
StartInterceptor
::
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
ComputeInterceptor
(
interceptor_id
,
node
)
{
auto
&
downstream
=
node_
->
downstream
();
PADDLE_ENFORCE_EQ
(
downstream
.
size
(),
1
,
platform
::
errors
::
OutOfRange
(
"The downstream for StartInterceptor only support 1 for now."
));
for
(
auto
down
:
downstream
)
{
batch_size_
=
down
.
second
;
}
bool
evenly_divisible
=
((
node_
->
max_run_times
()
%
batch_size_
)
==
0
);
PADDLE_ENFORCE
(
evenly_divisible
,
platform
::
errors
::
Fatal
(
"Wrong config: Num of step should be divided by batch_size,"
"num_step=%lld, batch_size=%lld"
,
node_
->
max_run_times
(),
batch_size_
));
}
void
StartInterceptor
::
RunOps
()
{
finish_count_
++
;
ComputeInterceptor
::
RunOps
();
}
void
StartInterceptor
::
SendDataReadyToDownStream
()
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
down_id
=
outs
.
first
;
auto
max_buff_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
used_size
+=
1
;
if
(
max_buff_size
!=
INFINITE_BUFFER_SIZE
)
{
PADDLE_ENFORCE_LE
(
used_size
,
max_buff_size
,
platform
::
errors
::
OutOfRange
(
"downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld"
,
down_id
,
used_size
,
max_buff_size
));
}
outs
.
second
.
second
=
used_size
;
}
if
(
finish_count_
==
batch_size_
)
{
for
(
int64_t
i
=
0
;
i
<
batch_size_
;
++
i
)
{
int64_t
scope_id
=
step_
%
node_
->
max_run_times
();
for
(
auto
&
outs
:
out_buffs_
)
{
auto
down_id
=
outs
.
first
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
scope_id
);
VLOG
(
3
)
<<
"StartInterceptor "
<<
interceptor_id_
<<
" Send data_is_ready msg to "
<<
down_id
<<
" in scope: "
<<
scope_id
;
Send
(
down_id
,
ready_msg
);
}
step_
++
;
}
}
}
void
StartInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
VLOG
(
3
)
<<
"Start interceptor "
<<
interceptor_id_
<<
" receive data_is_ready "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
Run
();
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
VLOG
(
3
)
<<
"Start interceptor receive data_is_useless "
<<
msg
.
src_id
()
<<
" "
<<
finish_count_
;
finish_count_
--
;
if
(
finish_count_
==
0
)
{
for
(
int64_t
i
=
0
;
i
<
batch_size_
;
++
i
)
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
down_id
=
outs
.
first
;
DecreaseBuff
(
down_id
);
}
}
for
(
int64_t
i
=
0
;
i
<
batch_size_
;
++
i
)
{
Run
();
}
}
}
}
REGISTER_INTERCEPTOR
(
Start
,
StartInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/start_interceptor.h
0 → 100644
浏览文件 @
92c2dcbd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace
paddle
{
namespace
distributed
{
class
StartInterceptor
final
:
public
ComputeInterceptor
{
public:
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
SendDataReadyToDownStream
()
override
;
void
RunOps
()
override
;
void
Compute
(
const
InterceptorMessage
&
msg
)
override
;
int64_t
batch_size_
{
0
};
int64_t
finish_count_
{
0
};
int64_t
step_
{
0
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_node.cc
浏览文件 @
92c2dcbd
...
@@ -24,33 +24,14 @@ namespace {
...
@@ -24,33 +24,14 @@ namespace {
using
OperatorBase
=
TaskNode
::
OperatorBase
;
using
OperatorBase
=
TaskNode
::
OperatorBase
;
}
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
program_
(
program
),
rank_
(
rank
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static
int64_t
task_node_cnt
=
0
;
task_id_
=
task_node_cnt
++
;
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
)
int64_t
max_slot_nums
)
:
program_
(
program
),
:
program_
(
program
),
rank_
(
rank
),
rank_
(
rank
),
task_id_
(
task_id
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_run_times_
(
max_run_times
)
{
max_slot_nums_
(
max_slot_nums
)
{
// TODO(liyurui): Will be removed when execute program is supported.
// TODO(liyurui): Will be removed when execute program is supported.
Init
();
Init
();
}
}
...
@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
...
@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
)
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
)
:
program_
(
program
),
rank_
(
rank
),
task_id_
(
rank
)
{
:
program_
(
program
),
rank_
(
rank
),
task_id_
(
rank
)
{
max_run_times_
=
1
;
max_run_times_
=
1
;
max_slot_nums_
=
1
;
LOG
(
INFO
)
LOG
(
INFO
)
<<
"Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<<
"Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<<
rank
<<
rank
...
@@ -69,6 +49,16 @@ void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
...
@@ -69,6 +49,16 @@ void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_
=
program
;
program_
=
program
;
}
}
void
TaskNode
::
SetVarsToDtype
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
)
{
vars_to_dtype_
=
vars_to_dtype
;
}
void
TaskNode
::
SetVarsToShape
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
)
{
vars_to_shape_
=
vars_to_shape
;
}
void
TaskNode
::
Init
(
bool
use_feed_fetch_ops
)
{
void
TaskNode
::
Init
(
bool
use_feed_fetch_ops
)
{
if
(
!
use_feed_fetch_ops
)
{
if
(
!
use_feed_fetch_ops
)
{
VLOG
(
3
)
<<
"TaskNode will be inited without feed and fetch ops"
;
VLOG
(
3
)
<<
"TaskNode will be inited without feed and fetch ops"
;
...
@@ -98,13 +88,11 @@ TaskNode::TaskNode(int32_t role,
...
@@ -98,13 +88,11 @@ TaskNode::TaskNode(int32_t role,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
)
int64_t
max_slot_nums
)
:
role_
(
role
),
:
role_
(
role
),
rank_
(
rank
),
rank_
(
rank
),
task_id_
(
task_id
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_run_times_
(
max_run_times
)
{
max_slot_nums_
(
max_slot_nums
)
{
if
(
op_descs
.
empty
())
{
if
(
op_descs
.
empty
())
{
return
;
return
;
}
}
...
@@ -121,33 +109,35 @@ TaskNode::TaskNode(int32_t role,
...
@@ -121,33 +109,35 @@ TaskNode::TaskNode(int32_t role,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
)
int64_t
max_slot_nums
)
:
ops_
(
ops
),
:
ops_
(
ops
),
role_
(
role
),
role_
(
role
),
rank_
(
rank
),
rank_
(
rank
),
task_id_
(
task_id
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_run_times_
(
max_run_times
)
{}
max_slot_nums_
(
max_slot_nums
)
{}
TaskNode
::
TaskNode
(
int32_t
role
,
TaskNode
::
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
)
int64_t
max_slot_nums
)
:
role_
(
role
),
:
role_
(
role
),
rank_
(
rank
),
rank_
(
rank
),
task_id_
(
task_id
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_run_times_
(
max_run_times
)
{}
max_slot_nums_
(
max_slot_nums
)
{}
bool
TaskNode
::
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
)
{
bool
TaskNode
::
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
,
DependType
type
)
{
const
auto
&
ret
=
upstream_
.
emplace
(
task_id
,
buff_size
);
const
auto
&
ret
=
upstream_
.
emplace
(
task_id
,
buff_size
);
id_to_dep_type_
.
emplace
(
task_id
,
type
);
return
ret
.
second
;
return
ret
.
second
;
}
}
bool
TaskNode
::
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
)
{
bool
TaskNode
::
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
,
DependType
type
)
{
const
auto
&
ret
=
downstream_
.
emplace
(
task_id
,
buff_size
);
const
auto
&
ret
=
downstream_
.
emplace
(
task_id
,
buff_size
);
id_to_dep_type_
.
emplace
(
task_id
,
type
);
return
ret
.
second
;
return
ret
.
second
;
}
}
...
...
paddle/fluid/distributed/fleet_executor/task_node.h
浏览文件 @
92c2dcbd
...
@@ -14,8 +14,10 @@
...
@@ -14,8 +14,10 @@
#pragma once
#pragma once
#include <cstdint>
#include <cstdint>
#include <functional>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
...
@@ -29,38 +31,30 @@ class OpDesc;
...
@@ -29,38 +31,30 @@ class OpDesc;
}
// namespace framework
}
// namespace framework
namespace
distributed
{
namespace
distributed
{
enum
class
DependType
{
NORMAL
,
LOOP
,
STOP_LOOP
};
class
TaskNode
final
{
class
TaskNode
final
{
public:
public:
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
);
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
);
TaskNode
(
int32_t
role
,
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
);
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
int32_t
role
,
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
);
int64_t
max_slot_nums
);
TaskNode
(
int32_t
role
,
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
);
int64_t
max_slot_nums
);
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
);
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
);
// TODO(liyurui): This will be the only constructor for task node
// TODO(liyurui): This will be the only constructor for task node
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
rank
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_run_times
);
int64_t
max_slot_nums
);
~
TaskNode
()
=
default
;
~
TaskNode
()
=
default
;
void
SetProgram
(
paddle
::
framework
::
ProgramDesc
*
program
);
void
SetProgram
(
paddle
::
framework
::
ProgramDesc
*
program
);
...
@@ -69,11 +63,11 @@ class TaskNode final {
...
@@ -69,11 +63,11 @@ class TaskNode final {
int64_t
task_id
()
const
{
return
task_id_
;
}
int64_t
task_id
()
const
{
return
task_id_
;
}
int32_t
role
()
const
{
return
role_
;
}
int32_t
role
()
const
{
return
role_
;
}
int64_t
max_run_times
()
const
{
return
max_run_times_
;
}
int64_t
max_run_times
()
const
{
return
max_run_times_
;
}
int64_t
max_slot_nums
()
const
{
return
max_slot_nums_
;
}
int64_t
run_per_steps
()
const
{
return
run_per_steps_
;
}
int64_t
run_per_steps
()
const
{
return
run_per_steps_
;
}
int64_t
run_at_offset
()
const
{
return
run_at_offset_
;
}
int64_t
run_at_offset
()
const
{
return
run_at_offset_
;
}
int64_t
reply_up_per_steps
()
const
{
return
reply_up_per_steps_
;
}
int64_t
reply_up_per_steps
()
const
{
return
reply_up_per_steps_
;
}
int64_t
send_down_per_steps
()
const
{
return
send_down_per_steps_
;
}
int64_t
send_down_per_steps
()
const
{
return
send_down_per_steps_
;
}
const
std
::
string
&
cond_var
()
const
{
return
cond_var_
;
}
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
upstream
()
const
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
upstream
()
const
{
return
upstream_
;
return
upstream_
;
}
}
...
@@ -86,11 +80,20 @@ class TaskNode final {
...
@@ -86,11 +80,20 @@ class TaskNode final {
const
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>&
unique_ops
()
const
{
const
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>&
unique_ops
()
const
{
return
ops_vec_
;
return
ops_vec_
;
}
}
const
std
::
unordered_map
<
int64_t
,
DependType
>
id_to_dep_type
()
const
{
return
id_to_dep_type_
;
}
const
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
const
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
unused_vars
()
const
{
unused_vars
()
const
{
return
unused_vars_
;
return
unused_vars_
;
}
}
const
std
::
vector
<
std
::
string
>
while_block_vars
()
const
{
return
while_block_vars_
;
}
void
SetCondVarName
(
const
std
::
string
&
cond_var_name
)
{
cond_var_
=
cond_var_name
;
}
void
SetRunPerSteps
(
int64_t
value
);
void
SetRunPerSteps
(
int64_t
value
);
void
SetRunAtOffset
(
int64_t
value
);
void
SetRunAtOffset
(
int64_t
value
);
void
SetReplyUpPerSteps
(
int64_t
value
);
void
SetReplyUpPerSteps
(
int64_t
value
);
...
@@ -101,11 +104,27 @@ class TaskNode final {
...
@@ -101,11 +104,27 @@ class TaskNode final {
unused_vars
)
{
unused_vars
)
{
unused_vars_
=
unused_vars
;
unused_vars_
=
unused_vars
;
}
}
void
SetWhileBlockVars
(
const
std
::
vector
<
std
::
string
>&
vars
)
{
while_block_vars_
=
vars
;
}
// upstream need buffs?
// upstream need buffs?
bool
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
=
1
);
bool
AddUpstreamTask
(
int64_t
task_id
,
bool
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
=
1
);
int64_t
buff_size
=
1
,
DependType
type
=
DependType
::
NORMAL
);
bool
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
=
1
,
DependType
type
=
DependType
::
NORMAL
);
std
::
string
DebugString
()
const
;
std
::
string
DebugString
()
const
;
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
()
const
{
return
vars_to_dtype_
;
}
void
SetVarsToDtype
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
);
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
()
const
{
return
vars_to_shape_
;
}
void
SetVarsToShape
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
);
private:
private:
DISABLE_COPY_AND_ASSIGN
(
TaskNode
);
DISABLE_COPY_AND_ASSIGN
(
TaskNode
);
...
@@ -115,16 +134,22 @@ class TaskNode final {
...
@@ -115,16 +134,22 @@ class TaskNode final {
// task_id-->buff_size
// task_id-->buff_size
std
::
unordered_map
<
int64_t
,
int64_t
>
upstream_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
upstream_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
downstream_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
downstream_
;
// task_id-->type
std
::
unordered_map
<
int64_t
,
DependType
>
id_to_dep_type_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ProgramDesc
*
program_
;
std
::
string
cond_var_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_vec_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_vec_
;
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
unused_vars_
;
unused_vars_
;
std
::
vector
<
std
::
string
>
while_block_vars_
;
std
::
map
<
std
::
string
,
std
::
string
>
vars_to_dtype_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>
vars_to_shape_
;
int32_t
role_
;
int32_t
role_
;
int64_t
rank_
;
int64_t
rank_
;
int64_t
task_id_
;
int64_t
task_id_
;
int64_t
max_run_times_
;
int64_t
max_run_times_
;
int64_t
max_slot_nums_
;
int64_t
run_per_steps_
{
1
};
int64_t
run_per_steps_
{
1
};
int64_t
run_at_offset_
{
0
};
int64_t
run_at_offset_
{
0
};
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
92c2dcbd
...
@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
// FIXME: don't delete, otherwise interceptor will use undefined node
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
2
);
// rank, task_id, max_run_times
new
TaskNode
(
0
,
SOURCE_ID
,
2
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
TaskNode
*
node_a
=
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
);
// role, ops, rank, task_id
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
,
0
);
// role, ops, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
2
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
2
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
2
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
2
);
// source->a->b->sink
// source->a->b->sink
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
92c2dcbd
...
@@ -21,61 +21,49 @@ limitations under the License. */
...
@@ -21,61 +21,49 @@ limitations under the License. */
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
class
StartInterceptor
:
public
Interceptor
{
public:
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
1
,
stop
);
// stop 1, compute
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
}
};
TEST
(
ComputeInterceptor
,
Compute
)
{
TEST
(
ComputeInterceptor
,
Compute
)
{
std
::
string
carrier_id
=
"0"
;
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
source
=
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
new
TaskNode
(
0
,
SOURCE_ID
,
3
);
// rank, task_id, max_run_times
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
);
// a->b->c
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
3
);
// source->a->b->sink
source
->
AddDownstreamTask
(
0
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
);
node_a
->
AddDownstreamTask
(
1
,
3
);
node_a
->
AddDownstreamTask
(
1
,
3
);
node_b
->
AddUpstreamTask
(
0
,
3
);
node_b
->
AddUpstreamTask
(
0
,
3
);
node_b
->
AddDownstreamTask
(
2
);
node_b
->
AddDownstreamTask
(
SINK_ID
);
node_c
->
AddUpstreamTask
(
1
);
sink
->
AddUpstreamTask
(
1
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_message_type
(
START
);
// test run three times
msg
.
set_dst_id
(
SOURCE_ID
);
a
->
Send
(
1
,
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
carrier
->
Release
();
carrier
->
Release
();
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
92c2dcbd
...
@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
...
@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
return
;
return
;
}
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
92c2dcbd
...
@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
...
@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
StopCarrier
();
StopCarrier
();
return
;
return
;
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
92c2dcbd
...
@@ -66,17 +66,17 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -66,17 +66,17 @@ TEST(AmplifierInterceptor, Amplifier) {
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
int64_t
micro_steps
=
3
;
int64_t
micro_steps
=
1
;
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
,
0
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
,
0
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
// source->a->b->c->d->e->f->sink
// source->a->b->c->d->e->f->sink
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
92c2dcbd
...
@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
micro_steps
);
// role, rank, task_id
new
TaskNode
(
0
,
0
,
0
,
micro_steps
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
micro_steps
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
micro_steps
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
// source->a->b->c->d->sink
// source->a->b->c->d->sink
...
...
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
浏览文件 @
92c2dcbd
...
@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
...
@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
);
// role, rank, task_id
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
0
,
3
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
...
...
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
浏览文件 @
92c2dcbd
...
@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
...
@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
);
// role, rank, task_id
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
...
...
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
浏览文件 @
92c2dcbd
...
@@ -112,5 +112,7 @@ REGISTER_OP_CUDA_KERNEL(c_broadcast,
...
@@ -112,5 +112,7 @@ REGISTER_OP_CUDA_KERNEL(c_broadcast,
ops
::
CBroadcastOpCUDAKernel
<
plat
::
bfloat16
>
,
ops
::
CBroadcastOpCUDAKernel
<
plat
::
bfloat16
>
,
#endif
#endif
ops
::
CBroadcastOpCUDAKernel
<
int
>
,
ops
::
CBroadcastOpCUDAKernel
<
int
>
,
ops
::
CBroadcastOpCUDAKernel
<
uint8_t
>
,
ops
::
CBroadcastOpCUDAKernel
<
int8_t
>
,
ops
::
CBroadcastOpCUDAKernel
<
int64_t
>
,
ops
::
CBroadcastOpCUDAKernel
<
int64_t
>
,
ops
::
CBroadcastOpCUDAKernel
<
plat
::
float16
>
);
ops
::
CBroadcastOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_embedding_op.cu
浏览文件 @
92c2dcbd
...
@@ -19,6 +19,8 @@ limitations under the License. */
...
@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool
(
cudnn_deterministic
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -83,6 +85,32 @@ __global__ void CEmbeddingGrad(T *table,
...
@@ -83,6 +85,32 @@ __global__ void CEmbeddingGrad(T *table,
}
}
}
}
template
<
typename
T
,
typename
IndexT
>
__global__
void
CEmbeddingGradSerial
(
T
*
table
,
const
T
*
output
,
const
IndexT
*
ids
,
const
int
rows
,
const
int
columns
,
const
int64_t
N
,
const
int64_t
start_idx
,
const
int64_t
end_idx
,
const
int64_t
limit
)
{
CUDA_KERNEL_LOOP
(
i
,
limit
)
{
if
(
i
==
0
)
{
for
(
int
j
=
0
;
j
<
limit
;
j
++
)
{
size_t
row
=
j
/
columns
;
size_t
col
=
j
%
columns
;
auto
id
=
ids
[
row
];
if
(
id
>=
start_idx
&&
id
<
end_idx
)
{
auto
real_idx
=
id
-
start_idx
;
paddle
::
platform
::
CudaAtomicAdd
(
&
table
[
real_idx
*
columns
+
col
],
output
[
i
]);
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
class
CEmbeddingCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CEmbeddingCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -163,6 +191,33 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -163,6 +191,33 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
t
.
device
(
*
dev_ctx
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
t
.
device
(
*
dev_ctx
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
const
auto
&
index_type
=
framework
::
TransToProtoVarType
(
ids_t
->
dtype
());
const
auto
&
index_type
=
framework
::
TransToProtoVarType
(
ids_t
->
dtype
());
if
(
FLAGS_cudnn_deterministic
)
{
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
blocks
=
1
;
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CEmbeddingGradSerial
<
T
,
int32_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids_t
->
data
<
int32_t
>
(),
K
,
D
,
N
,
start_idx
,
end_idx
,
limit
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
CEmbeddingGradSerial
<
T
,
int64_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids_t
->
data
<
int64_t
>
(),
K
,
D
,
N
,
start_idx
,
end_idx
,
limit
);
}
}
else
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CEmbeddingGrad
<
T
,
int32_t
>
CEmbeddingGrad
<
T
,
int32_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
...
@@ -187,6 +242,7 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -187,6 +242,7 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
limit
);
limit
);
}
}
}
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/pybind/bind_fleet_executor.cc
浏览文件 @
92c2dcbd
...
@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
...
@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
namespace
paddle
{
namespace
paddle
{
namespace
pybind
{
namespace
pybind
{
using
paddle
::
distributed
::
DependType
;
using
paddle
::
distributed
::
DistModel
;
using
paddle
::
distributed
::
DistModel
;
using
paddle
::
distributed
::
DistModelConfig
;
using
paddle
::
distributed
::
DistModelConfig
;
using
paddle
::
distributed
::
DistModelDataBuf
;
using
paddle
::
distributed
::
DistModelDataBuf
;
...
@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) {
...
@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) {
.
def
(
.
def
(
"run"
,
&
FleetExecutor
::
Run
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"run"
,
&
FleetExecutor
::
Run
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
enum_
<
DependType
>
(
*
m
,
"DependType"
)
.
value
(
"NORMAL"
,
DependType
::
NORMAL
)
.
value
(
"LOOP"
,
DependType
::
LOOP
)
.
value
(
"STOP_LOOP"
,
DependType
::
STOP_LOOP
);
py
::
class_
<
TaskNode
>
(
*
m
,
"TaskNode"
)
py
::
class_
<
TaskNode
>
(
*
m
,
"TaskNode"
)
.
def
(
py
::
init
<
framework
::
ProgramDesc
*
,
int64_t
,
int64_t
,
int64_t
,
int64_t
>
())
.
def
(
py
::
init
<
framework
::
ProgramDesc
*
,
int64_t
,
int64_t
,
int64_t
>
())
.
def
(
py
::
init
<
framework
::
ProgramDesc
*
,
int64_t
,
int64_t
,
int64_t
>
())
.
def
(
py
::
init
<
int32_t
,
.
def
(
py
::
init
<
int32_t
,
const
std
::
vector
<
framework
::
OpDesc
*>&
,
const
std
::
vector
<
framework
::
OpDesc
*>&
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
int64_t
>
())
int64_t
>
())
.
def
(
"task_id"
,
&
TaskNode
::
task_id
)
.
def
(
"task_id"
,
&
TaskNode
::
task_id
)
.
def
(
"add_upstream_task"
,
&
TaskNode
::
AddUpstreamTask
)
.
def
(
"add_upstream_task"
,
&
TaskNode
::
AddUpstreamTask
)
...
@@ -183,7 +183,10 @@ void BindFleetExecutor(py::module* m) {
...
@@ -183,7 +183,10 @@ void BindFleetExecutor(py::module* m) {
.
def
(
"set_run_pre_steps"
,
&
TaskNode
::
SetRunPerSteps
)
.
def
(
"set_run_pre_steps"
,
&
TaskNode
::
SetRunPerSteps
)
.
def
(
"set_run_at_offset"
,
&
TaskNode
::
SetRunAtOffset
)
.
def
(
"set_run_at_offset"
,
&
TaskNode
::
SetRunAtOffset
)
.
def
(
"set_type"
,
&
TaskNode
::
SetType
)
.
def
(
"set_type"
,
&
TaskNode
::
SetType
)
.
def
(
"set_cond_var_name"
,
&
TaskNode
::
SetCondVarName
)
.
def
(
"role"
,
&
TaskNode
::
role
)
.
def
(
"role"
,
&
TaskNode
::
role
)
.
def
(
"set_vars_to_shape"
,
&
TaskNode
::
SetVarsToShape
)
.
def
(
"set_vars_to_dtype"
,
&
TaskNode
::
SetVarsToDtype
)
.
def
(
"init"
,
[](
TaskNode
&
self
)
{
self
.
Init
();
})
.
def
(
"init"
,
[](
TaskNode
&
self
)
{
self
.
Init
();
})
.
def
(
"set_program"
,
&
TaskNode
::
SetProgram
);
.
def
(
"set_program"
,
&
TaskNode
::
SetProgram
);
...
...
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
92c2dcbd
...
@@ -23,6 +23,8 @@
...
@@ -23,6 +23,8 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
namespace
phi
{
template
<
typename
InT
,
typename
OutT
>
template
<
typename
InT
,
typename
OutT
>
...
@@ -101,6 +103,12 @@ struct EmbeddingGradCUDAFunctor {
...
@@ -101,6 +103,12 @@ struct EmbeddingGradCUDAFunctor {
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
gridx
,
1
);
dim3
grids
(
gridx
,
1
);
if
(
FLAGS_cudnn_deterministic
)
{
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
grids
.
x
=
1
;
threads
.
y
=
1
;
}
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
...
...
python/paddle/distributed/auto_parallel/completion.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
92c2dcbd
...
@@ -94,6 +94,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False)
...
@@ -94,6 +94,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config
(
GRADIENT_MERGE
,
"k_steps"
,
1
)
set_field_default_config
(
GRADIENT_MERGE
,
"k_steps"
,
1
)
set_field_default_config
(
GRADIENT_MERGE
,
"avg"
,
True
)
set_field_default_config
(
GRADIENT_MERGE
,
"avg"
,
True
)
#########################################
# pipeline configuration
#########################################
PIPELINE
=
"pipeline"
set_field_default_config
(
PIPELINE
,
"enable"
,
False
)
set_field_default_config
(
PIPELINE
,
"schedule_mode"
,
"1F1B"
)
set_field_default_config
(
PIPELINE
,
"micro_batch_size"
,
1
)
set_field_default_config
(
PIPELINE
,
"accumulate_steps"
,
1
)
set_field_default_config
(
PIPELINE
,
"generation_batch_size"
,
1
)
#########################################
#########################################
# quantization configuration
# quantization configuration
#########################################
#########################################
...
...
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
浏览文件 @
92c2dcbd
...
@@ -556,8 +556,8 @@ def get_cost_from_engine(engine, mode):
...
@@ -556,8 +556,8 @@ def get_cost_from_engine(engine, mode):
)
)
serial_startup_prog
=
(
serial_startup_prog
=
(
engine
.
_
serial_startup_progs
[
mode
]
.
clone
()
engine
.
_
fwd_dist_contexts
[
mode
].
_original_serial_main_program
.
clone
()
if
mode
in
engine
.
_
serial_startup_prog
s
if
mode
in
engine
.
_
fwd_dist_context
s
else
engine
.
_orig_startup_prog
.
clone
()
else
engine
.
_orig_startup_prog
.
clone
()
)
)
losses
=
(
losses
=
(
...
...
python/paddle/distributed/auto_parallel/dist_context.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/dist_op.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/interface.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/operators/__init__.py
浏览文件 @
92c2dcbd
...
@@ -35,3 +35,4 @@ from . import dist_fused_attention
...
@@ -35,3 +35,4 @@ from . import dist_fused_attention
from
.
import
dist_reduce_sum_p
from
.
import
dist_reduce_sum_p
from
.
import
dist_shape
from
.
import
dist_shape
from
.
import
dist_assign
from
.
import
dist_assign
from
.
import
dist_scale
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/operators/dist_default.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/operators/dist_scale.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/process_group.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/process_mesh.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/strategy.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/tuner/profiler.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
92c2dcbd
...
@@ -1874,6 +1874,12 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
...
@@ -1874,6 +1874,12 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
)
)
)
)
break
break
print
(
"***process_group: id:"
,
process_group
.
id
,
"rank:"
,
process_group
.
ranks
,
)
process_group
.
instantiate
()
process_group
.
instantiate
()
server_socket
.
close
()
server_socket
.
close
()
...
...
python/paddle/distributed/fleet/fleet_executor_utils.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/parallel.py
浏览文件 @
92c2dcbd
...
@@ -56,6 +56,13 @@ ParallelStrategy = core.ParallelStrategy
...
@@ -56,6 +56,13 @@ ParallelStrategy = core.ParallelStrategy
_global_parallel_env
=
None
_global_parallel_env
=
None
def
_is_global_parallel_initialize
():
global
_global_parallel_env
if
_global_parallel_env
is
None
:
return
False
return
True
def
_get_global_parallel_env
():
def
_get_global_parallel_env
():
global
_global_parallel_env
global
_global_parallel_env
if
_global_parallel_env
is
None
:
if
_global_parallel_env
is
None
:
...
...
python/paddle/distributed/passes/__init__.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/passes/auto_parallel_grad_clip.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/passes/auto_parallel_pipeline.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/executor.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/generation_pipeline_pass_unittest.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_generation_pipeline.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/tensor/stat.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录