Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
92c2dcbd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
92c2dcbd
编写于
3月 20, 2023
作者:
L
LiYuRio
提交者:
GitHub
3月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cherry-pick fleet executor and auto parallel (#50071)
上级
4bacf2ab
变更
74
展开全部
隐藏空白更改
内联
并排
Showing
74 changed file
with
7436 addition
and
3086 deletion
+7436
-3086
cmake/third_party.cmake
cmake/third_party.cmake
+2
-1
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+6
-0
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
...fluid/distributed/fleet_executor/amplifier_interceptor.cc
+3
-3
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
.../fluid/distributed/fleet_executor/amplifier_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+90
-38
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+3
-3
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+261
-125
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+17
-20
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
+167
-0
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
+55
-0
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+136
-38
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+5
-2
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+0
-4
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
...luid/distributed/fleet_executor/interceptor_message.proto
+16
-0
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/source_interceptor.h
paddle/fluid/distributed/fleet_executor/source_interceptor.h
+1
-1
paddle/fluid/distributed/fleet_executor/start_interceptor.cc
paddle/fluid/distributed/fleet_executor/start_interceptor.cc
+115
-0
paddle/fluid/distributed/fleet_executor/start_interceptor.h
paddle/fluid/distributed/fleet_executor/start_interceptor.h
+39
-0
paddle/fluid/distributed/fleet_executor/task_node.cc
paddle/fluid/distributed/fleet_executor/task_node.cc
+26
-36
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+44
-19
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+2
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+24
-36
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+0
-1
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+0
-1
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+7
-7
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+4
-5
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
.../distributed/fleet_executor/test/sink_interceptor_test.cc
+3
-4
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
...istributed/fleet_executor/test/source_interceptor_test.cc
+2
-3
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
+2
-0
paddle/fluid/operators/collective/c_embedding_op.cu
paddle/fluid/operators/collective/c_embedding_op.cu
+78
-22
paddle/fluid/pybind/bind_fleet_executor.cc
paddle/fluid/pybind/bind_fleet_executor.cc
+9
-6
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+8
-0
python/paddle/distributed/auto_parallel/completion.py
python/paddle/distributed/auto_parallel/completion.py
+798
-390
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+10
-0
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
...on/paddle/distributed/auto_parallel/cost/estimate_cost.py
+2
-2
python/paddle/distributed/auto_parallel/dist_context.py
python/paddle/distributed/auto_parallel/dist_context.py
+454
-173
python/paddle/distributed/auto_parallel/dist_op.py
python/paddle/distributed/auto_parallel/dist_op.py
+101
-53
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+109
-71
python/paddle/distributed/auto_parallel/interface.py
python/paddle/distributed/auto_parallel/interface.py
+60
-30
python/paddle/distributed/auto_parallel/operators/__init__.py
...on/paddle/distributed/auto_parallel/operators/__init__.py
+1
-0
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+120
-72
python/paddle/distributed/auto_parallel/operators/dist_default.py
...addle/distributed/auto_parallel/operators/dist_default.py
+158
-86
python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py
..._parallel/operators/dist_fill_constant_batch_size_like.py
+23
-31
python/paddle/distributed/auto_parallel/operators/dist_scale.py
.../paddle/distributed/auto_parallel/operators/dist_scale.py
+90
-0
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+215
-106
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+159
-66
python/paddle/distributed/auto_parallel/process_group.py
python/paddle/distributed/auto_parallel/process_group.py
+47
-30
python/paddle/distributed/auto_parallel/process_mesh.py
python/paddle/distributed/auto_parallel/process_mesh.py
+35
-23
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+1107
-559
python/paddle/distributed/auto_parallel/strategy.py
python/paddle/distributed/auto_parallel/strategy.py
+14
-10
python/paddle/distributed/auto_parallel/tuner/profiler.py
python/paddle/distributed/auto_parallel/tuner/profiler.py
+48
-31
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+6
-0
python/paddle/distributed/fleet/fleet_executor_utils.py
python/paddle/distributed/fleet/fleet_executor_utils.py
+195
-131
python/paddle/distributed/parallel.py
python/paddle/distributed/parallel.py
+7
-0
python/paddle/distributed/passes/__init__.py
python/paddle/distributed/passes/__init__.py
+1
-0
python/paddle/distributed/passes/auto_parallel_grad_clip.py
python/paddle/distributed/passes/auto_parallel_grad_clip.py
+69
-37
python/paddle/distributed/passes/auto_parallel_pipeline.py
python/paddle/distributed/passes/auto_parallel_pipeline.py
+635
-0
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+961
-542
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
.../fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
+7
-3
python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py
...tests/unittests/auto_parallel/clip_grad_by_global_norm.py
+7
-5
python/paddle/fluid/tests/unittests/auto_parallel/generation_pipeline_pass_unittest.py
...ttests/auto_parallel/generation_pipeline_pass_unittest.py
+177
-0
python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py
...s/unittests/auto_parallel/gradient_merge_pass_unittest.py
+10
-11
python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py
.../tests/unittests/auto_parallel/recompute_pass_unittest.py
+4
-2
python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py
...d/tests/unittests/auto_parallel/sharding_pass_unittest.py
+13
-11
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py
.../fluid/tests/unittests/auto_parallel/test_dist_context.py
+121
-69
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_generation_pipeline.py
.../unittests/auto_parallel/test_pass_generation_pipeline.py
+58
-0
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
...addle/fluid/tests/unittests/test_auto_parallel_reshard.py
+154
-82
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py
...luid/tests/unittests/test_auto_parallel_reshard_serial.py
+76
-48
python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py
...d/tests/unittests/test_fleet_executor_cond_interceptor.py
+217
-0
python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py
...le/fluid/tests/unittests/test_fleet_executor_task_node.py
+17
-10
python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py
...id/tests/unittests/test_fleet_executor_with_task_nodes.py
+20
-22
python/paddle/tensor/stat.py
python/paddle/tensor/stat.py
+1
-0
未找到文件。
cmake/third_party.cmake
浏览文件 @
92c2dcbd
...
@@ -426,7 +426,8 @@ endif()
...
@@ -426,7 +426,8 @@ endif()
if
(
WITH_DISTRIBUTE
if
(
WITH_DISTRIBUTE
AND NOT WITH_PSLIB
AND NOT WITH_PSLIB
AND NOT WITH_PSCORE
)
AND NOT WITH_PSCORE
AND NOT WITH_RPC
)
include
(
external/snappy
)
include
(
external/snappy
)
list
(
APPEND third_party_deps extern_snappy
)
list
(
APPEND third_party_deps extern_snappy
)
...
...
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
92c2dcbd
...
@@ -36,6 +36,8 @@ cc_library(
...
@@ -36,6 +36,8 @@ cc_library(
interceptor.cc
interceptor.cc
compute_interceptor.cc
compute_interceptor.cc
amplifier_interceptor.cc
amplifier_interceptor.cc
cond_interceptor.cc
start_interceptor.cc
source_interceptor.cc
source_interceptor.cc
sink_interceptor.cc
sink_interceptor.cc
message_service.cc
message_service.cc
...
@@ -66,6 +68,10 @@ if(WITH_DISTRIBUTE)
...
@@ -66,6 +68,10 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
set_source_files_properties
(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
start_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
set_source_files_properties
(
source_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
source_interceptor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
set_source_files_properties
(
...
...
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
浏览文件 @
92c2dcbd
...
@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
...
@@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
// 4, 3 --> run at step 3, 7, 11, 15
if
((
step
_
%
run_per_steps_
)
==
run_at_offset_
)
{
if
((
cur_scope_id
_
%
run_per_steps_
)
==
run_at_offset_
)
{
ComputeInterceptor
::
RunOps
();
ComputeInterceptor
::
RunOps
();
}
}
}
}
...
@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
...
@@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() {
void
AmplifierInterceptor
::
SendDataReadyToDownStream
()
{
void
AmplifierInterceptor
::
SendDataReadyToDownStream
()
{
// run multi times, send ready one times to downstream, that is
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
// input multi times, output one times
if
(
step
_
%
send_down_per_steps_
==
0
)
{
if
(
cur_scope_id
_
%
send_down_per_steps_
==
0
)
{
ComputeInterceptor
::
SendDataReadyToDownStream
();
ComputeInterceptor
::
SendDataReadyToDownStream
();
}
}
}
}
...
@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
...
@@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() {
void
AmplifierInterceptor
::
ReplyCompletedToUpStream
()
{
void
AmplifierInterceptor
::
ReplyCompletedToUpStream
()
{
// run multi times, reply one times to upstream, that is
// run multi times, reply one times to upstream, that is
// input one times, output multi times
// input one times, output multi times
if
(
step
_
%
reply_up_per_steps_
==
0
)
{
if
(
cur_scope_id
_
%
reply_up_per_steps_
==
0
)
{
ComputeInterceptor
::
ReplyCompletedToUpStream
();
ComputeInterceptor
::
ReplyCompletedToUpStream
();
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
浏览文件 @
92c2dcbd
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
class
AmplifierInterceptor
:
public
ComputeInterceptor
{
class
AmplifierInterceptor
final
:
public
ComputeInterceptor
{
public:
public:
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
92c2dcbd
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include <algorithm>
#include <algorithm>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...
@@ -24,6 +25,7 @@
...
@@ -24,6 +25,7 @@
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -33,6 +35,8 @@ USE_INTERCEPTOR(Source);
...
@@ -33,6 +35,8 @@ USE_INTERCEPTOR(Source);
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
USE_INTERCEPTOR
(
Amplifier
);
USE_INTERCEPTOR
(
Sink
);
USE_INTERCEPTOR
(
Sink
);
USE_INTERCEPTOR
(
Cond
);
USE_INTERCEPTOR
(
Start
);
void
Carrier
::
Init
(
void
Carrier
::
Init
(
int64_t
rank
,
int64_t
rank
,
...
@@ -54,24 +58,38 @@ void Carrier::Init(
...
@@ -54,24 +58,38 @@ void Carrier::Init(
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
rank_
=
rank
;
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
place_
=
place
;
place_
=
place
;
root_scope_
=
scope
;
root_scope_
=
scope
;
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
bool
need_create_scope
=
micro_scope_list
.
empty
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
root_scope_
,
platform
::
errors
::
InvalidArgument
(
"root_scope can not be nullptr"
));
platform
::
errors
::
InvalidArgument
(
"root_scope can not be nullptr"
));
minibatch_scope_
=
&
root_scope_
->
NewScope
();
microbatch_scopes_
.
resize
(
num_micro_batches
);
if
(
need_create_scope
)
{
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
minibatch_scope_
=
&
root_scope_
->
NewScope
();
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
microbatch_scopes_
.
resize
(
num_micro_batches
);
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
}
}
else
{
microbatch_scopes_
=
micro_scope_list
;
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
}
}
}
// Add source and sink interceptor id to rank
interceptor_id_to_rank_
.
emplace
(
SOURCE_ID
,
rank
);
interceptor_id_to_rank_
.
emplace
(
SINK_ID
,
rank
);
// TODO(fleet_exe dev): thread pool
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
SetThreadNum
(
thread_num_
);
...
@@ -93,29 +111,30 @@ void Carrier::CopyParameters(
...
@@ -93,29 +111,30 @@ void Carrier::CopyParameters(
int
microbatch_id
,
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
,
const
framework
::
ProgramDesc
&
program
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
auto
&
global_block
=
program
.
Block
(
0
);
std
::
map
<
std
::
string
,
int
>
inference_root_scope_var_map
;
std
::
map
<
std
::
string
,
int
>
inference_root_scope_var_map
;
for
(
auto
var_name
:
inference_root_scope_vars
)
{
for
(
auto
var_name
:
inference_root_scope_vars
)
{
inference_root_scope_var_map
.
insert
({
var_name
,
1
});
inference_root_scope_var_map
.
insert
({
var_name
,
1
});
}
}
for
(
auto
&
var
:
global_block
.
AllVars
())
{
for
(
size_t
i
=
0
;
i
<
program
.
Size
();
++
i
)
{
std
::
string
var_name
=
var
->
Name
();
for
(
auto
&
var
:
program
.
Block
(
i
).
AllVars
())
{
bool
force_root
=
inference_root_scope_var_map
.
find
(
var_name
)
!=
std
::
string
var_name
=
var
->
Name
();
inference_root_scope_var_map
.
end
();
bool
force_root
=
inference_root_scope_var_map
.
find
(
var_name
)
!=
if
(
force_root
)
{
inference_root_scope_var_map
.
end
();
VLOG
(
4
)
<<
var_name
<<
" will be forced to be created in the root scope."
;
if
(
force_root
)
{
}
VLOG
(
4
)
<<
var_name
if
((
var
->
Persistable
()
||
force_root
)
&&
microbatch_id
==
0
)
{
<<
" will be forced to be created in the root scope."
;
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
}
InitializeVariable
(
ptr
,
var
->
GetType
());
if
((
var
->
Persistable
()
||
force_root
)
&&
microbatch_id
==
0
)
{
VLOG
(
5
)
<<
"Create persistable var: "
<<
var
->
Name
()
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
<<
", which pointer is "
<<
ptr
;
InitializeVariable
(
ptr
,
var
->
GetType
());
}
else
if
(
!
var
->
Persistable
())
{
VLOG
(
5
)
<<
"Create persistable var: "
<<
var
->
Name
()
auto
*
ptr
=
microbatch_scopes_
[
microbatch_id
]
->
Var
(
var
->
Name
());
<<
", which pointer is "
<<
ptr
;
VLOG
(
5
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for microbatch "
}
else
if
(
!
var
->
Persistable
())
{
<<
microbatch_id
<<
", which pointer is "
<<
ptr
<<
"."
;
auto
*
ptr
=
microbatch_scopes_
[
microbatch_id
]
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
5
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for microbatch "
<<
microbatch_id
<<
", which pointer is "
<<
ptr
<<
"."
;
InitializeVariable
(
ptr
,
var
->
GetType
());
}
}
}
}
}
}
}
...
@@ -159,16 +178,11 @@ void Carrier::Start() {
...
@@ -159,16 +178,11 @@ void Carrier::Start() {
true
,
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Using carrier before initialized."
));
"Using carrier before initialized."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
InterceptorMessage
start_msg
;
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
start_msg
.
set_src_id
(
SOURCE_ID
);
<<
"."
;
start_msg
.
set_dst_id
(
SOURCE_ID
);
InterceptorMessage
start_msg
;
start_msg
.
set_message_type
(
START
);
// source node data_is_ready is send by carrier, so set src_id=-1
Send
(
start_msg
);
start_msg
.
set_src_id
(
-
1
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
start_msg
);
}
// TODO(wangxi): async step
// TODO(wangxi): async step
Wait
();
Wait
();
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
...
@@ -270,6 +284,38 @@ void Carrier::CreateInterceptors() {
...
@@ -270,6 +284,38 @@ void Carrier::CreateInterceptors() {
auto
gc
=
GetGC
(
place_
);
auto
gc
=
GetGC
(
place_
);
// create source and sink task node
auto
max_run_times
=
microbatch_scopes_
.
size
();
TaskNode
*
source
=
new
TaskNode
(
rank_
,
SOURCE_ID
,
max_run_times
);
// rank, task_id, max_run_times
TaskNode
*
sink
=
new
TaskNode
(
rank_
,
SINK_ID
,
max_run_times
);
// find nodes without upstreams or without downstreams
std
::
vector
<
TaskNode
*>
origin_sources
,
origin_sinks
;
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
TaskNode
*
task_node
=
item
.
second
;
if
(
task_node
->
upstream
().
empty
())
{
origin_sources
.
emplace_back
(
task_node
);
}
if
(
task_node
->
downstream
().
empty
())
{
origin_sinks
.
emplace_back
(
task_node
);
}
}
// link source node with origin source
for
(
const
auto
&
node
:
origin_sources
)
{
source
->
AddDownstreamTask
(
node
->
task_id
(),
std
::
numeric_limits
<
int64_t
>::
max
());
node
->
AddUpstreamTask
(
SOURCE_ID
,
std
::
numeric_limits
<
int64_t
>::
max
());
}
// link sink node with origin sink
for
(
const
auto
&
node
:
origin_sinks
)
{
sink
->
AddUpstreamTask
(
node
->
task_id
(),
std
::
numeric_limits
<
int64_t
>::
max
());
node
->
AddDownstreamTask
(
SINK_ID
,
std
::
numeric_limits
<
int64_t
>::
max
());
}
// create source and sink interceptor
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// create each Interceptor
// create each Interceptor
// no auto init since there is no config
// no auto init since there is no config
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
...
@@ -303,9 +349,15 @@ void Carrier::CreateInterceptors() {
...
@@ -303,9 +349,15 @@ void Carrier::CreateInterceptors() {
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
<<
" with type: "
<<
task_node
->
type
()
<<
"."
;
<<
" with type: "
<<
task_node
->
type
()
<<
"."
;
if
(
task_node
->
upstream
().
empty
())
{
PADDLE_ENFORCE_EQ
(
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
task_node
->
upstream
().
empty
(),
}
false
,
platform
::
errors
::
PreconditionNotMet
(
"There should not have normal nodes as source nodes"
));
PADDLE_ENFORCE_EQ
(
task_node
->
downstream
().
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"There should not have normal nodes as sink nodes"
));
}
}
}
}
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
92c2dcbd
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
...
@@ -60,7 +61,8 @@ class Carrier final {
...
@@ -60,7 +61,8 @@ class Carrier final {
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
void
CopyParameters
(
void
CopyParameters
(
int
microbatch_id
,
int
microbatch_id
,
...
@@ -100,8 +102,6 @@ class Carrier final {
...
@@ -100,8 +102,6 @@ class Carrier final {
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
interceptor_idx_to_interceptor_
;
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
bool
is_init_
{
false
};
bool
is_init_
{
false
};
std
::
mutex
running_mutex_
;
std
::
mutex
running_mutex_
;
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
92c2dcbd
...
@@ -18,10 +18,85 @@
...
@@ -18,10 +18,85 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/core/utils/dim.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
namespace
{
template
<
typename
T
>
void
SetVarResult
(
const
std
::
string
&
name
,
T
value
,
int64_t
scope_id
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
int64_t
>&
dim_vec
)
{
auto
*
var
=
scope
->
FindVar
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
!
var
)
{
VLOG
(
3
)
<<
"Create var and memory for var "
<<
name
;
var
=
scope
->
Var
(
name
);
phi
::
DDim
dims
=
phi
::
make_ddim
(
dim_vec
);
tensor
->
Resize
(
dims
);
tensor
->
mutable_data
<
T
>
(
dims
,
place
);
}
PADDLE_ENFORCE_EQ
(
tensor
->
dims
().
size
(),
1
,
platform
::
errors
::
OutOfRange
(
"Only support transfer size 1 value."
));
PADDLE_ENFORCE_EQ
(
tensor
->
dims
().
at
(
0
),
1
,
platform
::
errors
::
OutOfRange
(
"Only support transfer size 1 value."
));
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi
::
DenseTensor
cpu_tensor
;
auto
dim
=
phi
::
make_ddim
({
1
});
cpu_tensor
.
mutable_data
<
T
>
(
dim
,
platform
::
CPUPlace
());
auto
*
cpu_tensor_ptr
=
cpu_tensor
.
data
<
T
>
();
cpu_tensor_ptr
[
0
]
=
value
;
framework
::
TensorCopySync
(
cpu_tensor
,
tensor
->
place
(),
tensor
);
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport device for cond interceptor."
));
}
}
template
<
typename
T
>
T
GetVarResult
(
const
std
::
string
&
name
,
int64_t
scope_id
,
framework
::
Scope
*
scope
)
{
auto
*
var
=
scope
->
FindVar
(
name
);
PADDLE_ENFORCE
(
var
,
platform
::
errors
::
NotFound
(
"Variable %s not exists in scope %ld"
,
name
,
scope_id
));
const
auto
&
tensor
=
var
->
Get
<
phi
::
DenseTensor
>
();
T
res
;
if
(
platform
::
is_gpu_place
(
tensor
.
place
()))
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi
::
DenseTensor
cpu_tensor
;
framework
::
TensorCopySync
(
tensor
,
platform
::
CPUPlace
(),
&
cpu_tensor
);
res
=
cpu_tensor
.
data
<
T
>
()[
0
];
#endif
}
else
if
(
platform
::
is_cpu_place
(
tensor
.
place
()))
{
res
=
tensor
.
data
<
T
>
()[
0
];
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport device for cond interceptor."
));
}
return
res
;
}
}
// namespace
ComputeInterceptor
::
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
ComputeInterceptor
::
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
:
Interceptor
(
interceptor_id
,
node
)
{
PrepareDeps
();
PrepareDeps
();
...
@@ -33,57 +108,49 @@ void ComputeInterceptor::PrepareDeps() {
...
@@ -33,57 +108,49 @@ void ComputeInterceptor::PrepareDeps() {
auto
&
downstream
=
node_
->
downstream
();
auto
&
downstream
=
node_
->
downstream
();
for
(
auto
up
:
upstream
)
{
for
(
auto
up
:
upstream
)
{
in_readys_
.
emplace
(
up
.
first
,
std
::
make_pair
(
up
.
second
,
0
));
std
::
map
<
int64_t
,
int64_t
>
ready_size_map
;
in_stops_
.
emplace
(
up
.
first
,
false
);
for
(
int64_t
i
=
0
;
i
<
node_
->
max_run_times
();
++
i
)
{
ready_size_map
.
emplace
(
i
,
0
);
}
in_readys_
.
emplace
(
up
.
first
,
std
::
make_pair
(
up
.
second
,
ready_size_map
));
}
}
for
(
auto
down
:
downstream
)
{
for
(
auto
down
:
downstream
)
{
out_buffs_
.
emplace
(
down
.
first
,
std
::
make_pair
(
down
.
second
,
0
));
out_buffs_
.
emplace
(
down
.
first
,
std
::
make_pair
(
down
.
second
,
0
));
}
}
// source compute node, should we add a new SourceInterceptor?
if
(
upstream
.
empty
())
{
is_source_
=
true
;
PADDLE_ENFORCE_GT
(
node_
->
max_run_times
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld"
,
node_
->
max_run_times
()));
in_readys_
.
emplace
(
-
1
,
std
::
make_pair
(
std
::
numeric_limits
<
int64_t
>::
max
(),
0
));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_
=
downstream
.
empty
();
}
}
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
)
{
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
,
int64_t
scope_id
)
{
auto
it
=
in_readys_
.
find
(
up_id
);
auto
it
=
in_readys_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
PADDLE_ENFORCE_NE
(
it
,
in_readys_
.
end
(),
in_readys_
.
end
(),
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_readys."
,
up_id
));
"Cannot find upstream=%lld in in_readys."
,
up_id
));
// source node has no upstream, data_is_ready is send by carrier or others
if
(
is_source_
&&
up_id
==
-
1
)
{
it
->
second
.
second
+=
GetTaskNode
()
->
max_run_times
();
return
;
}
auto
max_ready_size
=
it
->
second
.
first
;
auto
max_ready_size
=
it
->
second
.
first
;
auto
ready_size
=
it
->
second
.
second
;
const
auto
&
ready_scope_map
=
it
->
second
.
second
;
ready_size
+=
1
;
int64_t
ready_size
=
0
;
PADDLE_ENFORCE_LE
(
ready_size
,
for
(
auto
&
scope_iter
:
ready_scope_map
)
{
max_ready_size
,
ready_size
+=
scope_iter
.
second
;
platform
::
errors
::
OutOfRange
(
}
"upstream=%lld ready_size must <= max_ready_size, but "
if
(
max_ready_size
!=
INFINITE_BUFFER_SIZE
)
{
"now ready_size=%lld, max_ready_size=%lld"
,
PADDLE_ENFORCE_LE
(
up_id
,
ready_size
,
ready_size
,
max_ready_size
,
max_ready_size
));
platform
::
errors
::
OutOfRange
(
it
->
second
.
second
=
ready_size
;
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld"
,
up_id
,
ready_size
,
max_ready_size
));
}
PADDLE_ENFORCE_NE
(
it
->
second
.
second
.
find
(
scope_id
),
it
->
second
.
second
.
end
(),
platform
::
errors
::
OutOfRange
(
"Interceptor %lld can not find scope %lld in upstream ready map"
,
interceptor_id_
,
scope_id
));
it
->
second
.
second
.
at
(
scope_id
)
=
ready_scope_map
.
at
(
scope_id
)
+
1
;
}
}
void
ComputeInterceptor
::
DecreaseBuff
(
int64_t
down_id
)
{
void
ComputeInterceptor
::
DecreaseBuff
(
int64_t
down_id
)
{
...
@@ -105,22 +172,40 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
...
@@ -105,22 +172,40 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
}
}
bool
ComputeInterceptor
::
IsInputReady
()
{
bool
ComputeInterceptor
::
IsInputReady
()
{
for
(
auto
&
ins
:
in_readys_
)
{
for
(
int64_t
i
=
0
;
i
<
node_
->
max_run_times
();
++
i
)
{
auto
ready_size
=
ins
.
second
.
second
;
bool
flag
=
true
;
// not ready, return false
for
(
auto
&
ins
:
in_readys_
)
{
if
(
ready_size
==
0
)
{
auto
ready_size_map
=
ins
.
second
.
second
;
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
flag
=
flag
&&
(
ready_size_map
.
at
(
i
)
!=
0
);
}
if
(
flag
)
{
for
(
auto
iter
:
scope_id_to_finish_flag_
)
{
if
(
iter
.
first
==
i
)
{
break
;
}
else
if
(
!
iter
.
second
)
{
VLOG
(
3
)
<<
"The previous scope is not ready, waiting for the "
"previous scope "
<<
iter
.
first
;
return
false
;
}
}
cur_scope_id_
=
i
;
return
true
;
}
else
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" in scope "
<<
i
<<
"'s upstreams aren't all ready."
;
<<
"'s upstreams aren't all ready."
;
return
false
;
}
}
}
}
return
tru
e
;
return
fals
e
;
}
}
bool
ComputeInterceptor
::
CanWriteOutput
()
{
bool
ComputeInterceptor
::
CanWriteOutput
()
{
for
(
auto
&
outs
:
out_buffs_
)
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
max_buffer_size
=
outs
.
second
.
first
;
auto
max_buffer_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
auto
used_size
=
outs
.
second
.
second
;
if
(
max_buffer_size
==
INFINITE_BUFFER_SIZE
)
{
continue
;
}
// full, return false
// full, return false
if
(
used_size
==
max_buffer_size
)
{
if
(
used_size
==
max_buffer_size
)
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
...
@@ -137,30 +222,76 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
...
@@ -137,30 +222,76 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
auto
max_buff_size
=
outs
.
second
.
first
;
auto
max_buff_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
auto
used_size
=
outs
.
second
.
second
;
used_size
+=
1
;
used_size
+=
1
;
PADDLE_ENFORCE_LE
(
if
(
max_buff_size
!=
INFINITE_BUFFER_SIZE
)
{
used_size
,
PADDLE_ENFORCE_LE
(
max_buff_size
,
used_size
,
platform
::
errors
::
OutOfRange
(
"downstream=%lld used buff size must <= "
max_buff_size
,
"max_buff_size, but now used_size=%lld, "
platform
::
errors
::
OutOfRange
(
"downstream=%lld used buff size must <= "
"max_buff_size=%lld"
,
"max_buff_size, but now used_size=%lld, "
down_id
,
"max_buff_size=%lld"
,
used_size
,
down_id
,
max_buff_size
));
used_size
,
max_buff_size
));
}
outs
.
second
.
second
=
used_size
;
outs
.
second
.
second
=
used_size
;
InterceptorMessage
ready_msg
;
bool
need_send_vars
=
!
(
node_
->
vars_to_dtype
().
empty
());
ready_msg
.
set_message_type
(
DATA_IS_READY
);
if
(
need_send_vars
)
{
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
InterceptorMessage
ready_msg
=
PrepareVarsMsg
();
<<
" Send data_is_ready msg to "
<<
down_id
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" for step: "
<<
step_
;
<<
" Send data_with_vars msg to "
<<
down_id
Send
(
down_id
,
ready_msg
);
<<
" in scope: "
<<
cur_scope_id_
;
Send
(
down_id
,
ready_msg
);
}
else
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Send data_is_ready msg to "
<<
down_id
<<
" in scope: "
<<
cur_scope_id_
;
Send
(
down_id
,
ready_msg
);
}
}
}
InterceptorMessage
ComputeInterceptor
::
PrepareVarsMsg
()
{
PADDLE_ENFORCE_LT
(
cur_scope_id_
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
cur_scope_id_
));
auto
*
scope
=
microbatch_scopes_
[
cur_scope_id_
];
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_WITH_VARS
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
for
(
auto
iter
:
node_
->
vars_to_dtype
())
{
VarList
*
vars
=
ready_msg
.
add_vars_list
();
const
auto
&
var_name
=
iter
.
first
;
vars
->
set_name
(
var_name
);
std
::
ostringstream
ss
;
auto
&
dev_ctx
=
*
pool
.
Get
(
place_
);
auto
*
var
=
scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE
(
var
,
platform
::
errors
::
NotFound
(
"Variable %s not exists in scope %ld"
,
var_name
,
cur_scope_id_
));
const
auto
&
tensor
=
var
->
Get
<
phi
::
DenseTensor
>
();
SerializeToStream
(
ss
,
tensor
,
dev_ctx
);
vars
->
set_stensor
(
ss
.
str
());
VLOG
(
3
)
<<
"Prepare vars msg "
<<
var_name
<<
" with dimension "
<<
tensor
.
dims
()
<<
" dtype "
<<
tensor
.
dtype
();
}
}
return
ready_msg
;
}
}
void
ComputeInterceptor
::
ReplyCompletedToUpStream
()
{
void
ComputeInterceptor
::
ReplyCompletedToUpStream
()
{
for
(
auto
&
ins
:
in_readys_
)
{
for
(
auto
&
ins
:
in_readys_
)
{
auto
up_id
=
ins
.
first
;
auto
up_id
=
ins
.
first
;
auto
ready_size
=
ins
.
second
.
second
;
auto
ready_size
=
ins
.
second
.
second
.
at
(
cur_scope_id_
)
;
ready_size
-=
1
;
ready_size
-=
1
;
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
ready_size
,
ready_size
,
...
@@ -169,109 +300,114 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
...
@@ -169,109 +300,114 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
"upstream=%lld ready_size must >= 0, but now got %lld"
,
"upstream=%lld ready_size must >= 0, but now got %lld"
,
up_id
,
up_id
,
ready_size
));
ready_size
));
ins
.
second
.
second
=
ready_size
;
ins
.
second
.
second
[
cur_scope_id_
]
=
ready_size
;
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" for step: "
<<
step_
;
<<
" in scope: "
<<
cur_scope_id_
;
if
(
is_source_
&&
up_id
==
-
1
)
return
;
InterceptorMessage
reply_msg
;
InterceptorMessage
reply_msg
;
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
reply_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
up_id
,
reply_msg
);
Send
(
up_id
,
reply_msg
);
}
}
}
}
void
ComputeInterceptor
::
RunOps
()
{
void
ComputeInterceptor
::
RunOps
()
{
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
for
(
auto
op
:
node_
->
ops
())
{
op
->
Run
(
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
place_
);
PADDLE_ENFORCE_LT
(
cur_scope_id_
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
cur_scope_id_
));
op
->
Run
(
*
microbatch_scopes_
[
cur_scope_id_
],
place_
);
if
(
gc_
)
{
if
(
gc_
)
{
framework
::
DeleteUnusedTensors
(
framework
::
DeleteUnusedTensors
(
*
microbatch_scopes_
[
cur_scope_id_
],
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
op
,
op
,
node_
->
unused_vars
(),
node_
->
unused_vars
(),
gc_
.
get
());
gc_
.
get
());
}
}
}
}
}
}
void
ComputeInterceptor
::
Run
()
{
void
ComputeInterceptor
::
Run
()
{
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running in scope "
<<
cur_scope_id_
;
RunOps
();
RunOps
();
++
step_
;
if
(
!
scope_id_to_finish_flag_
.
empty
())
{
PADDLE_ENFORCE_NE
(
scope_id_to_finish_flag_
.
find
(
cur_scope_id_
),
scope_id_to_finish_flag_
.
end
(),
platform
::
errors
::
NotFound
(
"Can not find scope %ld in scope_id_to_finish"
,
cur_scope_id_
));
scope_id_to_finish_flag_
.
erase
(
cur_scope_id_
);
}
// send to downstream and increase buff used
// send to downstream and increase buff used
SendDataReadyToDownStream
();
SendDataReadyToDownStream
();
// reply to upstream and decrease ready data
// reply to upstream and decrease ready data
ReplyCompletedToUpStream
();
ReplyCompletedToUpStream
();
// Try to stop Carrier
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" is stopping carrier."
;
// FIXME(wangxi): with multi sink interceptor
StopCarrier
();
}
}
}
}
}
void
ComputeInterceptor
::
ReceivedStop
(
int64_t
up_id
)
{
void
ComputeInterceptor
::
DecodeMsgVars
(
const
InterceptorMessage
&
msg
)
{
received_stop_
=
true
;
int64_t
scope_id
=
msg
.
scope_idx
();
PADDLE_ENFORCE_LT
(
scope_id
,
// source node has no upstream, stop is send by carrier or others
microbatch_scopes_
.
size
(),
if
(
is_source_
&&
up_id
==
-
1
)
return
;
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
auto
it
=
in_stops_
.
find
(
up_id
);
"microbatch_scopes, but recevice scope index %ld"
,
PADDLE_ENFORCE_NE
(
it
,
microbatch_scopes_
.
size
(),
in_stops_
.
end
(),
scope_id
));
platform
::
errors
::
NotFound
(
auto
*
scope
=
microbatch_scopes_
[
scope_id
];
"Cannot find upstream=%lld in in_stops."
,
up_id
));
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
PADDLE_ENFORCE_EQ
(
for
(
const
auto
&
var_iter
:
msg
.
vars_list
())
{
it
->
second
,
const
std
::
string
&
name
=
var_iter
.
name
();
false
,
auto
&
dev_ctx
=
*
pool
.
Get
(
place_
);
platform
::
errors
::
AlreadyExists
(
"Already received stop from %lld, stop "
std
::
istringstream
ss
(
var_iter
.
stensor
());
"cannot be send more than once."
));
auto
*
var
=
scope
->
Var
(
name
);
it
->
second
=
true
;
auto
*
tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
}
DeserializeFromStream
(
ss
,
tensor
,
dev_ctx
);
void
ComputeInterceptor
::
TryStop
()
{
VLOG
(
3
)
<<
"Set vars "
<<
name
<<
" with value in scope "
<<
scope_id
if
(
!
received_stop_
)
return
;
<<
" with dims "
<<
tensor
->
dims
()
<<
" with dtype "
<<
tensor
->
dtype
();
// can stop only when all upstream is stop and
// downstream complete
for
(
auto
&
in_stop
:
in_stops_
)
{
if
(
!
in_stop
.
second
)
return
;
}
for
(
auto
&
out_buff
:
out_buffs_
)
{
auto
used_size
=
out_buff
.
second
.
second
;
if
(
used_size
!=
0
)
return
;
}
// send stop to downstream
for
(
auto
&
out
:
out_buffs_
)
{
auto
down_id
=
out
.
first
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
down_id
,
stop
);
}
}
stop_
=
true
;
}
}
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
IncreaseReady
(
msg
.
src_id
());
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive data_is_ready "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
Run
();
Run
();
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive data_is_useless "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
DecreaseBuff
(
msg
.
src_id
());
DecreaseBuff
(
msg
.
src_id
());
Run
();
Run
();
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
}
else
if
(
msg
.
message_type
()
==
DATA_WITH_VARS
)
{
ReceivedStop
(
msg
.
src_id
());
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive data_with_vars "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
DecodeMsgVars
(
msg
);
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
Run
();
}
else
if
(
msg
.
message_type
()
==
START_LOOP
)
{
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive start_loop "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
scope_id_to_finish_flag_
.
emplace
(
msg
.
scope_idx
(),
false
);
Run
();
}
}
TryStop
();
}
}
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
浏览文件 @
92c2dcbd
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <queue>
#include <utility>
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...
@@ -21,6 +22,8 @@
...
@@ -21,6 +22,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
const
int64_t
INFINITE_BUFFER_SIZE
=
-
1
;
class
ComputeInterceptor
:
public
Interceptor
{
class
ComputeInterceptor
:
public
Interceptor
{
public:
public:
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
@@ -29,33 +32,27 @@ class ComputeInterceptor : public Interceptor {
...
@@ -29,33 +32,27 @@ class ComputeInterceptor : public Interceptor {
virtual
void
RunOps
();
virtual
void
RunOps
();
virtual
void
SendDataReadyToDownStream
();
virtual
void
SendDataReadyToDownStream
();
virtual
void
ReplyCompletedToUpStream
();
virtual
void
ReplyCompletedToUpStream
();
virtual
void
Compute
(
const
InterceptorMessage
&
msg
);
void
Run
();
void
IncreaseReady
(
int64_t
up_id
,
int64_t
scope_id
);
void
DecreaseBuff
(
int64_t
down_id
);
int64_t
cur_scope_id_
;
int64_t
step_
{
0
};
// upstream_id-->(max_ready_size, scope-->ready_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
std
::
map
<
int64_t
,
int64_t
>>>
in_readys_
{};
// downstream_id-->(max_buffer_size, used_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
private:
private:
void
PrepareDeps
();
void
PrepareDeps
();
InterceptorMessage
PrepareVarsMsg
();
void
DecodeMsgVars
(
const
InterceptorMessage
&
msg
);
void
IncreaseReady
(
int64_t
up_id
);
void
DecreaseBuff
(
int64_t
down_id
);
bool
IsInputReady
();
bool
IsInputReady
();
bool
CanWriteOutput
();
bool
CanWriteOutput
();
std
::
map
<
int64_t
,
bool
>
scope_id_to_finish_flag_
;
void
Run
();
void
Compute
(
const
InterceptorMessage
&
msg
);
void
ReceivedStop
(
int64_t
up_id
);
void
TryStop
();
bool
is_source_
{
false
};
bool
is_last_
{
false
};
// upstream_id-->(max_ready_size, ready_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
in_readys_
{};
// downstream_id-->(max_buffer_size, used_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
bool
received_stop_
{
false
};
std
::
map
<
int64_t
,
bool
>
in_stops_
{};
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
0 → 100644
浏览文件 @
92c2dcbd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include <algorithm>
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
distributed
{
CondInterceptor
::
CondInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
PrepareDeps
();
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Run
(
msg
);
});
}
void
CondInterceptor
::
PrepareDeps
()
{
auto
&
upstream
=
node_
->
upstream
();
auto
&
downstream
=
node_
->
downstream
();
auto
&
id_to_dep_type
=
node_
->
id_to_dep_type
();
for
(
const
auto
&
up
:
upstream
)
{
if
(
id_to_dep_type
.
at
(
up
.
first
)
==
DependType
::
NORMAL
)
{
normal_in_id_
.
insert
(
up
.
first
);
}
else
if
(
id_to_dep_type
.
at
(
up
.
first
)
==
DependType
::
LOOP
)
{
loop_id_
=
up
.
first
;
}
}
for
(
const
auto
&
down
:
downstream
)
{
if
(
id_to_dep_type
.
at
(
down
.
first
)
==
DependType
::
NORMAL
)
{
normal_out_id_
.
insert
(
down
.
first
);
}
else
if
(
id_to_dep_type
.
at
(
down
.
first
)
==
DependType
::
STOP_LOOP
)
{
stop_loop_id_
=
down
.
first
;
}
}
}
bool
CondInterceptor
::
GetCondResult
()
{
PADDLE_ENFORCE_LT
(
cur_scope_id_
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
cur_scope_id_
));
auto
*
cond_var
=
microbatch_scopes_
[
cur_scope_id_
]
->
FindVar
(
node_
->
cond_var
());
PADDLE_ENFORCE
(
cond_var
,
platform
::
errors
::
NotFound
(
"Condition variable %s not exists in scope %ld"
,
node_
->
cond_var
(),
cur_scope_id_
));
const
auto
&
cond_tensor
=
cond_var
->
Get
<
phi
::
DenseTensor
>
();
bool
res
=
false
;
if
(
platform
::
is_gpu_place
(
cond_tensor
.
place
()))
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi
::
DenseTensor
cpu_tensor
;
framework
::
TensorCopy
(
cond_tensor
,
platform
::
CPUPlace
(),
&
cpu_tensor
);
platform
::
DeviceContextPool
::
Instance
().
Get
(
cond_tensor
.
place
())
->
Wait
();
res
=
cpu_tensor
.
data
<
bool
>
()[
0
];
#endif
}
else
if
(
platform
::
is_cpu_place
(
cond_tensor
.
place
()))
{
res
=
cond_tensor
.
data
<
bool
>
()[
0
];
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport device for cond interceptor."
));
}
return
res
;
}
void
CondInterceptor
::
SendDataReady
(
int64_t
down_id
)
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
down_id
,
ready_msg
);
}
void
CondInterceptor
::
SendStartLoop
(
int64_t
down_id
)
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
START_LOOP
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
down_id
,
ready_msg
);
}
void
CondInterceptor
::
ReplyDataIsUseless
(
int64_t
up_id
)
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_USELESS
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
up_id
,
ready_msg
);
}
void
CondInterceptor
::
Compute
()
{
bool
cond
=
GetCondResult
();
VLOG
(
3
)
<<
"Cond interceptor get condition var "
<<
node_
->
cond_var
()
<<
" with value "
<<
cond
;
if
(
cond
)
{
VLOG
(
3
)
<<
"Loop again in scope "
<<
cur_scope_id_
;
for
(
auto
&
down_id
:
normal_out_id_
)
{
SendStartLoop
(
down_id
);
}
++
num_of_scopes_
;
}
else
{
VLOG
(
3
)
<<
"Finish loop in scope "
<<
cur_scope_id_
;
SendDataReady
(
stop_loop_id_
);
}
}
void
CondInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
||
msg
.
message_type
()
==
DATA_WITH_VARS
)
{
if
(
msg
.
src_id
()
==
loop_id_
)
{
--
num_of_scopes_
;
VLOG
(
3
)
<<
"Receving loop again message from "
<<
msg
.
src_id
()
<<
" waiting other "
<<
num_of_scopes_
<<
" scopes ready"
;
ready_scope_id_
.
emplace_back
(
msg
.
scope_idx
());
if
(
num_of_scopes_
==
0
)
{
std
::
sort
(
ready_scope_id_
.
begin
(),
ready_scope_id_
.
end
());
for
(
auto
scope_id
:
ready_scope_id_
)
{
VLOG
(
3
)
<<
"Start a new loop in scope "
<<
scope_id
;
cur_scope_id_
=
scope_id
;
Compute
();
}
ready_scope_id_
.
clear
();
}
}
else
{
cur_scope_id_
=
msg
.
scope_idx
();
Compute
();
}
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
if
(
node_
->
id_to_dep_type
().
at
(
msg
.
src_id
())
==
DependType
::
STOP_LOOP
)
{
for
(
auto
&
up_id
:
normal_in_id_
)
{
ReplyDataIsUseless
(
up_id
);
}
// Gc the variable in while block
int64_t
scope_id
=
msg
.
scope_idx
();
if
(
gc_
)
{
VLOG
(
3
)
<<
"Release vars in while block in scope "
<<
scope_id
;
framework
::
DeleteUnusedTensors
(
*
microbatch_scopes_
[
scope_id
],
node_
->
while_block_vars
(),
gc_
.
get
());
}
}
}
}
REGISTER_INTERCEPTOR
(
Cond
,
CondInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
0 → 100644
浏览文件 @
92c2dcbd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iomanip>
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
/* Condition Interceptor
* This is a special interceptor and only one condition op in the task node.
* This interceptor has two downstreams,
* 1. If the program result is true, select one of the downstreams, otherwise
* select another.
* 2. Used to implement while op in program.
*/
class
CondInterceptor
final
:
public
Interceptor
{
public:
CondInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
PrepareDeps
();
void
Run
(
const
InterceptorMessage
&
msg
);
void
Compute
();
bool
GetCondResult
();
void
SendDataReady
(
int64_t
down_id
);
void
SendStartLoop
(
int64_t
down_id
);
void
ReplyDataIsUseless
(
int64_t
up_id
);
int64_t
cur_scope_id_
;
std
::
set
<
int64_t
>
normal_in_id_
;
std
::
set
<
int64_t
>
normal_out_id_
;
int64_t
stop_loop_id_
;
int64_t
loop_id_
;
int64_t
num_of_scopes_
{
0
};
std
::
vector
<
int64_t
>
ready_scope_id_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
92c2dcbd
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include <algorithm>
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...
@@ -24,6 +26,7 @@
...
@@ -24,6 +26,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/variable.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -51,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
...
@@ -51,40 +54,40 @@ FleetExecutor::~FleetExecutor() {
}
}
}
}
void
FleetExecutor
::
Init
(
namespace
{
const
std
::
string
&
carrier_id
,
void
GetSubBlockTask
(
const
std
::
vector
<
TaskNode
*>&
tasks
,
const
framework
::
ProgramDesc
&
program_desc
,
TaskNode
*
cur_task
,
framework
::
Scope
*
scope
,
std
::
set
<
TaskNode
*>*
sub_block_task
)
{
const
platform
::
Place
&
place
,
auto
&
downstream
=
cur_task
->
downstream
();
int64_t
num_micro_batches
,
auto
&
id_to_dep_type
=
cur_task
->
id_to_dep_type
();
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
for
(
auto
&
down
:
downstream
)
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
int64_t
task_id
=
down
.
first
;
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
if
(
id_to_dep_type
.
at
(
task_id
)
==
DependType
::
NORMAL
)
{
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
for
(
const
auto
&
task
:
tasks
)
{
0
,
if
(
task
->
task_id
()
==
task_id
)
{
platform
::
errors
::
InvalidArgument
(
sub_block_task
->
emplace
(
task
);
"Fleet executor is inited with empty task node"
));
GetSubBlockTask
(
tasks
,
task
,
sub_block_task
);
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
}
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
}
for
(
auto
task_node
:
task_nodes
)
{
for
(
auto
op
:
task_node
->
ops
())
{
ops
.
emplace_back
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
));
}
}
}
}
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
}
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
void
PreventVarsDelete
(
// inf. If they are GCed, it will cause error during ZeroCopy the result.
std
::
unordered_map
<
const
framework
::
OperatorBase
*
,
std
::
vector
<
std
::
string
>>*
unused_vars
,
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
std
::
vector
<
const
framework
::
OperatorBase
*>
changed_ops
;
std
::
vector
<
const
framework
::
OperatorBase
*>
changed_ops
;
for
(
auto
pair
:
unused_vars
)
{
for
(
const
auto
&
pair
:
*
unused_vars
)
{
const
framework
::
OperatorBase
*
op
=
pair
.
first
;
const
framework
::
OperatorBase
*
op
=
pair
.
first
;
std
::
vector
<
std
::
string
>
unused
=
pair
.
second
;
std
::
vector
<
std
::
string
>
cur_
unused
=
pair
.
second
;
for
(
auto
name
:
inference_root_scope_vars
)
{
for
(
auto
name
:
vars_not_gc
)
{
auto
iter
=
std
::
find
(
unused
.
begin
(),
unused
.
end
(),
name
);
auto
iter
=
std
::
find
(
cur_unused
.
begin
(),
cur_
unused
.
end
(),
name
);
if
(
iter
!=
unused
.
end
())
{
if
(
iter
!=
cur_
unused
.
end
())
{
VLOG
(
3
)
<<
"Removing var: ["
<<
name
VLOG
(
3
)
<<
"Removing var: ["
<<
name
<<
"] from the unused vars list of op: ["
<<
op
->
Type
()
<<
"]"
;
<<
"] from the unused vars list of op: ["
<<
op
->
Type
()
<<
"]"
;
unused
.
erase
(
iter
);
cur_
unused
.
erase
(
iter
);
if
(
std
::
find
(
changed_ops
.
begin
(),
changed_ops
.
end
(),
op
)
==
if
(
std
::
find
(
changed_ops
.
begin
(),
changed_ops
.
end
(),
op
)
==
changed_ops
.
end
())
{
changed_ops
.
end
())
{
// record the op whose unused vars have been updated
// record the op whose unused vars have been updated
...
@@ -93,28 +96,120 @@ void FleetExecutor::Init(
...
@@ -93,28 +96,120 @@ void FleetExecutor::Init(
}
}
}
}
// update the unused vars list in the map
// update the unused vars list in the map
unused_vars
[
op
]
=
unused
;
unused_vars
->
at
(
op
)
=
cur_
unused
;
}
}
for
(
auto
op
:
changed_ops
)
{
for
(
auto
op
:
changed_ops
)
{
auto
iter
=
unused_vars
.
find
(
op
);
const
auto
&
iter
=
unused_vars
->
find
(
op
);
if
(
iter
->
second
.
empty
())
{
if
(
iter
->
second
.
empty
())
{
// remove those ops in the map that have empty unused vars list
// remove those ops in the map that have empty unused vars list
VLOG
(
3
)
<<
"Removing op: ["
<<
op
->
Type
()
<<
"] from unused_vars map."
;
VLOG
(
3
)
<<
"Removing op: ["
<<
op
->
Type
()
<<
"] from unused_vars map."
;
unused_vars
.
erase
(
iter
);
unused_vars
->
erase
(
iter
);
}
}
}
std
::
vector
<
std
::
string
>
GetUnusedVarsAfterWhile
(
const
framework
::
ProgramDesc
&
program_desc
,
TaskNode
*
cond_task
,
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// The vars in while block should not be free until the while op is finished.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
std
::
vector
<
std
::
string
>
while_block_vars
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
for
(
const
auto
&
desc
:
program_desc
.
Block
(
0
).
AllOps
())
{
ops
.
emplace_back
(
framework
::
OpRegistry
::
CreateOp
(
*
desc
));
}
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
PreventVarsDelete
(
&
unused_vars
,
vars_not_gc
);
for
(
const
auto
&
pair
:
unused_vars
)
{
if
(
pair
.
first
->
Type
()
==
"while"
)
{
for
(
const
auto
&
var_name
:
pair
.
second
)
{
while_block_vars
.
emplace_back
(
var_name
);
}
for
(
auto
&
var
:
program_desc
.
Block
(
1
).
AllVars
())
{
while_block_vars
.
emplace_back
(
var
->
Name
());
}
}
}
return
while_block_vars
;
}
}
// namespace
void
FleetExecutor
::
Init
(
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Fleet executor is inited with empty task node"
));
// Set the unused var after running while op
std
::
set
<
TaskNode
*>
sub_block_tasks
;
std
::
vector
<
std
::
string
>
while_block_vars
;
for
(
const
auto
&
task_node
:
task_nodes
)
{
if
(
task_node
->
type
()
==
"Cond"
)
{
GetSubBlockTask
(
task_nodes
,
task_node
,
&
sub_block_tasks
);
while_block_vars
=
GetUnusedVarsAfterWhile
(
program_desc
,
task_node
,
inference_root_scope_vars
);
VLOG
(
3
)
<<
"Vars will be gced after while op"
;
for
(
auto
var
:
while_block_vars
)
{
VLOG
(
3
)
<<
var
;
}
task_node
->
SetWhileBlockVars
(
while_block_vars
);
}
}
std
::
vector
<
framework
::
OperatorBase
*>
sub_block_ops
;
for
(
const
auto
&
task_node
:
sub_block_tasks
)
{
for
(
const
auto
&
op
:
task_node
->
ops
())
{
sub_block_ops
.
emplace_back
(
op
);
}
}
}
}
// Analyse the unused vars in block 0. The operators in block 1
// should be passed in first for prevent vars been released but removed soon.
// Since the unused vars in block 1 need to analyse separately.
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
for
(
const
auto
&
task_node
:
task_nodes
)
{
for
(
const
auto
&
op
:
task_node
->
ops
())
{
ops
.
emplace_back
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
));
}
}
auto
global_unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
PreventVarsDelete
(
&
global_unused_vars
,
inference_root_scope_vars
);
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
for
(
auto
task_node
:
task_nodes
)
{
for
(
auto
task_node
:
task_nodes
)
{
task_node
->
SetUnusedVars
(
unused_vars
);
if
(
sub_block_tasks
.
find
(
task_node
)
==
sub_block_tasks
.
end
())
{
task_node
->
SetUnusedVars
(
global_unused_vars
);
}
int64_t
interceptor_id
=
task_node
->
task_id
();
int64_t
interceptor_id
=
task_node
->
task_id
();
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
}
}
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
...
@@ -126,7 +221,8 @@ void FleetExecutor::Init(
...
@@ -126,7 +221,8 @@ void FleetExecutor::Init(
place
,
place
,
num_micro_batches
,
num_micro_batches
,
program_desc
,
program_desc
,
inference_root_scope_vars
);
inference_root_scope_vars
,
micro_scope_list
);
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
}
...
@@ -136,7 +232,8 @@ void FleetExecutor::InitCarrier(
...
@@ -136,7 +232,8 @@ void FleetExecutor::InitCarrier(
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
,
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
)
{
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_node
(),
runtime_graph_
->
interceptor_id_to_node
(),
...
@@ -144,7 +241,8 @@ void FleetExecutor::InitCarrier(
...
@@ -144,7 +241,8 @@ void FleetExecutor::InitCarrier(
scope
,
scope
,
num_micro_batches
,
num_micro_batches
,
place
,
place
,
inference_root_scope_vars
);
inference_root_scope_vars
,
micro_scope_list
);
}
}
void
FleetExecutor
::
InitMessageBus
()
{
void
FleetExecutor
::
InitMessageBus
()
{
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
92c2dcbd
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -45,7 +46,8 @@ class FleetExecutor final {
...
@@ -45,7 +46,8 @@ class FleetExecutor final {
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
void
Run
(
const
std
::
string
&
carrier_id
);
void
Run
(
const
std
::
string
&
carrier_id
);
private:
private:
...
@@ -57,7 +59,8 @@ class FleetExecutor final {
...
@@ -57,7 +59,8 @@ class FleetExecutor final {
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{},
const
std
::
vector
<
framework
::
Scope
*>&
micro_scope_list
=
{});
FleetExecutorDesc
exe_desc_
;
FleetExecutorDesc
exe_desc_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
92c2dcbd
...
@@ -93,7 +93,6 @@ class Interceptor {
...
@@ -93,7 +93,6 @@ class Interceptor {
TaskNode
*
node_
;
TaskNode
*
node_
;
// for stop
// for stop
bool
stop_
{
false
};
void
StopCarrier
();
void
StopCarrier
();
// for runtime
// for runtime
...
@@ -114,9 +113,6 @@ class Interceptor {
...
@@ -114,9 +113,6 @@ class Interceptor {
std
::
mutex
mutex_
;
std
::
mutex
mutex_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
};
};
class
InterceptorFactory
{
class
InterceptorFactory
{
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
浏览文件 @
92c2dcbd
...
@@ -24,6 +24,21 @@ enum MessageType {
...
@@ -24,6 +24,21 @@ enum MessageType {
ERR
=
4
;
// current Interceptor encounters error
ERR
=
4
;
// current Interceptor encounters error
RESET
=
5
;
// reset the status
RESET
=
5
;
// reset the status
START
=
6
;
START
=
6
;
DATA_WITH_VARS
=
7
;
START_LOOP
=
8
;
}
enum
ValueType
{
INT3
=
0
;
INT6
=
1
;
FLOAT
=
2
;
DOUBLE
=
3
;
BOOL
=
4
;
}
message
VarList
{
required
string
name
=
1
;
required
string
stensor
=
2
;
}
}
message
InterceptorMessage
{
message
InterceptorMessage
{
...
@@ -32,6 +47,7 @@ message InterceptorMessage {
...
@@ -32,6 +47,7 @@ message InterceptorMessage {
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
int64
scope_idx
=
5
[
default
=
0
];
optional
int64
scope_idx
=
5
[
default
=
0
];
repeated
VarList
vars_list
=
6
;
}
}
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
...
...
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
浏览文件 @
92c2dcbd
...
@@ -25,7 +25,7 @@ namespace distributed {
...
@@ -25,7 +25,7 @@ namespace distributed {
* 1. record the num of micro-step
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
* 2. check whether to notify carrier the current step is finished
*/
*/
class
SinkInterceptor
:
public
Interceptor
{
class
SinkInterceptor
final
:
public
Interceptor
{
public:
public:
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/source_interceptor.h
浏览文件 @
92c2dcbd
...
@@ -25,7 +25,7 @@ namespace distributed {
...
@@ -25,7 +25,7 @@ namespace distributed {
* 1. receive `start` message from carrier
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
* 2. send num_of_steps `data_is_ready` message to downstream
*/
*/
class
SourceInterceptor
:
public
Interceptor
{
class
SourceInterceptor
final
:
public
Interceptor
{
public:
public:
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
...
...
paddle/fluid/distributed/fleet_executor/start_interceptor.cc
0 → 100644
浏览文件 @
92c2dcbd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/start_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
distributed
{
StartInterceptor
::
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
ComputeInterceptor
(
interceptor_id
,
node
)
{
auto
&
downstream
=
node_
->
downstream
();
PADDLE_ENFORCE_EQ
(
downstream
.
size
(),
1
,
platform
::
errors
::
OutOfRange
(
"The downstream for StartInterceptor only support 1 for now."
));
for
(
auto
down
:
downstream
)
{
batch_size_
=
down
.
second
;
}
bool
evenly_divisible
=
((
node_
->
max_run_times
()
%
batch_size_
)
==
0
);
PADDLE_ENFORCE
(
evenly_divisible
,
platform
::
errors
::
Fatal
(
"Wrong config: Num of step should be divided by batch_size,"
"num_step=%lld, batch_size=%lld"
,
node_
->
max_run_times
(),
batch_size_
));
}
void
StartInterceptor
::
RunOps
()
{
finish_count_
++
;
ComputeInterceptor
::
RunOps
();
}
void
StartInterceptor
::
SendDataReadyToDownStream
()
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
down_id
=
outs
.
first
;
auto
max_buff_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
used_size
+=
1
;
if
(
max_buff_size
!=
INFINITE_BUFFER_SIZE
)
{
PADDLE_ENFORCE_LE
(
used_size
,
max_buff_size
,
platform
::
errors
::
OutOfRange
(
"downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld"
,
down_id
,
used_size
,
max_buff_size
));
}
outs
.
second
.
second
=
used_size
;
}
if
(
finish_count_
==
batch_size_
)
{
for
(
int64_t
i
=
0
;
i
<
batch_size_
;
++
i
)
{
int64_t
scope_id
=
step_
%
node_
->
max_run_times
();
for
(
auto
&
outs
:
out_buffs_
)
{
auto
down_id
=
outs
.
first
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
scope_id
);
VLOG
(
3
)
<<
"StartInterceptor "
<<
interceptor_id_
<<
" Send data_is_ready msg to "
<<
down_id
<<
" in scope: "
<<
scope_id
;
Send
(
down_id
,
ready_msg
);
}
step_
++
;
}
}
}
void
StartInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
VLOG
(
3
)
<<
"Start interceptor "
<<
interceptor_id_
<<
" receive data_is_ready "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
Run
();
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
VLOG
(
3
)
<<
"Start interceptor receive data_is_useless "
<<
msg
.
src_id
()
<<
" "
<<
finish_count_
;
finish_count_
--
;
if
(
finish_count_
==
0
)
{
for
(
int64_t
i
=
0
;
i
<
batch_size_
;
++
i
)
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
down_id
=
outs
.
first
;
DecreaseBuff
(
down_id
);
}
}
for
(
int64_t
i
=
0
;
i
<
batch_size_
;
++
i
)
{
Run
();
}
}
}
}
REGISTER_INTERCEPTOR
(
Start
,
StartInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/start_interceptor.h
0 → 100644
浏览文件 @
92c2dcbd
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace
paddle
{
namespace
distributed
{
class
StartInterceptor
final
:
public
ComputeInterceptor
{
public:
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
SendDataReadyToDownStream
()
override
;
void
RunOps
()
override
;
void
Compute
(
const
InterceptorMessage
&
msg
)
override
;
int64_t
batch_size_
{
0
};
int64_t
finish_count_
{
0
};
int64_t
step_
{
0
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_node.cc
浏览文件 @
92c2dcbd
...
@@ -24,33 +24,14 @@ namespace {
...
@@ -24,33 +24,14 @@ namespace {
using
OperatorBase
=
TaskNode
::
OperatorBase
;
using
OperatorBase
=
TaskNode
::
OperatorBase
;
}
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
program_
(
program
),
rank_
(
rank
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static
int64_t
task_node_cnt
=
0
;
task_id_
=
task_node_cnt
++
;
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
)
int64_t
max_slot_nums
)
:
program_
(
program
),
:
program_
(
program
),
rank_
(
rank
),
rank_
(
rank
),
task_id_
(
task_id
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_run_times_
(
max_run_times
)
{
max_slot_nums_
(
max_slot_nums
)
{
// TODO(liyurui): Will be removed when execute program is supported.
// TODO(liyurui): Will be removed when execute program is supported.
Init
();
Init
();
}
}
...
@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
...
@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
)
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
)
:
program_
(
program
),
rank_
(
rank
),
task_id_
(
rank
)
{
:
program_
(
program
),
rank_
(
rank
),
task_id_
(
rank
)
{
max_run_times_
=
1
;
max_run_times_
=
1
;
max_slot_nums_
=
1
;
LOG
(
INFO
)
LOG
(
INFO
)
<<
"Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<<
"Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<<
rank
<<
rank
...
@@ -69,6 +49,16 @@ void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
...
@@ -69,6 +49,16 @@ void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_
=
program
;
program_
=
program
;
}
}
void
TaskNode
::
SetVarsToDtype
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
)
{
vars_to_dtype_
=
vars_to_dtype
;
}
void
TaskNode
::
SetVarsToShape
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
)
{
vars_to_shape_
=
vars_to_shape
;
}
void
TaskNode
::
Init
(
bool
use_feed_fetch_ops
)
{
void
TaskNode
::
Init
(
bool
use_feed_fetch_ops
)
{
if
(
!
use_feed_fetch_ops
)
{
if
(
!
use_feed_fetch_ops
)
{
VLOG
(
3
)
<<
"TaskNode will be inited without feed and fetch ops"
;
VLOG
(
3
)
<<
"TaskNode will be inited without feed and fetch ops"
;
...
@@ -98,13 +88,11 @@ TaskNode::TaskNode(int32_t role,
...
@@ -98,13 +88,11 @@ TaskNode::TaskNode(int32_t role,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
)
int64_t
max_slot_nums
)
:
role_
(
role
),
:
role_
(
role
),
rank_
(
rank
),
rank_
(
rank
),
task_id_
(
task_id
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_run_times_
(
max_run_times
)
{
max_slot_nums_
(
max_slot_nums
)
{
if
(
op_descs
.
empty
())
{
if
(
op_descs
.
empty
())
{
return
;
return
;
}
}
...
@@ -121,33 +109,35 @@ TaskNode::TaskNode(int32_t role,
...
@@ -121,33 +109,35 @@ TaskNode::TaskNode(int32_t role,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
)
int64_t
max_slot_nums
)
:
ops_
(
ops
),
:
ops_
(
ops
),
role_
(
role
),
role_
(
role
),
rank_
(
rank
),
rank_
(
rank
),
task_id_
(
task_id
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_run_times_
(
max_run_times
)
{}
max_slot_nums_
(
max_slot_nums
)
{}
TaskNode
::
TaskNode
(
int32_t
role
,
TaskNode
::
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
)
int64_t
max_slot_nums
)
:
role_
(
role
),
:
role_
(
role
),
rank_
(
rank
),
rank_
(
rank
),
task_id_
(
task_id
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_run_times_
(
max_run_times
)
{}
max_slot_nums_
(
max_slot_nums
)
{}
bool
TaskNode
::
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
)
{
bool
TaskNode
::
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
,
DependType
type
)
{
const
auto
&
ret
=
upstream_
.
emplace
(
task_id
,
buff_size
);
const
auto
&
ret
=
upstream_
.
emplace
(
task_id
,
buff_size
);
id_to_dep_type_
.
emplace
(
task_id
,
type
);
return
ret
.
second
;
return
ret
.
second
;
}
}
bool
TaskNode
::
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
)
{
bool
TaskNode
::
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
,
DependType
type
)
{
const
auto
&
ret
=
downstream_
.
emplace
(
task_id
,
buff_size
);
const
auto
&
ret
=
downstream_
.
emplace
(
task_id
,
buff_size
);
id_to_dep_type_
.
emplace
(
task_id
,
type
);
return
ret
.
second
;
return
ret
.
second
;
}
}
...
...
paddle/fluid/distributed/fleet_executor/task_node.h
浏览文件 @
92c2dcbd
...
@@ -14,8 +14,10 @@
...
@@ -14,8 +14,10 @@
#pragma once
#pragma once
#include <cstdint>
#include <cstdint>
#include <functional>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
...
@@ -29,38 +31,30 @@ class OpDesc;
...
@@ -29,38 +31,30 @@ class OpDesc;
}
// namespace framework
}
// namespace framework
namespace
distributed
{
namespace
distributed
{
enum
class
DependType
{
NORMAL
,
LOOP
,
STOP_LOOP
};
class
TaskNode
final
{
class
TaskNode
final
{
public:
public:
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
);
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
);
TaskNode
(
int32_t
role
,
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
);
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
int32_t
role
,
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
);
int64_t
max_slot_nums
);
TaskNode
(
int32_t
role
,
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_run_times
);
int64_t
max_slot_nums
);
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
);
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
);
// TODO(liyurui): This will be the only constructor for task node
// TODO(liyurui): This will be the only constructor for task node
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
task_id
,
int64_t
task_id
,
int64_t
rank
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_run_times
);
int64_t
max_slot_nums
);
~
TaskNode
()
=
default
;
~
TaskNode
()
=
default
;
void
SetProgram
(
paddle
::
framework
::
ProgramDesc
*
program
);
void
SetProgram
(
paddle
::
framework
::
ProgramDesc
*
program
);
...
@@ -69,11 +63,11 @@ class TaskNode final {
...
@@ -69,11 +63,11 @@ class TaskNode final {
int64_t
task_id
()
const
{
return
task_id_
;
}
int64_t
task_id
()
const
{
return
task_id_
;
}
int32_t
role
()
const
{
return
role_
;
}
int32_t
role
()
const
{
return
role_
;
}
int64_t
max_run_times
()
const
{
return
max_run_times_
;
}
int64_t
max_run_times
()
const
{
return
max_run_times_
;
}
int64_t
max_slot_nums
()
const
{
return
max_slot_nums_
;
}
int64_t
run_per_steps
()
const
{
return
run_per_steps_
;
}
int64_t
run_per_steps
()
const
{
return
run_per_steps_
;
}
int64_t
run_at_offset
()
const
{
return
run_at_offset_
;
}
int64_t
run_at_offset
()
const
{
return
run_at_offset_
;
}
int64_t
reply_up_per_steps
()
const
{
return
reply_up_per_steps_
;
}
int64_t
reply_up_per_steps
()
const
{
return
reply_up_per_steps_
;
}
int64_t
send_down_per_steps
()
const
{
return
send_down_per_steps_
;
}
int64_t
send_down_per_steps
()
const
{
return
send_down_per_steps_
;
}
const
std
::
string
&
cond_var
()
const
{
return
cond_var_
;
}
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
upstream
()
const
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
upstream
()
const
{
return
upstream_
;
return
upstream_
;
}
}
...
@@ -86,11 +80,20 @@ class TaskNode final {
...
@@ -86,11 +80,20 @@ class TaskNode final {
const
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>&
unique_ops
()
const
{
const
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>&
unique_ops
()
const
{
return
ops_vec_
;
return
ops_vec_
;
}
}
const
std
::
unordered_map
<
int64_t
,
DependType
>
id_to_dep_type
()
const
{
return
id_to_dep_type_
;
}
const
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
const
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
unused_vars
()
const
{
unused_vars
()
const
{
return
unused_vars_
;
return
unused_vars_
;
}
}
const
std
::
vector
<
std
::
string
>
while_block_vars
()
const
{
return
while_block_vars_
;
}
void
SetCondVarName
(
const
std
::
string
&
cond_var_name
)
{
cond_var_
=
cond_var_name
;
}
void
SetRunPerSteps
(
int64_t
value
);
void
SetRunPerSteps
(
int64_t
value
);
void
SetRunAtOffset
(
int64_t
value
);
void
SetRunAtOffset
(
int64_t
value
);
void
SetReplyUpPerSteps
(
int64_t
value
);
void
SetReplyUpPerSteps
(
int64_t
value
);
...
@@ -101,11 +104,27 @@ class TaskNode final {
...
@@ -101,11 +104,27 @@ class TaskNode final {
unused_vars
)
{
unused_vars
)
{
unused_vars_
=
unused_vars
;
unused_vars_
=
unused_vars
;
}
}
void
SetWhileBlockVars
(
const
std
::
vector
<
std
::
string
>&
vars
)
{
while_block_vars_
=
vars
;
}
// upstream need buffs?
// upstream need buffs?
bool
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
=
1
);
bool
AddUpstreamTask
(
int64_t
task_id
,
bool
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
=
1
);
int64_t
buff_size
=
1
,
DependType
type
=
DependType
::
NORMAL
);
bool
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
=
1
,
DependType
type
=
DependType
::
NORMAL
);
std
::
string
DebugString
()
const
;
std
::
string
DebugString
()
const
;
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
()
const
{
return
vars_to_dtype_
;
}
void
SetVarsToDtype
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
);
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
()
const
{
return
vars_to_shape_
;
}
void
SetVarsToShape
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
);
private:
private:
DISABLE_COPY_AND_ASSIGN
(
TaskNode
);
DISABLE_COPY_AND_ASSIGN
(
TaskNode
);
...
@@ -115,16 +134,22 @@ class TaskNode final {
...
@@ -115,16 +134,22 @@ class TaskNode final {
// task_id-->buff_size
// task_id-->buff_size
std
::
unordered_map
<
int64_t
,
int64_t
>
upstream_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
upstream_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
downstream_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
downstream_
;
// task_id-->type
std
::
unordered_map
<
int64_t
,
DependType
>
id_to_dep_type_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ProgramDesc
*
program_
;
std
::
string
cond_var_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_vec_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_vec_
;
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
unused_vars_
;
unused_vars_
;
std
::
vector
<
std
::
string
>
while_block_vars_
;
std
::
map
<
std
::
string
,
std
::
string
>
vars_to_dtype_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>
vars_to_shape_
;
int32_t
role_
;
int32_t
role_
;
int64_t
rank_
;
int64_t
rank_
;
int64_t
task_id_
;
int64_t
task_id_
;
int64_t
max_run_times_
;
int64_t
max_run_times_
;
int64_t
max_slot_nums_
;
int64_t
run_per_steps_
{
1
};
int64_t
run_per_steps_
{
1
};
int64_t
run_at_offset_
{
0
};
int64_t
run_at_offset_
{
0
};
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
92c2dcbd
...
@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
// FIXME: don't delete, otherwise interceptor will use undefined node
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
2
);
// rank, task_id, max_run_times
new
TaskNode
(
0
,
SOURCE_ID
,
2
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
TaskNode
*
node_a
=
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
);
// role, ops, rank, task_id
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
,
0
);
// role, ops, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
2
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
2
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
2
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
2
);
// source->a->b->sink
// source->a->b->sink
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
92c2dcbd
...
@@ -21,61 +21,49 @@ limitations under the License. */
...
@@ -21,61 +21,49 @@ limitations under the License. */
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
class
StartInterceptor
:
public
Interceptor
{
public:
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
1
,
stop
);
// stop 1, compute
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
}
};
TEST
(
ComputeInterceptor
,
Compute
)
{
TEST
(
ComputeInterceptor
,
Compute
)
{
std
::
string
carrier_id
=
"0"
;
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
source
=
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
new
TaskNode
(
0
,
SOURCE_ID
,
3
);
// rank, task_id, max_run_times
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
);
// a->b->c
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
3
);
// source->a->b->sink
source
->
AddDownstreamTask
(
0
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
);
node_a
->
AddDownstreamTask
(
1
,
3
);
node_a
->
AddDownstreamTask
(
1
,
3
);
node_b
->
AddUpstreamTask
(
0
,
3
);
node_b
->
AddUpstreamTask
(
0
,
3
);
node_b
->
AddDownstreamTask
(
2
);
node_b
->
AddDownstreamTask
(
SINK_ID
);
node_c
->
AddUpstreamTask
(
1
);
sink
->
AddUpstreamTask
(
1
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
msg
.
set_message_type
(
START
);
// test run three times
msg
.
set_dst_id
(
SOURCE_ID
);
a
->
Send
(
1
,
msg
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
carrier
->
Release
();
carrier
->
Release
();
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
92c2dcbd
...
@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
...
@@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
return
;
return
;
}
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
92c2dcbd
...
@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
...
@@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor {
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
StopCarrier
();
StopCarrier
();
return
;
return
;
}
}
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
92c2dcbd
...
@@ -66,17 +66,17 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -66,17 +66,17 @@ TEST(AmplifierInterceptor, Amplifier) {
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
int64_t
micro_steps
=
3
;
int64_t
micro_steps
=
1
;
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
,
0
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
,
0
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
,
0
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
// source->a->b->c->d->e->f->sink
// source->a->b->c->d->e->f->sink
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
92c2dcbd
...
@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
...
@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
micro_steps
);
// role, rank, task_id
new
TaskNode
(
0
,
0
,
0
,
micro_steps
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
micro_steps
);
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
micro_steps
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
// source->a->b->c->d->sink
// source->a->b->c->d->sink
...
...
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
浏览文件 @
92c2dcbd
...
@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
...
@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
);
// role, rank, task_id
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
0
,
3
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
...
...
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
浏览文件 @
92c2dcbd
...
@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
...
@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
);
// role, rank, task_id
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
...
...
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
浏览文件 @
92c2dcbd
...
@@ -112,5 +112,7 @@ REGISTER_OP_CUDA_KERNEL(c_broadcast,
...
@@ -112,5 +112,7 @@ REGISTER_OP_CUDA_KERNEL(c_broadcast,
ops
::
CBroadcastOpCUDAKernel
<
plat
::
bfloat16
>
,
ops
::
CBroadcastOpCUDAKernel
<
plat
::
bfloat16
>
,
#endif
#endif
ops
::
CBroadcastOpCUDAKernel
<
int
>
,
ops
::
CBroadcastOpCUDAKernel
<
int
>
,
ops
::
CBroadcastOpCUDAKernel
<
uint8_t
>
,
ops
::
CBroadcastOpCUDAKernel
<
int8_t
>
,
ops
::
CBroadcastOpCUDAKernel
<
int64_t
>
,
ops
::
CBroadcastOpCUDAKernel
<
int64_t
>
,
ops
::
CBroadcastOpCUDAKernel
<
plat
::
float16
>
);
ops
::
CBroadcastOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_embedding_op.cu
浏览文件 @
92c2dcbd
...
@@ -19,6 +19,8 @@ limitations under the License. */
...
@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool
(
cudnn_deterministic
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -83,6 +85,32 @@ __global__ void CEmbeddingGrad(T *table,
...
@@ -83,6 +85,32 @@ __global__ void CEmbeddingGrad(T *table,
}
}
}
}
template
<
typename
T
,
typename
IndexT
>
__global__
void
CEmbeddingGradSerial
(
T
*
table
,
const
T
*
output
,
const
IndexT
*
ids
,
const
int
rows
,
const
int
columns
,
const
int64_t
N
,
const
int64_t
start_idx
,
const
int64_t
end_idx
,
const
int64_t
limit
)
{
CUDA_KERNEL_LOOP
(
i
,
limit
)
{
if
(
i
==
0
)
{
for
(
int
j
=
0
;
j
<
limit
;
j
++
)
{
size_t
row
=
j
/
columns
;
size_t
col
=
j
%
columns
;
auto
id
=
ids
[
row
];
if
(
id
>=
start_idx
&&
id
<
end_idx
)
{
auto
real_idx
=
id
-
start_idx
;
paddle
::
platform
::
CudaAtomicAdd
(
&
table
[
real_idx
*
columns
+
col
],
output
[
i
]);
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
class
CEmbeddingCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CEmbeddingCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -163,28 +191,56 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -163,28 +191,56 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
t
.
device
(
*
dev_ctx
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
t
.
device
(
*
dev_ctx
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
const
auto
&
index_type
=
framework
::
TransToProtoVarType
(
ids_t
->
dtype
());
const
auto
&
index_type
=
framework
::
TransToProtoVarType
(
ids_t
->
dtype
());
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
FLAGS_cudnn_deterministic
)
{
CEmbeddingGrad
<
T
,
int32_t
>
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
blocks
=
1
;
d_output
,
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
ids_t
->
data
<
int32_t
>
(),
CEmbeddingGradSerial
<
T
,
int32_t
>
K
,
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
D
,
d_output
,
N
,
ids_t
->
data
<
int32_t
>
(),
start_idx
,
K
,
end_idx
,
D
,
limit
);
N
,
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
start_idx
,
CEmbeddingGrad
<
T
,
int64_t
>
end_idx
,
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
limit
);
d_output
,
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
ids_t
->
data
<
int64_t
>
(),
CEmbeddingGradSerial
<
T
,
int64_t
>
K
,
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
D
,
d_output
,
N
,
ids_t
->
data
<
int64_t
>
(),
start_idx
,
K
,
end_idx
,
D
,
limit
);
N
,
start_idx
,
end_idx
,
limit
);
}
}
else
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CEmbeddingGrad
<
T
,
int32_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids_t
->
data
<
int32_t
>
(),
K
,
D
,
N
,
start_idx
,
end_idx
,
limit
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
CEmbeddingGrad
<
T
,
int64_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids_t
->
data
<
int64_t
>
(),
K
,
D
,
N
,
start_idx
,
end_idx
,
limit
);
}
}
}
}
}
};
};
...
...
paddle/fluid/pybind/bind_fleet_executor.cc
浏览文件 @
92c2dcbd
...
@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
...
@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
namespace
paddle
{
namespace
paddle
{
namespace
pybind
{
namespace
pybind
{
using
paddle
::
distributed
::
DependType
;
using
paddle
::
distributed
::
DistModel
;
using
paddle
::
distributed
::
DistModel
;
using
paddle
::
distributed
::
DistModelConfig
;
using
paddle
::
distributed
::
DistModelConfig
;
using
paddle
::
distributed
::
DistModelDataBuf
;
using
paddle
::
distributed
::
DistModelDataBuf
;
...
@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) {
...
@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) {
.
def
(
.
def
(
"run"
,
&
FleetExecutor
::
Run
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"run"
,
&
FleetExecutor
::
Run
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
enum_
<
DependType
>
(
*
m
,
"DependType"
)
.
value
(
"NORMAL"
,
DependType
::
NORMAL
)
.
value
(
"LOOP"
,
DependType
::
LOOP
)
.
value
(
"STOP_LOOP"
,
DependType
::
STOP_LOOP
);
py
::
class_
<
TaskNode
>
(
*
m
,
"TaskNode"
)
py
::
class_
<
TaskNode
>
(
*
m
,
"TaskNode"
)
.
def
(
py
::
init
<
framework
::
ProgramDesc
*
,
int64_t
,
int64_t
,
int64_t
,
int64_t
>
())
.
def
(
py
::
init
<
framework
::
ProgramDesc
*
,
int64_t
,
int64_t
,
int64_t
>
())
.
def
(
py
::
init
<
framework
::
ProgramDesc
*
,
int64_t
,
int64_t
,
int64_t
>
())
.
def
(
py
::
init
<
int32_t
,
.
def
(
py
::
init
<
int32_t
,
const
std
::
vector
<
framework
::
OpDesc
*>&
,
const
std
::
vector
<
framework
::
OpDesc
*>&
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
int64_t
>
())
int64_t
>
())
.
def
(
"task_id"
,
&
TaskNode
::
task_id
)
.
def
(
"task_id"
,
&
TaskNode
::
task_id
)
.
def
(
"add_upstream_task"
,
&
TaskNode
::
AddUpstreamTask
)
.
def
(
"add_upstream_task"
,
&
TaskNode
::
AddUpstreamTask
)
...
@@ -183,7 +183,10 @@ void BindFleetExecutor(py::module* m) {
...
@@ -183,7 +183,10 @@ void BindFleetExecutor(py::module* m) {
.
def
(
"set_run_pre_steps"
,
&
TaskNode
::
SetRunPerSteps
)
.
def
(
"set_run_pre_steps"
,
&
TaskNode
::
SetRunPerSteps
)
.
def
(
"set_run_at_offset"
,
&
TaskNode
::
SetRunAtOffset
)
.
def
(
"set_run_at_offset"
,
&
TaskNode
::
SetRunAtOffset
)
.
def
(
"set_type"
,
&
TaskNode
::
SetType
)
.
def
(
"set_type"
,
&
TaskNode
::
SetType
)
.
def
(
"set_cond_var_name"
,
&
TaskNode
::
SetCondVarName
)
.
def
(
"role"
,
&
TaskNode
::
role
)
.
def
(
"role"
,
&
TaskNode
::
role
)
.
def
(
"set_vars_to_shape"
,
&
TaskNode
::
SetVarsToShape
)
.
def
(
"set_vars_to_dtype"
,
&
TaskNode
::
SetVarsToDtype
)
.
def
(
"init"
,
[](
TaskNode
&
self
)
{
self
.
Init
();
})
.
def
(
"init"
,
[](
TaskNode
&
self
)
{
self
.
Init
();
})
.
def
(
"set_program"
,
&
TaskNode
::
SetProgram
);
.
def
(
"set_program"
,
&
TaskNode
::
SetProgram
);
...
...
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
92c2dcbd
...
@@ -23,6 +23,8 @@
...
@@ -23,6 +23,8 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
namespace
phi
{
template
<
typename
InT
,
typename
OutT
>
template
<
typename
InT
,
typename
OutT
>
...
@@ -101,6 +103,12 @@ struct EmbeddingGradCUDAFunctor {
...
@@ -101,6 +103,12 @@ struct EmbeddingGradCUDAFunctor {
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
gridx
,
1
);
dim3
grids
(
gridx
,
1
);
if
(
FLAGS_cudnn_deterministic
)
{
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
grids
.
x
=
1
;
threads
.
y
=
1
;
}
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
...
...
python/paddle/distributed/auto_parallel/completion.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
92c2dcbd
...
@@ -94,6 +94,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False)
...
@@ -94,6 +94,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config
(
GRADIENT_MERGE
,
"k_steps"
,
1
)
set_field_default_config
(
GRADIENT_MERGE
,
"k_steps"
,
1
)
set_field_default_config
(
GRADIENT_MERGE
,
"avg"
,
True
)
set_field_default_config
(
GRADIENT_MERGE
,
"avg"
,
True
)
#########################################
# pipeline configuration
#########################################
PIPELINE
=
"pipeline"
set_field_default_config
(
PIPELINE
,
"enable"
,
False
)
set_field_default_config
(
PIPELINE
,
"schedule_mode"
,
"1F1B"
)
set_field_default_config
(
PIPELINE
,
"micro_batch_size"
,
1
)
set_field_default_config
(
PIPELINE
,
"accumulate_steps"
,
1
)
set_field_default_config
(
PIPELINE
,
"generation_batch_size"
,
1
)
#########################################
#########################################
# quantization configuration
# quantization configuration
#########################################
#########################################
...
...
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
浏览文件 @
92c2dcbd
...
@@ -556,8 +556,8 @@ def get_cost_from_engine(engine, mode):
...
@@ -556,8 +556,8 @@ def get_cost_from_engine(engine, mode):
)
)
serial_startup_prog
=
(
serial_startup_prog
=
(
engine
.
_
serial_startup_progs
[
mode
]
.
clone
()
engine
.
_
fwd_dist_contexts
[
mode
].
_original_serial_main_program
.
clone
()
if
mode
in
engine
.
_
serial_startup_prog
s
if
mode
in
engine
.
_
fwd_dist_context
s
else
engine
.
_orig_startup_prog
.
clone
()
else
engine
.
_orig_startup_prog
.
clone
()
)
)
losses
=
(
losses
=
(
...
...
python/paddle/distributed/auto_parallel/dist_context.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/dist_op.py
浏览文件 @
92c2dcbd
...
@@ -27,7 +27,6 @@ from .utils import convert_to_shard_spec, verify_shard_spec
...
@@ -27,7 +27,6 @@ from .utils import convert_to_shard_spec, verify_shard_spec
class
DistributedOperator
:
class
DistributedOperator
:
def
__init__
(
self
,
serial_op
,
dist_attr
=
None
):
def
__init__
(
self
,
serial_op
,
dist_attr
=
None
):
self
.
_serial_op
=
serial_op
self
.
_serial_op
=
serial_op
self
.
_serial_inputs
=
{}
self
.
_serial_inputs
=
{}
...
@@ -78,28 +77,34 @@ class DistributedOperator:
...
@@ -78,28 +77,34 @@ class DistributedOperator:
if
tensor
is
None
:
if
tensor
is
None
:
tensor_shape
=
[]
tensor_shape
=
[]
else
:
else
:
if
tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
\
if
(
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
:
tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
):
tensor_shape
=
[]
tensor_shape
=
[]
else
:
else
:
tensor_shape
=
tensor
.
shape
tensor_shape
=
tensor
.
shape
if
self
.
_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
is
None
:
if
self
.
_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
is
None
:
tensor_dims_mapping
=
[
-
1
for
_
in
range
(
len
(
tensor_shape
))]
tensor_dims_mapping
=
[
-
1
for
_
in
range
(
len
(
tensor_shape
))]
self
.
_dist_attr
.
set_input_dims_mapping
(
tensor_name
,
self
.
_dist_attr
.
set_input_dims_mapping
(
tensor_dims_mapping
)
tensor_name
,
tensor_dims_mapping
)
for
tensor_name
in
self
.
_serial_op
.
output_arg_names
:
for
tensor_name
in
self
.
_serial_op
.
output_arg_names
:
tensor
=
self
.
_serial_op
.
block
.
_var_recursive
(
tensor_name
)
tensor
=
self
.
_serial_op
.
block
.
_var_recursive
(
tensor_name
)
if
tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
\
if
(
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
\
tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
:
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
):
tensor_shape
=
[]
tensor_shape
=
[]
else
:
else
:
tensor_shape
=
tensor
.
shape
tensor_shape
=
tensor
.
shape
self
.
_serial_outputs
[
tensor_name
]
=
tensor
self
.
_serial_outputs
[
tensor_name
]
=
tensor
if
self
.
_dist_attr
.
get_output_dims_mapping
(
tensor_name
)
is
None
:
if
self
.
_dist_attr
.
get_output_dims_mapping
(
tensor_name
)
is
None
:
tensor_dims_mapping
=
[
-
1
for
_
in
range
(
len
(
tensor_shape
))]
tensor_dims_mapping
=
[
-
1
for
_
in
range
(
len
(
tensor_shape
))]
self
.
_dist_attr
.
set_output_dims_mapping
(
tensor_name
,
self
.
_dist_attr
.
set_output_dims_mapping
(
tensor_dims_mapping
)
tensor_name
,
tensor_dims_mapping
)
if
self
.
_dist_attr
.
op_type
is
None
:
if
self
.
_dist_attr
.
op_type
is
None
:
self
.
_dist_attr
.
op_type
=
self
.
serial_op
.
type
self
.
_dist_attr
.
op_type
=
self
.
serial_op
.
type
if
self
.
_dist_attr
.
impl_type
is
None
:
if
self
.
_dist_attr
.
impl_type
is
None
:
...
@@ -117,8 +122,10 @@ class DistributedOperator:
...
@@ -117,8 +122,10 @@ class DistributedOperator:
new_dist_attr
=
{}
new_dist_attr
=
{}
for
key
,
value
in
dist_attr
.
items
():
for
key
,
value
in
dist_attr
.
items
():
if
isinstance
(
key
,
Variable
):
if
isinstance
(
key
,
Variable
):
if
key
.
name
in
self
.
_serial_op
.
input_arg_names
\
if
(
or
key
.
name
in
self
.
_serial_op
.
output_arg_names
:
key
.
name
in
self
.
_serial_op
.
input_arg_names
or
key
.
name
in
self
.
_serial_op
.
output_arg_names
):
new_dist_attr
[
key
]
=
value
new_dist_attr
[
key
]
=
value
else
:
else
:
new_dist_attr
[
key
]
=
value
new_dist_attr
[
key
]
=
value
...
@@ -129,13 +136,15 @@ class DistributedOperator:
...
@@ -129,13 +136,15 @@ class DistributedOperator:
for
tensor_name
in
self
.
_serial_op
.
input_arg_names
:
for
tensor_name
in
self
.
_serial_op
.
input_arg_names
:
tensor_dist_attr
=
dist_attr
.
get_input_dist_attr
(
tensor_name
)
tensor_dist_attr
=
dist_attr
.
get_input_dist_attr
(
tensor_name
)
if
tensor_dist_attr
:
if
tensor_dist_attr
:
new_dist_attr
.
set_input_dist_attr
(
tensor_name
,
new_dist_attr
.
set_input_dist_attr
(
tensor_dist_attr
)
tensor_name
,
tensor_dist_attr
)
for
tensor_name
in
self
.
_serial_op
.
output_arg_names
:
for
tensor_name
in
self
.
_serial_op
.
output_arg_names
:
tensor_dist_attr
=
dist_attr
.
get_output_dist_attr
(
tensor_name
)
tensor_dist_attr
=
dist_attr
.
get_output_dist_attr
(
tensor_name
)
if
tensor_dist_attr
:
if
tensor_dist_attr
:
new_dist_attr
.
set_output_dist_attr
(
tensor_name
,
new_dist_attr
.
set_output_dist_attr
(
tensor_dist_attr
)
tensor_name
,
tensor_dist_attr
)
else
:
else
:
assert
False
,
"Cannot recognize the {} parameter."
.
format
(
dist_attr
)
assert
False
,
"Cannot recognize the {} parameter."
.
format
(
dist_attr
)
return
new_dist_attr
return
new_dist_attr
...
@@ -146,8 +155,10 @@ class DistributedOperator:
...
@@ -146,8 +155,10 @@ class DistributedOperator:
for
name
in
self
.
serial_op
.
input_arg_names
:
for
name
in
self
.
serial_op
.
input_arg_names
:
input_dist_attr
=
self
.
dist_attr
.
get_input_dist_attr
(
name
)
input_dist_attr
=
self
.
dist_attr
.
get_input_dist_attr
(
name
)
dims_mapping
=
input_dist_attr
.
dims_mapping
dims_mapping
=
input_dist_attr
.
dims_mapping
if
self
.
get_serial_input
(
if
(
name
).
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
:
self
.
get_serial_input
(
name
).
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
):
shape
=
[]
shape
=
[]
else
:
else
:
shape
=
self
.
get_serial_input
(
name
).
shape
shape
=
self
.
get_serial_input
(
name
).
shape
...
@@ -155,7 +166,8 @@ class DistributedOperator:
...
@@ -155,7 +166,8 @@ class DistributedOperator:
return
False
return
False
for
i
in
range
(
len
(
dims_mapping
)):
for
i
in
range
(
len
(
dims_mapping
)):
if
dims_mapping
[
i
]
<
-
1
or
dims_mapping
[
i
]
>=
len
(
if
dims_mapping
[
i
]
<
-
1
or
dims_mapping
[
i
]
>=
len
(
self
.
dist_attr
.
process_mesh
.
topology
):
self
.
dist_attr
.
process_mesh
.
topology
):
return
False
return
False
for
i
in
range
(
len
(
self
.
dist_attr
.
process_mesh
.
topology
)):
for
i
in
range
(
len
(
self
.
dist_attr
.
process_mesh
.
topology
)):
if
dims_mapping
.
count
(
i
)
>
1
:
if
dims_mapping
.
count
(
i
)
>
1
:
...
@@ -166,8 +178,12 @@ class DistributedOperator:
...
@@ -166,8 +178,12 @@ class DistributedOperator:
for
name
in
self
.
serial_op
.
output_arg_names
:
for
name
in
self
.
serial_op
.
output_arg_names
:
output_dist_attr
=
self
.
dist_attr
.
get_output_dist_attr
(
name
)
output_dist_attr
=
self
.
dist_attr
.
get_output_dist_attr
(
name
)
dims_mapping
=
output_dist_attr
.
dims_mapping
dims_mapping
=
output_dist_attr
.
dims_mapping
if
self
.
get_serial_output
(
name
).
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
\
if
(
or
self
.
get_serial_output
(
name
).
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
:
self
.
get_serial_output
(
name
).
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
or
self
.
get_serial_output
(
name
).
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
):
shape
=
[]
shape
=
[]
else
:
else
:
shape
=
self
.
get_serial_output
(
name
).
shape
shape
=
self
.
get_serial_output
(
name
).
shape
...
@@ -175,7 +191,8 @@ class DistributedOperator:
...
@@ -175,7 +191,8 @@ class DistributedOperator:
return
False
return
False
for
i
in
range
(
len
(
dims_mapping
)):
for
i
in
range
(
len
(
dims_mapping
)):
if
dims_mapping
[
i
]
<
-
1
or
dims_mapping
[
i
]
>=
len
(
if
dims_mapping
[
i
]
<
-
1
or
dims_mapping
[
i
]
>=
len
(
self
.
dist_attr
.
process_mesh
.
topology
):
self
.
dist_attr
.
process_mesh
.
topology
):
return
False
return
False
for
i
in
range
(
len
(
self
.
dist_attr
.
process_mesh
.
topology
)):
for
i
in
range
(
len
(
self
.
dist_attr
.
process_mesh
.
topology
)):
if
dims_mapping
.
count
(
i
)
>
1
:
if
dims_mapping
.
count
(
i
)
>
1
:
...
@@ -185,8 +202,9 @@ class DistributedOperator:
...
@@ -185,8 +202,9 @@ class DistributedOperator:
return
True
return
True
def
__str__
(
self
):
def
__str__
(
self
):
str
=
"{{op type: {}, op id: {}"
.
format
(
self
.
serial_op
.
desc
.
type
(),
str
=
"{{op type: {}, op id: {}"
.
format
(
self
.
serial_op
.
desc
.
id
())
self
.
serial_op
.
desc
.
type
(),
self
.
serial_op
.
desc
.
id
()
)
# str += ", {}".format(self.dist_attr)
# str += ", {}".format(self.dist_attr)
# return str
# return str
...
@@ -195,8 +213,9 @@ class DistributedOperator:
...
@@ -195,8 +213,9 @@ class DistributedOperator:
annotated_str
=
"annotated"
annotated_str
=
"annotated"
else
:
else
:
annotated_str
=
"non-annotated"
annotated_str
=
"non-annotated"
str
+=
", process_mesh ({}): {}"
.
format
(
annotated_str
,
str
+=
", process_mesh ({}): {}"
.
format
(
self
.
dist_attr
.
process_mesh
)
annotated_str
,
self
.
dist_attr
.
process_mesh
)
for
arg_name
in
self
.
serial_op
.
desc
.
input_arg_names
():
for
arg_name
in
self
.
serial_op
.
desc
.
input_arg_names
():
dims_mapping
=
self
.
dist_attr
.
get_input_dims_mapping
(
arg_name
)
dims_mapping
=
self
.
dist_attr
.
get_input_dims_mapping
(
arg_name
)
...
@@ -212,7 +231,8 @@ class DistributedOperator:
...
@@ -212,7 +231,8 @@ class DistributedOperator:
else
:
else
:
is_parameter_str
=
"non-parameter"
is_parameter_str
=
"non-parameter"
str
+=
", {}'s dims_mapping (input, {}, {}): {}"
.
format
(
str
+=
", {}'s dims_mapping (input, {}, {}): {}"
.
format
(
arg_name
,
annotated_str
,
is_parameter_str
,
dims_mapping
)
arg_name
,
annotated_str
,
is_parameter_str
,
dims_mapping
)
for
arg_name
in
self
.
serial_op
.
desc
.
output_arg_names
():
for
arg_name
in
self
.
serial_op
.
desc
.
output_arg_names
():
dims_mapping
=
self
.
dist_attr
.
get_output_dims_mapping
(
arg_name
)
dims_mapping
=
self
.
dist_attr
.
get_output_dims_mapping
(
arg_name
)
...
@@ -228,12 +248,14 @@ class DistributedOperator:
...
@@ -228,12 +248,14 @@ class DistributedOperator:
else
:
else
:
is_parameter_str
=
"non-parameter"
is_parameter_str
=
"non-parameter"
str
+=
", {}'s dims_mapping (output, {}, {}): {}"
.
format
(
str
+=
", {}'s dims_mapping (output, {}, {}): {}"
.
format
(
arg_name
,
annotated_str
,
is_parameter_str
,
dims_mapping
)
arg_name
,
annotated_str
,
is_parameter_str
,
dims_mapping
)
str
+=
", pipeline stage: {}"
.
format
(
None
)
str
+=
", pipeline stage: {}"
.
format
(
None
)
str
+=
", dist_impl idx: {} , dist_impl type {} }}"
.
format
(
str
+=
", dist_impl idx: {} , dist_impl type {} }}"
.
format
(
self
.
dist_attr
.
_impl_idx
,
self
.
dist_attr
.
_impl_type
)
self
.
dist_attr
.
_impl_idx
,
self
.
dist_attr
.
_impl_type
)
return
str
return
str
...
@@ -242,7 +264,11 @@ class DistributedOperator:
...
@@ -242,7 +264,11 @@ class DistributedOperator:
result
=
cls
.
__new__
(
cls
)
result
=
cls
.
__new__
(
cls
)
memo
[
id
(
self
)]
=
result
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
for
k
,
v
in
self
.
__dict__
.
items
():
if
k
==
"_serial_op"
or
k
==
"_serial_inputs"
or
k
==
"_serial_outputs"
:
if
(
k
==
"_serial_op"
or
k
==
"_serial_inputs"
or
k
==
"_serial_outputs"
):
setattr
(
result
,
k
,
v
)
setattr
(
result
,
k
,
v
)
else
:
else
:
setattr
(
result
,
k
,
copy
.
deepcopy
(
v
,
memo
))
setattr
(
result
,
k
,
copy
.
deepcopy
(
v
,
memo
))
...
@@ -250,9 +276,9 @@ class DistributedOperator:
...
@@ -250,9 +276,9 @@ class DistributedOperator:
class
DistributedOperatorHelper
:
class
DistributedOperatorHelper
:
def
__init__
(
def
__init__
(
self
,
serial_op
,
process_mesh
,
in_dims_mappings
,
self
,
serial_op
,
process_mesh
,
in_dims_mappings
,
out_dims_mappings
out_dims_mappings
):
):
self
.
_serial_op
=
serial_op
self
.
_serial_op
=
serial_op
self
.
_process_mesh
=
process_mesh
self
.
_process_mesh
=
process_mesh
self
.
_in_dims_mappings
=
in_dims_mappings
self
.
_in_dims_mappings
=
in_dims_mappings
...
@@ -262,8 +288,11 @@ class DistributedOperatorHelper:
...
@@ -262,8 +288,11 @@ class DistributedOperatorHelper:
tensor_to_dims_mapping
=
{}
tensor_to_dims_mapping
=
{}
index
=
0
index
=
0
if
self
.
_in_dims_mappings
:
if
self
.
_in_dims_mappings
:
assert
len
(
args
)
+
len
(
kwargs
)
==
len
(
self
.
_in_dims_mappings
),
\
assert
len
(
args
)
+
len
(
kwargs
)
==
len
(
"The length of dims_mapping {} does not matching the length output {}."
.
format
(
len
(
self
.
_in_dims_mappings
),
len
(
args
)
+
len
(
kwargs
))
self
.
_in_dims_mappings
),
"The length of dims_mapping {} does not matching the length output {}."
.
format
(
len
(
self
.
_in_dims_mappings
),
len
(
args
)
+
len
(
kwargs
)
)
for
arg
in
args
:
for
arg
in
args
:
if
isinstance
(
arg
,
Variable
)
and
self
.
_in_dims_mappings
:
if
isinstance
(
arg
,
Variable
)
and
self
.
_in_dims_mappings
:
tensor_to_dims_mapping
[
arg
.
name
]
=
self
.
_in_dims_mappings
[
index
]
tensor_to_dims_mapping
[
arg
.
name
]
=
self
.
_in_dims_mappings
[
index
]
...
@@ -287,13 +316,17 @@ class DistributedOperatorHelper:
...
@@ -287,13 +316,17 @@ class DistributedOperatorHelper:
raise
ValueError
(
"Unrecognized outpout."
)
raise
ValueError
(
"Unrecognized outpout."
)
if
self
.
_out_dims_mappings
:
if
self
.
_out_dims_mappings
:
assert
len
(
new_output
)
==
len
(
self
.
_out_dims_mappings
),
\
assert
len
(
new_output
)
==
len
(
"The length of dims_mapping {} does not matching the length output {}."
.
format
(
len
(
self
.
_out_dims_mappings
),
len
(
new_output
))
self
.
_out_dims_mappings
),
"The length of dims_mapping {} does not matching the length output {}."
.
format
(
len
(
self
.
_out_dims_mappings
),
len
(
new_output
)
)
for
i
,
item
in
enumerate
(
new_output
):
for
i
,
item
in
enumerate
(
new_output
):
if
isinstance
(
item
,
Variable
)
and
self
.
_out_dims_mappings
:
if
isinstance
(
item
,
Variable
)
and
self
.
_out_dims_mappings
:
tensor_to_dims_mapping
[
item
.
name
]
=
self
.
_out_dims_mappings
[
i
]
tensor_to_dims_mapping
[
item
.
name
]
=
self
.
_out_dims_mappings
[
i
]
from
.dist_context
import
get_default_distributed_context
from
.dist_context
import
get_default_distributed_context
default_dist_ctx
=
get_default_distributed_context
()
default_dist_ctx
=
get_default_distributed_context
()
for
idx
in
range
(
op_size
,
new_op_size
):
for
idx
in
range
(
op_size
,
new_op_size
):
op
=
cur_block
.
ops
[
idx
]
op
=
cur_block
.
ops
[
idx
]
...
@@ -302,53 +335,68 @@ class DistributedOperatorHelper:
...
@@ -302,53 +335,68 @@ class DistributedOperatorHelper:
if
name
in
tensor_to_dims_mapping
.
keys
():
if
name
in
tensor_to_dims_mapping
.
keys
():
tensor
=
dist_op
.
get_serial_input
(
name
)
tensor
=
dist_op
.
get_serial_input
(
name
)
tensor_dist_attr
=
dist_op
.
dist_attr
.
get_input_dist_attr
(
tensor_dist_attr
=
dist_op
.
dist_attr
.
get_input_dist_attr
(
name
)
name
)
dims_mapping
=
tensor_to_dims_mapping
[
name
]
dims_mapping
=
tensor_to_dims_mapping
[
name
]
if
tensor
is
None
:
if
tensor
is
None
:
tensor_shape
=
[]
tensor_shape
=
[]
else
:
else
:
if
tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
\
if
(
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
\
tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
:
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
):
tensor_shape
=
[]
tensor_shape
=
[]
else
:
else
:
tensor_shape
=
tensor
.
shape
tensor_shape
=
tensor
.
shape
if
dims_mapping
is
not
None
:
if
dims_mapping
is
not
None
:
dims_mapping
=
tensor_to_dims_mapping
[
name
]
dims_mapping
=
tensor_to_dims_mapping
[
name
]
shard_spec
=
convert_to_shard_spec
(
shard_spec
=
convert_to_shard_spec
(
dims_mapping
,
self
.
_process_mesh
)
dims_mapping
,
self
.
_process_mesh
assert
verify_shard_spec
(
shard_spec
,
tensor_shape
,
self
.
_process_mesh
),
\
)
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}."
.
format
(
assert
verify_shard_spec
(
name
,
shard_spec
,
tensor_shape
,
self
.
_process_mesh
)
shard_spec
,
tensor_shape
,
self
.
_process_mesh
),
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}."
.
format
(
name
,
shard_spec
,
tensor_shape
,
self
.
_process_mesh
)
tensor_dist_attr
.
dims_mapping
=
dims_mapping
tensor_dist_attr
.
dims_mapping
=
dims_mapping
tensor_dist_attr
.
mark_annotated
(
"dims_mapping"
)
tensor_dist_attr
.
mark_annotated
(
"dims_mapping"
)
for
name
in
dist_op
.
serial_op
.
output_arg_names
:
for
name
in
dist_op
.
serial_op
.
output_arg_names
:
if
name
in
tensor_to_dims_mapping
.
keys
():
if
name
in
tensor_to_dims_mapping
.
keys
():
tensor
=
dist_op
.
get_serial_output
(
name
)
tensor
=
dist_op
.
get_serial_output
(
name
)
tensor_dist_attr
=
dist_op
.
dist_attr
.
get_output_dist_attr
(
tensor_dist_attr
=
dist_op
.
dist_attr
.
get_output_dist_attr
(
name
)
name
)
dims_mapping
=
tensor_to_dims_mapping
[
name
]
dims_mapping
=
tensor_to_dims_mapping
[
name
]
if
tensor
is
None
:
if
tensor
is
None
:
tensor_shape
=
[]
tensor_shape
=
[]
else
:
else
:
if
tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
\
if
(
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
\
tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
:
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
or
tensor
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
):
tensor_shape
=
[]
tensor_shape
=
[]
else
:
else
:
tensor_shape
=
tensor
.
shape
tensor_shape
=
tensor
.
shape
if
dims_mapping
is
not
None
:
if
dims_mapping
is
not
None
:
dims_mapping
=
tensor_to_dims_mapping
[
name
]
dims_mapping
=
tensor_to_dims_mapping
[
name
]
shard_spec
=
convert_to_shard_spec
(
shard_spec
=
convert_to_shard_spec
(
dims_mapping
,
self
.
_process_mesh
)
dims_mapping
,
self
.
_process_mesh
assert
verify_shard_spec
(
shard_spec
,
tensor_shape
,
self
.
_process_mesh
),
\
)
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}."
.
format
(
assert
verify_shard_spec
(
name
,
shard_spec
,
tensor_shape
,
self
.
_process_mesh
)
shard_spec
,
tensor_shape
,
self
.
_process_mesh
),
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}."
.
format
(
name
,
shard_spec
,
tensor_shape
,
self
.
_process_mesh
)
tensor_dist_attr
.
dims_mapping
=
dims_mapping
tensor_dist_attr
.
dims_mapping
=
dims_mapping
tensor_dist_attr
.
mark_annotated
(
"dims_mapping"
)
tensor_dist_attr
.
mark_annotated
(
"dims_mapping"
)
dist_op
.
dist_attr
.
process_mesh
=
self
.
_process_mesh
dist_op
.
dist_attr
.
process_mesh
=
self
.
_process_mesh
if
self
.
_process_mesh
is
not
None
:
if
self
.
_process_mesh
is
not
None
:
dist_op
.
dist_attr
.
mark_annotated
(
"process_mesh"
)
dist_op
.
dist_attr
.
mark_annotated
(
"process_mesh"
)
default_dist_ctx
.
add_dist_op_for_program
(
dist_op
)
default_dist_ctx
.
add_dist_op_for_program
(
dist_op
)
default_dist_ctx
.
add_process_mesh
(
self
.
_process_mesh
)
return
output
return
output
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
92c2dcbd
...
@@ -34,6 +34,7 @@ from paddle.fluid.framework import Operator, _non_static_mode
...
@@ -34,6 +34,7 @@ from paddle.fluid.framework import Operator, _non_static_mode
from
paddle.fluid.framework
import
_current_expected_place
as
_get_device
from
paddle.fluid.framework
import
_current_expected_place
as
_get_device
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.distributed
import
fleet
from
paddle.distributed
import
fleet
from
paddle.distributed.parallel
import
_is_global_parallel_initialize
from
.callbacks
import
config_callbacks
from
.callbacks
import
config_callbacks
from
.converter
import
Converter
from
.converter
import
Converter
...
@@ -160,7 +161,6 @@ class Engine:
...
@@ -160,7 +161,6 @@ class Engine:
" or `paddle.fluid.optimizer.Optimizer`."
" or `paddle.fluid.optimizer.Optimizer`."
)
)
self
.
_optimizer
=
validate_opt
(
optimizer
)
self
.
_optimizer
=
validate_opt
(
optimizer
)
self
.
_orig_optimizer
=
copy
.
deepcopy
(
self
.
_optimizer
)
metrics
=
metrics
or
[]
metrics
=
metrics
or
[]
for
metric
in
to_list
(
metrics
):
for
metric
in
to_list
(
metrics
):
...
@@ -185,12 +185,18 @@ class Engine:
...
@@ -185,12 +185,18 @@ class Engine:
self
.
_strategy
=
strategy
or
Strategy
()
self
.
_strategy
=
strategy
or
Strategy
()
self
.
_logger
=
get_logger
(
logging
.
INFO
)
self
.
_logger
=
get_logger
(
logging
.
INFO
)
if
os
.
getenv
(
"POD_NAME"
):
if
os
.
getenv
(
"POD_NAME"
)
and
not
_is_global_parallel_initialize
()
:
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"Distribute training by paddle.distributed.launch"
"Distribute training by paddle.distributed.launch"
)
)
fleet
.
init
(
is_collective
=
True
)
fleet
.
init
(
is_collective
=
True
)
# for compute cost
# TODO: remove _fwd_main_progs and _orig_optimizer
self
.
_fwd_dist_contexts
=
{}
self
.
_fwd_main_progs
=
{}
self
.
_orig_optimizer
=
copy
.
deepcopy
(
self
.
_optimizer
)
self
.
_executor
=
None
self
.
_executor
=
None
self
.
_cur_rank
=
paddle
.
distributed
.
get_rank
()
self
.
_cur_rank
=
paddle
.
distributed
.
get_rank
()
self
.
_nranks
=
paddle
.
distributed
.
get_world_size
()
self
.
_nranks
=
paddle
.
distributed
.
get_world_size
()
...
@@ -200,14 +206,6 @@ class Engine:
...
@@ -200,14 +206,6 @@ class Engine:
self
.
_orig_startup_prog
=
static
.
default_startup_program
()
self
.
_orig_startup_prog
=
static
.
default_startup_program
()
self
.
_orig_dist_context
=
get_default_distributed_context
()
self
.
_orig_dist_context
=
get_default_distributed_context
()
self
.
_dist_contexts
=
{}
self
.
_dist_contexts
=
{}
self
.
_fwd_main_progs
=
{}
self
.
_fwd_dist_contexts
=
{}
self
.
_serial_main_progs
=
{}
self
.
_serial_startup_progs
=
{}
self
.
_dist_main_progs
=
defaultdict
(
dict
)
# dist main programs
self
.
_dist_startup_progs
=
defaultdict
(
dict
)
# dist startup programs
self
.
_feed_vars
=
{}
self
.
_fetch_vars
=
{}
self
.
_planners
=
{}
self
.
_planners
=
{}
self
.
_has_prepared
=
{
"train"
:
False
,
"eval"
:
False
,
"predict"
:
False
}
self
.
_has_prepared
=
{
"train"
:
False
,
"eval"
:
False
,
"predict"
:
False
}
self
.
_has_prepared_reader
=
{
self
.
_has_prepared_reader
=
{
...
@@ -338,9 +336,9 @@ class Engine:
...
@@ -338,9 +336,9 @@ class Engine:
return
inputs
,
labels
return
inputs
,
labels
def
_prepare_reader
(
self
):
def
_prepare_reader
(
self
,
feed_list
=
[]):
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
_mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_main_block
=
dist_main_prog
.
global_block
()
dist_main_block
=
dist_main_prog
.
global_block
()
# NOTE: this list may be changed if Paddle changes the existing rules.
# NOTE: this list may be changed if Paddle changes the existing rules.
...
@@ -361,10 +359,13 @@ class Engine:
...
@@ -361,10 +359,13 @@ class Engine:
if
op
.
type
in
related_reader_ops
:
if
op
.
type
in
related_reader_ops
:
reader_op_indices
.
append
(
idx
)
reader_op_indices
.
append
(
idx
)
# Step 2: insert the new reader ops to cpp
# Step 2: insert the new reader ops to cpp
# record the read ops' desc to insert to program of forward task_node
read_ops_desc
=
[]
new_reader_ops
=
[]
new_reader_ops
=
[]
for
idx
in
reversed
(
reader_op_indices
):
for
idx
in
reversed
(
reader_op_indices
):
new_op_desc
=
dist_main_block
.
desc
.
_prepend_op
()
new_op_desc
=
dist_main_block
.
desc
.
_prepend_op
()
new_op_desc
.
copy_from
(
dist_main_block
.
ops
[
idx
].
desc
)
new_op_desc
.
copy_from
(
dist_main_block
.
ops
[
idx
].
desc
)
read_ops_desc
.
append
(
new_op_desc
)
new_op
=
Operator
(
new_op
=
Operator
(
dist_main_block
,
new_op_desc
,
type
=
new_op_desc
.
type
()
dist_main_block
,
new_op_desc
,
type
=
new_op_desc
.
type
()
)
)
...
@@ -383,6 +384,29 @@ class Engine:
...
@@ -383,6 +384,29 @@ class Engine:
dist_main_block
.
_sync_with_cpp
()
dist_main_block
.
_sync_with_cpp
()
self
.
_has_prepared_reader
[
self
.
_mode
]
=
True
self
.
_has_prepared_reader
[
self
.
_mode
]
=
True
# Insert read op to forward TaskNode if 1F1B pass is setted
if
self
.
main_program
.
_pipeline_opt
:
assert
"tasks"
in
self
.
main_program
.
_pipeline_opt
[
"fleet_opt"
]
fleet_opt
=
self
.
main_program
.
_pipeline_opt
[
"fleet_opt"
]
fwd_task
=
fleet_opt
[
"tasks"
][
0
]
fwd_prog
=
fwd_task
.
get_program
()
fwd_block
=
fwd_prog
.
global_block
()
for
var
in
feed_list
:
if
var
.
name
not
in
fwd_block
.
vars
:
fwd_block
.
_clone_variable
(
var
)
for
op_desc
in
read_ops_desc
:
new_op_desc
=
fwd_block
.
desc
.
_prepend_op
()
new_op_desc
.
copy_from
(
op_desc
)
new_op
=
Operator
(
fwd_block
,
new_op_desc
,
type
=
new_op_desc
.
type
()
)
fwd_block
.
ops
.
insert
(
0
,
new_op
)
fwd_block
.
_sync_with_cpp
()
fwd_task
.
set_program
(
fwd_prog
)
def
_prepare_feed
(
self
,
data
,
user_feeds
,
mode
):
def
_prepare_feed
(
self
,
data
,
user_feeds
,
mode
):
feeds
=
{}
feeds
=
{}
if
data
is
not
None
:
if
data
is
not
None
:
...
@@ -430,14 +454,16 @@ class Engine:
...
@@ -430,14 +454,16 @@ class Engine:
fetch_names
.
append
([])
fetch_names
.
append
([])
fetch_indices
.
append
(
group_indices
)
fetch_indices
.
append
(
group_indices
)
dist_context
=
self
.
_dist_contexts
[
mode
]
fetch_vars
=
dist_context
.
serial_fetch_vars
if
mode
!=
"predict"
:
if
mode
!=
"predict"
:
_process_fetch_group
(
"loss"
,
self
.
_fetch_vars
[
mode
]
[
"loss"
])
_process_fetch_group
(
"loss"
,
fetch_vars
[
"loss"
])
if
mode
!=
"predict"
:
if
mode
!=
"predict"
:
metrics
=
self
.
_fetch_vars
[
mode
]
[
"metrics"
]
metrics
=
fetch_vars
[
"metrics"
]
for
i
,
var_list
in
enumerate
(
metrics
):
for
i
,
var_list
in
enumerate
(
metrics
):
_process_fetch_group
(
"metrics_"
+
str
(
i
),
var_list
)
_process_fetch_group
(
"metrics_"
+
str
(
i
),
var_list
)
if
mode
==
"predict"
:
if
mode
==
"predict"
:
_process_fetch_group
(
"outputs"
,
self
.
_fetch_vars
[
mode
]
[
"outputs"
])
_process_fetch_group
(
"outputs"
,
fetch_vars
[
"outputs"
])
user_fetches_collection
=
[
user_fetches_collection
=
[
item
[
1
]
for
item
in
get_collection
(
CollectionNames
.
FETCHES
)
item
[
1
]
for
item
in
get_collection
(
CollectionNames
.
FETCHES
)
]
]
...
@@ -471,7 +497,8 @@ class Engine:
...
@@ -471,7 +497,8 @@ class Engine:
logs
[
"loss"
]
=
outs
[
idx
][
0
]
logs
[
"loss"
]
=
outs
[
idx
][
0
]
group_idx
+=
1
group_idx
+=
1
# logging metrics
# logging metrics
metric_vars
=
self
.
_fetch_vars
[
mode
][
"metrics"
]
dist_context
=
self
.
_dist_contexts
[
mode
]
metric_vars
=
dist_context
.
serial_fetch_vars
[
"metrics"
]
if
metric_vars
:
if
metric_vars
:
for
metric
in
self
.
_metrics
:
for
metric
in
self
.
_metrics
:
metrics_indices
=
fetch_indices
[
group_idx
]
metrics_indices
=
fetch_indices
[
group_idx
]
...
@@ -502,15 +529,18 @@ class Engine:
...
@@ -502,15 +529,18 @@ class Engine:
logs
[
"fetches"
]
=
logs_fetch
logs
[
"fetches"
]
=
logs_fetch
return
logs
return
logs
def
_prepare_program
(
self
,
mode
):
def
_prepare_program
(
self
,
mode
,
init_parameters
=
True
):
# Do the build process
# Do the build process
self
.
_build
(
mode
)
self
.
_build
(
mode
)
# Do the planning process
# Do the planning process
self
.
_plan
(
mode
)
self
.
_plan
(
mode
)
# Do the parallel process
# Do the parallel process
self
.
_parallel
(
mode
)
self
.
_parallel
(
mode
)
# Init comm and startup program
# Init comm
self
.
_initialize
(
mode
)
self
.
_init_comm
()
if
init_parameters
:
# startup program
self
.
_initialize
(
mode
)
self
.
_has_prepared
[
mode
]
=
True
self
.
_has_prepared
[
mode
]
=
True
def
_build
(
self
,
mode
):
def
_build
(
self
,
mode
):
...
@@ -542,8 +572,8 @@ class Engine:
...
@@ -542,8 +572,8 @@ class Engine:
paddle
.
enable_static
()
paddle
.
enable_static
()
else
:
else
:
# build program in static mode
# build program in static mode
serial_main_prog
=
self
.
_serial_main_prog
s
.
get
(
mode
,
None
)
dist_context
=
self
.
_dist_context
s
.
get
(
mode
,
None
)
if
serial_main_prog
is
not
None
:
if
dist_context
is
not
None
:
return
return
outputs
=
[]
outputs
=
[]
...
@@ -581,7 +611,7 @@ class Engine:
...
@@ -581,7 +611,7 @@ class Engine:
metric
.
compute
(
*
(
outputs
+
self
.
_labels
))
metric
.
compute
(
*
(
outputs
+
self
.
_labels
))
)
)
)
)
el
se
:
el
if
mode
==
"train"
:
assert
isinstance
(
assert
isinstance
(
self
.
_loss
,
Variable
self
.
_loss
,
Variable
),
"the type of `loss` of the Engine arguments should be Variable."
),
"the type of `loss` of the Engine arguments should be Variable."
...
@@ -724,37 +754,21 @@ class Engine:
...
@@ -724,37 +754,21 @@ class Engine:
)
)
dist_context
.
set_op_dist_attr_for_program
(
op
,
ref_op_dist_attr
)
dist_context
.
set_op_dist_attr_for_program
(
op
,
ref_op_dist_attr
)
def
_initialize
(
self
,
mode
):
def
_init_comm
(
self
):
# Get the current content from the distributed context
self
.
_serial_main_progs
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_main_program
self
.
_serial_startup_progs
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_startup_program
self
.
_dist_main_progs
[
mode
]
=
self
.
_dist_contexts
[
mode
].
dist_main_programs
self
.
_dist_startup_progs
[
mode
]
=
self
.
_dist_contexts
[
mode
].
dist_startup_programs
self
.
_feed_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_feed_vars
self
.
_fetch_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_fetch_vars
self
.
_optimizer
=
self
.
_dist_contexts
[
mode
].
_serial_optimizer
if
self
.
_nranks
>
1
:
if
self
.
_nranks
>
1
:
# Traverse different rank programs and traverse each op of them,
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
# instantiate communication by process_mapping.
all_process_groups
=
get_all_process_groups
()
all_process_groups
=
get_all_process_groups
()
if
self
.
_strategy
.
auto_mode
==
"full"
:
if
self
.
_strategy
.
auto_mode
==
"full"
:
initialize_pg_in_full_mode
(
all_process_groups
,
cur_rank
)
initialize_pg_in_full_mode
(
all_process_groups
,
self
.
_
cur_rank
)
else
:
else
:
for
process_group
in
all_process_groups
:
for
process_group
in
all_process_groups
:
if
self
.
_cur_rank
not
in
process_group
.
ranks
:
if
self
.
_cur_rank
not
in
process_group
.
ranks
:
continue
continue
process_group
.
instantiate
()
process_group
.
instantiate
()
def
_initialize
(
self
,
mode
):
place
=
_get_device
()
place
=
_get_device
()
if
isinstance
(
place
,
fluid
.
CUDAPlace
):
if
isinstance
(
place
,
fluid
.
CUDAPlace
):
place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
...
@@ -764,15 +778,17 @@ class Engine:
...
@@ -764,15 +778,17 @@ class Engine:
np
.
random
.
seed
(
self
.
_strategy
.
seed
+
self
.
_dp_ranks
[
0
])
np
.
random
.
seed
(
self
.
_strategy
.
seed
+
self
.
_dp_ranks
[
0
])
random
.
seed
(
self
.
_strategy
.
seed
+
self
.
_dp_ranks
[
0
])
random
.
seed
(
self
.
_strategy
.
seed
+
self
.
_dp_ranks
[
0
])
dist_context
=
self
.
_dist_contexts
[
mode
]
if
self
.
_dygraph_mode
:
if
self
.
_dygraph_mode
:
dist_context
=
self
.
_dist_contexts
[
mode
]
dist_main_program
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_main_program
=
self
.
_dist_main_progs
[
mode
][
self
.
_cur_rank
]
self
.
program_helper
.
init
(
dist_main_program
,
place
,
dist_context
)
self
.
program_helper
.
init
(
dist_main_program
,
place
,
dist_context
)
if
self
.
_executor
is
None
:
if
self
.
_executor
is
None
:
self
.
_executor
=
paddle
.
static
.
Executor
(
place
)
self
.
_executor
=
paddle
.
static
.
Executor
(
place
)
uninitialized
=
[]
uninitialized
=
[]
dist_startup_prog
=
self
.
_dist_startup_progs
[
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
for
var
in
dist_startup_prog
.
list_vars
():
for
var
in
dist_startup_prog
.
list_vars
():
scope_var
=
global_scope
().
find_var
(
var
.
name
)
scope_var
=
global_scope
().
find_var
(
var
.
name
)
if
scope_var
and
scope_var
.
get_tensor
().
_is_initialized
():
if
scope_var
and
scope_var
.
get_tensor
().
_is_initialized
():
...
@@ -789,7 +805,9 @@ class Engine:
...
@@ -789,7 +805,9 @@ class Engine:
if
self
.
_strategy
.
reinit
:
if
self
.
_strategy
.
reinit
:
self
.
_logger
.
info
(
"NOTE: parameters will be re-initialized."
)
self
.
_logger
.
info
(
"NOTE: parameters will be re-initialized."
)
dist_startup_prog
=
self
.
_dist_startup_progs
[
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
self
.
_executor
.
run
(
dist_startup_prog
)
self
.
_executor
.
run
(
dist_startup_prog
)
def
fit
(
def
fit
(
...
@@ -926,7 +944,7 @@ class Engine:
...
@@ -926,7 +944,7 @@ class Engine:
)
)
except
core
.
EOFException
:
except
core
.
EOFException
:
break
break
lr
=
get_lr
(
self
.
_
optimizer
)
lr
=
get_lr
(
self
.
optimizer
)
logs
=
self
.
_prepare_logger
(
logs
=
self
.
_prepare_logger
(
outs
,
outs
,
epoch
,
epoch
,
...
@@ -1262,6 +1280,7 @@ class Engine:
...
@@ -1262,6 +1280,7 @@ class Engine:
main_program
=
None
,
main_program
=
None
,
startup_program
=
None
,
startup_program
=
None
,
mode
=
None
,
mode
=
None
,
init_parameters
=
True
,
):
):
if
mode
is
not
None
:
if
mode
is
not
None
:
self
.
to_mode
(
mode
)
self
.
to_mode
(
mode
)
...
@@ -1304,7 +1323,7 @@ class Engine:
...
@@ -1304,7 +1323,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
inputs_spec
,
labels_spec
self
.
_inputs_spec
,
self
.
_labels_spec
=
inputs_spec
,
labels_spec
self
.
_inputs
,
self
.
_labels
=
inputs
,
labels
self
.
_inputs
,
self
.
_labels
=
inputs
,
labels
if
not
self
.
_has_prepared
[
self
.
_mode
]:
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
self
.
_prepare_program
(
self
.
_mode
,
init_parameters
)
else
:
else
:
self
.
_switch_mode
(
self
.
_mode
)
self
.
_switch_mode
(
self
.
_mode
)
...
@@ -1355,16 +1374,17 @@ class Engine:
...
@@ -1355,16 +1374,17 @@ class Engine:
)
)
batch_size
//=
self
.
_k_steps
batch_size
//=
self
.
_k_steps
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
_mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
_mode
][
self
.
_cur_rank
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
dist_main_block
=
dist_main_prog
.
global_block
()
dist_main_block
=
dist_main_prog
.
global_block
()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var,
# Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program,
# then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values.
# that maintains the length of feed_list equal to the length of dataset's values.
inputs_var
=
self
.
_feed_vars
[
self
.
_mode
]
[
"inputs"
]
inputs_var
=
dist_context
.
serial_feed_vars
[
"inputs"
]
labels_var
=
self
.
_feed_vars
[
self
.
_mode
]
[
"labels"
]
labels_var
=
dist_context
.
serial_feed_vars
[
"labels"
]
feed_list
=
[]
feed_list
=
[]
for
var
in
inputs_var
+
labels_var
:
for
var
in
inputs_var
+
labels_var
:
if
var
.
name
in
dist_main_block
.
vars
:
if
var
.
name
in
dist_main_block
.
vars
:
...
@@ -1423,16 +1443,17 @@ class Engine:
...
@@ -1423,16 +1443,17 @@ class Engine:
)
)
batch_size
//=
self
.
_k_steps
batch_size
//=
self
.
_k_steps
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
_mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
_mode
][
self
.
_cur_rank
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
dist_main_block
=
dist_main_prog
.
global_block
()
dist_main_block
=
dist_main_prog
.
global_block
()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var,
# Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program,
# then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values.
# that maintains the length of feed_list equal to the length of dataset's values.
inputs_var
=
self
.
_feed_vars
[
self
.
_mode
]
[
"inputs"
]
inputs_var
=
dist_context
.
serial_feed_vars
[
"inputs"
]
labels_var
=
self
.
_feed_vars
[
self
.
_mode
]
[
"labels"
]
labels_var
=
dist_context
.
serial_feed_vars
[
"labels"
]
feed_list
=
[]
feed_list
=
[]
for
var
in
inputs_var
+
labels_var
:
for
var
in
inputs_var
+
labels_var
:
if
var
.
name
in
dist_main_block
.
vars
:
if
var
.
name
in
dist_main_block
.
vars
:
...
@@ -1462,7 +1483,7 @@ class Engine:
...
@@ -1462,7 +1483,7 @@ class Engine:
data_parallel_world_size
=
self
.
_dp_world_sizes
,
data_parallel_world_size
=
self
.
_dp_world_sizes
,
data_parallel_rank
=
self
.
_dp_ranks
,
data_parallel_rank
=
self
.
_dp_ranks
,
)
)
self
.
_prepare_reader
()
self
.
_prepare_reader
(
feed_list
)
return
dataloader
return
dataloader
def
_tune
(
self
,
tune_data
,
tune_sample_split
=
None
,
batch_size
=
1
):
def
_tune
(
self
,
tune_data
,
tune_sample_split
=
None
,
batch_size
=
1
):
...
@@ -1551,10 +1572,9 @@ class Engine:
...
@@ -1551,10 +1572,9 @@ class Engine:
def
_switch_mode
(
self
,
mode
):
def
_switch_mode
(
self
,
mode
):
assert
(
assert
(
mode
in
self
.
_dist_
main_prog
s
mode
in
self
.
_dist_
context
s
),
"{} model is not ready, please call `prepare()` first."
.
format
(
mode
)
),
"{} model is not ready, please call `prepare()` first."
.
format
(
mode
)
self
.
to_mode
(
mode
)
self
.
to_mode
(
mode
)
self
.
_optimizer
=
self
.
_dist_contexts
[
mode
].
_serial_optimizer
def
to_mode
(
self
,
mode
):
def
to_mode
(
self
,
mode
):
assert
mode
in
[
assert
mode
in
[
...
@@ -1565,8 +1585,8 @@ class Engine:
...
@@ -1565,8 +1585,8 @@ class Engine:
self
.
_mode
=
mode
self
.
_mode
=
mode
def
_set_state_dict
(
self
,
mode
,
strict
,
state_dict
,
dist_attr
):
def
_set_state_dict
(
self
,
mode
,
strict
,
state_dict
,
dist_attr
):
program
=
self
.
_dist_main_progs
[
mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
mode
]
dist_context
=
self
.
_dist_contexts
[
mode
]
program
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
cur_dist_attr
=
get_dist_attr
(
program
,
dist_context
)
cur_dist_attr
=
get_dist_attr
(
program
,
dist_context
)
converter
=
Converter
(
state_dict
,
dist_attr
,
cur_dist_attr
)
converter
=
Converter
(
state_dict
,
dist_attr
,
cur_dist_attr
)
state_dict
=
converter
.
convert
(
strict
=
strict
)
state_dict
=
converter
.
convert
(
strict
=
strict
)
...
@@ -1618,10 +1638,10 @@ class Engine:
...
@@ -1618,10 +1638,10 @@ class Engine:
"""
"""
if
training
:
if
training
:
assert
self
.
_mode
in
self
.
_serial_main_progs
assert
self
.
_mode
in
self
.
_dist_contexts
serial_program
=
self
.
_serial_main_progs
[
self
.
_mode
]
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
_mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
serial_program
=
dist_context
.
serial_main_program
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
self
.
_saver
.
save
(
self
.
_saver
.
save
(
path
,
path
,
serial_program
=
serial_program
,
serial_program
=
serial_program
,
...
@@ -1629,10 +1649,11 @@ class Engine:
...
@@ -1629,10 +1649,11 @@ class Engine:
dist_context
=
dist_context
,
dist_context
=
dist_context
,
)
)
else
:
else
:
assert
"predict"
in
self
.
_dist_main_progs
assert
"predict"
in
self
.
_dist_contexts
feed_vars
=
self
.
_feed_vars
[
"predict"
][
'inputs'
]
dist_context
=
self
.
_dist_contexts
[
"predict"
]
fetch_vars
=
self
.
_fetch_vars
[
"predict"
][
'outputs'
]
feed_vars
=
dist_context
.
serial_feed_vars
[
'inputs'
]
dist_main_prog
=
self
.
_dist_main_progs
[
"predict"
][
self
.
_cur_rank
]
fetch_vars
=
dist_context
.
serial_fetch_vars
[
'outputs'
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
self
.
_saver
.
save_inference_model
(
self
.
_saver
.
save_inference_model
(
path
,
path
,
feed_vars
,
feed_vars
,
...
@@ -1758,11 +1779,13 @@ class Engine:
...
@@ -1758,11 +1779,13 @@ class Engine:
@
property
@
property
def
main_program
(
self
):
def
main_program
(
self
):
return
self
.
_dist_main_progs
[
self
.
_mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
return
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
@
property
@
property
def
startup_program
(
self
):
def
startup_program
(
self
):
return
self
.
_dist_startup_progs
[
self
.
_mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
return
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
@
property
@
property
def
dist_context
(
self
):
def
dist_context
(
self
):
...
@@ -1770,15 +1793,30 @@ class Engine:
...
@@ -1770,15 +1793,30 @@ class Engine:
@
property
@
property
def
serial_main_program
(
self
):
def
serial_main_program
(
self
):
return
self
.
_serial_main_progs
[
self
.
_mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
return
dist_context
.
serial_main_program
@
property
@
property
def
serial_startup_program
(
self
):
def
serial_startup_program
(
self
):
return
self
.
_serial_startup_progs
[
self
.
_mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
return
dist_context
.
serial_startup_program
@
property
def
feed_vars
(
self
):
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
return
dist_context
.
serial_feed_vars
@
property
@
property
def
fetch_vars
(
self
):
def
fetch_vars
(
self
):
return
self
.
_fetch_vars
[
self
.
_mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
return
dist_context
.
serial_fetch_vars
@
property
def
optimizer
(
self
):
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
if
dist_context
.
_serial_optimizer
:
return
dist_context
.
_serial_optimizer
return
self
.
_optimizer
@
property
@
property
def
inputs
(
self
):
def
inputs
(
self
):
...
...
python/paddle/distributed/auto_parallel/interface.py
浏览文件 @
92c2dcbd
...
@@ -67,29 +67,43 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
...
@@ -67,29 +67,43 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
"""
"""
if
process_mesh
is
not
None
:
if
process_mesh
is
not
None
:
assert
isinstance
(
process_mesh
,
ProcessMesh
),
\
assert
isinstance
(
"Argument process_mesh {} is not an instance of ProcessMesh"
.
format
(
process_mesh
)
process_mesh
,
ProcessMesh
),
"Argument process_mesh {} is not an instance of ProcessMesh"
.
format
(
process_mesh
)
else
:
else
:
process_mesh
=
get_current_process_mesh
()
process_mesh
=
get_current_process_mesh
()
assert
process_mesh
is
not
None
,
\
assert
(
"Specify the process mesh argument or use ProcessMesh context manager first."
process_mesh
is
not
None
assert
isinstance
(
shard_spec
,
list
),
\
),
"Specify the process mesh argument or use ProcessMesh context manager first."
"Argument shard_spec {} is not an instance of list"
.
format
(
shard_spec
)
assert
isinstance
(
dist_tensor
=
DistributedTensor
(
x
)
shard_spec
,
list
),
"Argument shard_spec {} is not an instance of list"
.
format
(
shard_spec
)
if
isinstance
(
x
,
str
):
x
=
paddle
.
fluid
.
default_main_program
().
global_block
().
_var_recursive
(
x
)
dist_tensor
=
DistributedTensor
(
x
)
else
:
dist_tensor
=
DistributedTensor
(
x
)
serial_tensor
=
dist_tensor
.
serial_tensor
serial_tensor
=
dist_tensor
.
serial_tensor
dist_tensor
.
dist_attr
.
process_mesh
=
process_mesh
dist_tensor
.
dist_attr
.
process_mesh
=
process_mesh
if
serial_tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
\
if
(
or
serial_tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
\
serial_tensor
.
type
==
core
.
VarDesc
.
VarType
.
READER
or
serial_tensor
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
:
or
serial_tensor
.
type
==
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
or
serial_tensor
.
type
==
core
.
VarDesc
.
VarType
.
STEP_SCOPES
):
tensor_shape
=
[]
tensor_shape
=
[]
else
:
else
:
tensor_shape
=
serial_tensor
.
shape
tensor_shape
=
serial_tensor
.
shape
if
shard_spec
is
not
None
:
if
shard_spec
is
not
None
:
assert
verify_shard_spec
(
shard_spec
,
tensor_shape
,
process_mesh
),
\
assert
verify_shard_spec
(
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}."
.
format
(
shard_spec
,
tensor_shape
,
process_mesh
serial_tensor
.
name
,
shard_spec
,
tensor_shape
,
process_mesh
)
),
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}."
.
format
(
serial_tensor
.
name
,
shard_spec
,
tensor_shape
,
process_mesh
)
dist_tensor
.
dist_attr
.
dims_mapping
=
convert_to_dims_mapping
(
dist_tensor
.
dist_attr
.
dims_mapping
=
convert_to_dims_mapping
(
shard_spec
,
process_mesh
)
shard_spec
,
process_mesh
)
if
process_mesh
is
not
None
:
if
process_mesh
is
not
None
:
dist_tensor
.
dist_attr
.
mark_annotated
(
"process_mesh"
)
dist_tensor
.
dist_attr
.
mark_annotated
(
"process_mesh"
)
if
shard_spec
is
not
None
:
if
shard_spec
is
not
None
:
...
@@ -97,6 +111,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
...
@@ -97,6 +111,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
default_dist_ctx
=
get_default_distributed_context
()
default_dist_ctx
=
get_default_distributed_context
()
default_dist_ctx
.
add_dist_tensor_for_program
(
dist_tensor
)
default_dist_ctx
.
add_dist_tensor_for_program
(
dist_tensor
)
dist_tensor
=
default_dist_ctx
.
get_dist_tensor_for_program
(
x
)
dist_tensor
=
default_dist_ctx
.
get_dist_tensor_for_program
(
x
)
default_dist_ctx
.
add_process_mesh
(
process_mesh
)
return
x
return
x
...
@@ -144,41 +159,54 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
...
@@ -144,41 +159,54 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
"""
"""
if
process_mesh
is
not
None
:
if
process_mesh
is
not
None
:
assert
isinstance
(
process_mesh
,
ProcessMesh
),
\
assert
isinstance
(
"Argument process_mesh {} is not an instance of ProcessMesh"
.
format
(
process_mesh
)
process_mesh
,
ProcessMesh
),
"Argument process_mesh {} is not an instance of ProcessMesh"
.
format
(
process_mesh
)
else
:
else
:
process_mesh
=
get_current_process_mesh
()
process_mesh
=
get_current_process_mesh
()
assert
process_mesh
is
not
None
,
\
assert
(
"Specify the process mesh argument or use ProcessMesh context manager first."
process_mesh
is
not
None
),
"Specify the process mesh argument or use ProcessMesh context manager first."
in_dims_mappings
=
[]
in_dims_mappings
=
[]
if
in_shard_specs
is
not
None
:
if
in_shard_specs
is
not
None
:
assert
all
((
isinstance
(
shard_spec
,
list
)
or
shard_spec
is
None
)
for
shard_spec
in
in_shard_specs
),
\
assert
all
(
"in_shard_spec {} is not a list of list or None"
.
format
(
in_shard_specs
)
(
isinstance
(
shard_spec
,
list
)
or
shard_spec
is
None
)
for
shard_spec
in
in_shard_specs
),
"in_shard_spec {} is not a list of list or None"
.
format
(
in_shard_specs
)
for
shard_spec
in
in_shard_specs
:
for
shard_spec
in
in_shard_specs
:
if
shard_spec
is
not
None
:
if
shard_spec
is
not
None
:
in_dims_mappings
.
append
(
in_dims_mappings
.
append
(
convert_to_dims_mapping
(
shard_spec
,
process_mesh
))
convert_to_dims_mapping
(
shard_spec
,
process_mesh
)
)
else
:
else
:
in_dims_mappings
.
append
(
None
)
in_dims_mappings
.
append
(
None
)
out_dims_mappings
=
[]
out_dims_mappings
=
[]
if
out_shard_specs
is
not
None
:
if
out_shard_specs
is
not
None
:
assert
all
((
isinstance
(
shard_spec
,
list
)
or
shard_spec
is
None
)
for
shard_spec
in
out_shard_specs
),
\
assert
all
(
"out_shard_spec {} is not a list of list or None"
.
format
(
out_shard_specs
)
(
isinstance
(
shard_spec
,
list
)
or
shard_spec
is
None
)
for
shard_spec
in
out_shard_specs
),
"out_shard_spec {} is not a list of list or None"
.
format
(
out_shard_specs
)
for
shard_spec
in
out_shard_specs
:
for
shard_spec
in
out_shard_specs
:
if
shard_spec
is
not
None
:
if
shard_spec
is
not
None
:
out_dims_mappings
.
append
(
out_dims_mappings
.
append
(
convert_to_dims_mapping
(
shard_spec
,
process_mesh
))
convert_to_dims_mapping
(
shard_spec
,
process_mesh
)
)
else
:
else
:
out_dims_mappings
.
append
(
None
)
out_dims_mappings
.
append
(
None
)
op
=
DistributedOperatorHelper
(
op
,
process_mesh
,
in_dims_mappings
,
op
=
DistributedOperatorHelper
(
out_dims_mappings
)
op
,
process_mesh
,
in_dims_mappings
,
out_dims_mappings
)
return
op
return
op
def
recompute
(
op
):
def
recompute
(
op
):
class
RecomputeOperator
:
class
RecomputeOperator
:
def
__init__
(
self
,
op
):
def
__init__
(
self
,
op
):
self
.
_op
=
op
self
.
_op
=
op
...
@@ -219,11 +247,13 @@ def add_to_collection(collection_name, value, name=None):
...
@@ -219,11 +247,13 @@ def add_to_collection(collection_name, value, name=None):
_g_collections
[
collection_name
]
=
[]
_g_collections
[
collection_name
]
=
[]
if
name
is
not
None
:
if
name
is
not
None
:
for
_
,
v
in
_g_collections
[
collection_name
]:
for
_
,
v
in
_g_collections
[
collection_name
]:
if
v
==
value
:
return
if
v
==
value
:
return
_g_collections
[
collection_name
].
append
((
name
,
value
))
_g_collections
[
collection_name
].
append
((
name
,
value
))
else
:
else
:
for
_
,
v
in
_g_collections
[
collection_name
]:
for
_
,
v
in
_g_collections
[
collection_name
]:
if
v
==
value
:
return
if
v
==
value
:
return
_g_collections
[
collection_name
].
append
((
None
,
value
))
_g_collections
[
collection_name
].
append
((
None
,
value
))
...
...
python/paddle/distributed/auto_parallel/operators/__init__.py
浏览文件 @
92c2dcbd
...
@@ -35,3 +35,4 @@ from . import dist_fused_attention
...
@@ -35,3 +35,4 @@ from . import dist_fused_attention
from
.
import
dist_reduce_sum_p
from
.
import
dist_reduce_sum_p
from
.
import
dist_shape
from
.
import
dist_shape
from
.
import
dist_assign
from
.
import
dist_assign
from
.
import
dist_scale
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/operators/dist_default.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/operators/dist_scale.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/process_group.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/process_mesh.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/strategy.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/tuner/profiler.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/fleet/fleet_executor_utils.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/parallel.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/passes/__init__.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/passes/auto_parallel_grad_clip.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/distributed/passes/auto_parallel_pipeline.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/executor.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/generation_pipeline_pass_unittest.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_generation_pipeline.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py
0 → 100644
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
python/paddle/tensor/stat.py
浏览文件 @
92c2dcbd
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录