Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
30845734
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
30845734
编写于
1月 18, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
1月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] Recompute Pass (#38920)
* [AutoParallel] Recompute Pass * update unittest * reshard for amp * add comment
上级
4aa91fd6
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
569 addition
and
39 deletion
+569
-39
python/paddle/distributed/auto_parallel/completion.py
python/paddle/distributed/auto_parallel/completion.py
+22
-0
python/paddle/distributed/auto_parallel/dist_attribute.py
python/paddle/distributed/auto_parallel/dist_attribute.py
+13
-1
python/paddle/distributed/auto_parallel/dist_context.py
python/paddle/distributed/auto_parallel/dist_context.py
+14
-0
python/paddle/distributed/auto_parallel/dist_op.py
python/paddle/distributed/auto_parallel/dist_op.py
+2
-0
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+2
-0
python/paddle/distributed/auto_parallel/operators/dist_default.py
...addle/distributed/auto_parallel/operators/dist_default.py
+6
-0
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+1
-1
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+4
-4
python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py
...buted/auto_parallel/operators/dist_update_loss_scaling.py
+2
-2
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+14
-11
python/paddle/distributed/auto_parallel/partitioner.py
python/paddle/distributed/auto_parallel/partitioner.py
+6
-5
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+14
-0
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+2
-7
python/paddle/distributed/passes/__init__.py
python/paddle/distributed/passes/__init__.py
+1
-0
python/paddle/distributed/passes/auto_parallel_recompute.py
python/paddle/distributed/passes/auto_parallel_recompute.py
+402
-0
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
...n/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
+0
-1
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
...ttests/distributed_passes/auto_parallel_pass_test_base.py
+1
-7
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py
...s/distributed_passes/test_auto_parallel_recompute_pass.py
+63
-0
未找到文件。
python/paddle/distributed/auto_parallel/completion.py
浏览文件 @
30845734
...
...
@@ -822,6 +822,28 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
# TODO to add attribute for moment var
op
=
ops
[
idx
]
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
if
op
.
type
==
"clip_by_norm"
:
param_grad
=
vars
[
op
.
input
(
"X"
)[
0
]]
param_grad_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
param_grad
)
assert
param_grad_dist_attr
is
not
None
ref_process_mesh
=
param_grad_dist_attr
.
process_mesh
ref_dims_mapping
=
param_grad_dist_attr
.
dims_mapping
out
=
vars
[
op
.
output
(
"Out"
)[
0
]]
out_dist_attr
=
TensorDistributedAttribute
()
out_dist_attr
.
process_mesh
=
ref_process_mesh
out_dist_attr
.
dims_mapping
=
ref_dims_mapping
dist_context
.
set_tensor_dist_attr_for_program
(
out
,
out_dist_attr
)
op_dist_attr
=
OperatorDistributedAttribute
()
op_dist_attr
.
process_mesh
=
ref_process_mesh
op_dist_attr
.
set_input_dist_attr
(
param_grad
.
name
,
param_grad_dist_attr
)
op_dist_attr
.
set_output_dist_attr
(
out
.
name
,
out_dist_attr
)
dist_context
.
set_op_dist_attr_for_program
(
op
,
op_dist_attr
)
if
"Grad"
in
op
.
input_names
and
"Param"
in
ops
[
idx
].
input_names
:
assert
len
(
op
.
input
(
...
...
python/paddle/distributed/auto_parallel/dist_attribute.py
浏览文件 @
30845734
...
...
@@ -21,7 +21,9 @@ _g_tensor_dist_attr_field_keys = [
"process_mesh"
,
"dims_mapping"
,
"shard_sizes"
,
"device_placement"
]
_g_op_dist_attr_field_keys
=
[
"process_mesh"
,
"impl_type"
,
"impl_idx"
]
_g_op_dist_attr_field_keys
=
[
"process_mesh"
,
"impl_type"
,
"impl_idx"
,
"is_recompute"
]
_g_op_input_suffix
=
"@input"
...
...
@@ -178,6 +180,7 @@ class OperatorDistributedAttribute:
self
.
_inputs_dist_attrs
=
{}
self
.
_outputs_dist_attrs
=
{}
self
.
_is_annotated
=
{}
self
.
_is_recompute
=
False
@
property
def
process_mesh
(
self
):
...
...
@@ -214,6 +217,15 @@ class OperatorDistributedAttribute:
if
impl_idx
is
not
None
:
self
.
_impl_idx
=
impl_idx
@
property
def
is_recompute
(
self
):
return
self
.
_is_recompute
@
is_recompute
.
setter
def
is_recompute
(
self
,
is_recompute
):
assert
isinstance
(
is_recompute
,
bool
)
self
.
_is_recompute
=
is_recompute
@
property
def
inputs_dist_attrs
(
self
):
return
self
.
_inputs_dist_attrs
...
...
python/paddle/distributed/auto_parallel/dist_context.py
浏览文件 @
30845734
...
...
@@ -166,6 +166,13 @@ class DistributedContext:
else
:
return
None
def
get_tensor_dist_attr_for_program_with_id
(
self
,
tensor_id
):
dist_tensor
=
self
.
_dist_tensors_for_program
.
get
(
tensor_id
,
None
)
if
dist_tensor
:
return
dist_tensor
.
dist_attr
else
:
return
None
def
set_tensor_dist_attr_for_program
(
self
,
serial_tensor
,
dist_attr
):
dist_tensor
=
DistributedTensor
(
serial_tensor
,
dist_attr
)
self
.
add_dist_tensor_for_program
(
dist_tensor
)
...
...
@@ -192,6 +199,13 @@ class DistributedContext:
else
:
return
None
def
get_op_dist_attr_for_program_with_id
(
self
,
op_id
):
dist_op
=
self
.
_dist_ops_for_program
.
get
(
op_id
,
None
)
if
dist_op
:
return
dist_op
.
dist_attr
else
:
return
None
def
set_op_dist_attr_for_program
(
self
,
serial_op
,
dist_attr
):
dist_op
=
DistributedOperator
(
serial_op
,
dist_attr
)
self
.
add_dist_op_for_program
(
dist_op
)
...
...
python/paddle/distributed/auto_parallel/dist_op.py
浏览文件 @
30845734
...
...
@@ -99,6 +99,8 @@ class DistributedOperator:
self
.
_dist_attr
.
impl_type
=
"default"
if
self
.
_dist_attr
.
impl_idx
is
None
:
self
.
_dist_attr
.
impl_idx
=
-
2
if
self
.
_dist_attr
.
is_recompute
is
None
:
self
.
_dist_attr
.
is_recompute
=
False
def
_filter_dist_attr
(
self
,
dist_attr
):
if
dist_attr
is
None
:
...
...
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
30845734
...
...
@@ -118,6 +118,8 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
def
is_parameter_related
(
varname
,
block
):
if
".subprog_"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".subprog_"
)]
if
".cast_fp"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".cast_fp"
)]
assert
block
.
has_var
(
varname
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_default.py
浏览文件 @
30845734
...
...
@@ -216,6 +216,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
varname
,
main_block
):
# NOTE: When amp and recompute pass are effective at the same time,
# if a parameter is casted and recomputed, the 'parameter@GARD' can not
# be found in the grad_op's output.
if
"subprog_"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".subprog_"
)]
assert
len
(
backward_op
.
desc
.
input
(
input_name
)
)
==
1
,
"parameter input to grad op should be length 1, but got [{}]"
.
format
(
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
30845734
...
...
@@ -283,7 +283,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
allreduce_op_dist_attr
)
# param initialization sync
if
Weight_var
.
is_parameter
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
assert
Weight_var
.
name
not
in
dist_op_context
.
already_init_sync_vars
dist_op_context
.
already_init_sync_vars
.
add
(
Weight_var
.
name
)
param
=
startup_block
.
var
(
Weight_var
.
name
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
30845734
...
...
@@ -680,7 +680,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx
.
set_op_dist_attr_for_program
(
matmul_op
,
matmul_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
...
...
@@ -968,7 +968,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
allreduce_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
...
...
@@ -1383,7 +1383,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx
.
set_op_dist_attr_for_program
(
matmul_v2_op
,
matmulv2_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
...
...
@@ -1666,7 +1666,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
allreduce_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py
浏览文件 @
30845734
...
...
@@ -83,9 +83,9 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
assert
'Out'
in
kwargs
,
"output [{}] is not given"
.
format
(
'Out'
)
assert
'LossScaling'
in
kwargs
,
"output [{}] is not given"
.
format
(
'LossScaling'
)
assert
'OutGoodSteps'
in
kwargs
,
"
in
put [{}] is not given"
.
format
(
assert
'OutGoodSteps'
in
kwargs
,
"
out
put [{}] is not given"
.
format
(
'OutGoodSteps'
)
assert
'OutBadSteps'
in
kwargs
,
"
in
put [{}] is not given"
.
format
(
assert
'OutBadSteps'
in
kwargs
,
"
out
put [{}] is not given"
.
format
(
'OutBadSteps'
)
assert
len
(
kwargs
[
'FoundInfinite'
])
==
1
,
\
...
...
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
30845734
...
...
@@ -97,8 +97,8 @@ class AutoParallelizer:
if
suffix
in
attr_name
:
op
.
_remove_attr
(
attr_name
)
def
_apply_pre_optimization_passe
d
(
self
,
main_program
,
startup_program
,
loss
,
params_grads
):
def
_apply_pre_optimization_passe
s
(
self
,
main_program
,
startup_program
,
loss
,
params_grads
,
no_grad_set
):
# apply amp pass
if
self
.
_dist_strategy
.
amp
:
config
=
copy
.
deepcopy
(
self
.
_dist_strategy
.
amp_configs
)
...
...
@@ -111,11 +111,14 @@ class AutoParallelizer:
# apply recompute pass
if
self
.
_dist_strategy
.
recompute
:
auto_parallel_recompute_pass
=
new_pass
(
"auto_parallel_recompute_pass"
,
self
.
_dist_strategy
.
recompute_configs
)
auto_parallel_recompute_pass
.
apply
(
main_program
,
startup_program
,
self
.
_pass_context
)
config
=
copy
.
deepcopy
(
self
.
_dist_strategy
.
recompute_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"no_grad_set"
]
=
copy
.
deepcopy
(
no_grad_set
)
config
[
"loss"
]
=
loss
auto_parallel_recompute_pass
=
new_pass
(
"auto_parallel_recompute"
,
config
)
auto_parallel_recompute_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
def
_generate_backward
(
self
,
main_program
,
startup_program
,
loss
,
parameter_list
,
no_grad_set
,
callbacks
):
...
...
@@ -144,7 +147,7 @@ class AutoParallelizer:
return
optimize_ops
def
_apply_post_optimization_passe
d
(
self
,
main_program
,
startup_program
,
def
_apply_post_optimization_passe
s
(
self
,
main_program
,
startup_program
,
rank
,
params_grads
):
if
self
.
_dist_strategy
.
sharding
:
...
...
@@ -188,9 +191,9 @@ class AutoParallelizer:
self
.
_parameter_list
,
self
.
_no_grad_set
,
self
.
_callbacks
)
# serial forward pass
self
.
_apply_pre_optimization_passe
d
(
completed_main_program
,
self
.
_apply_pre_optimization_passe
s
(
completed_main_program
,
serial_startup_program
,
serial_loss
,
params_grads
)
params_grads
,
self
.
_no_grad_set
)
# Logical partition
partitioner
=
Partitioner
(
self
.
_dist_context
,
rank
)
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
partitioner
.
partition
(
...
...
@@ -207,7 +210,7 @@ class AutoParallelizer:
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank
,
self
.
_dist_context
)
self
.
_apply_post_optimization_passe
d
(
dist_main_prog
,
dist_startup_prog
,
self
.
_apply_post_optimization_passe
s
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_params_grads
)
g_process_group_map
=
None
if
not
relaunch_phase
:
...
...
python/paddle/distributed/auto_parallel/partitioner.py
浏览文件 @
30845734
...
...
@@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di
from
.dist_attribute
import
OperatorDistributedAttribute
from
.process_group
import
new_process_group
from
.utils
import
set_dist_op_desc_original_id
from
.utils
import
print_program_with_dist_attr
,
is_forward_op
,
is_backward_op
,
is_recompute_op
from
.utils
import
print_program_with_dist_attr
,
is_forward_op
,
is_backward_op
from
.operators.common
import
BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__
=
[
"lod_tensor_blocking_queue_0"
]
...
...
@@ -200,7 +200,8 @@ class Partitioner(object):
serial_output_varname
]
=
new_varname
# partition op
if
is_forward_op
(
op
):
op_dist_attr
=
self
.
_dist_context
.
get_op_dist_attr_for_program
(
op
)
if
is_forward_op
(
op
)
or
op_dist_attr
.
is_recompute
:
kinputs
,
koutputs
=
dist_op_context
.
prepare_context
(
op
)
dist_op_forward_impl
=
_get_dist_op_forward_implement
(
op
,
self
.
_dist_context
)
...
...
@@ -380,9 +381,9 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
# NOTE trick for dist ops that only have backward implement
if
backward_op
.
type
in
BACKWARD_ONLY_DIST_OPS
:
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
backward_op
)
assert
op_dist_attr
.
impl_idx
>=
0
return
get_distributed_operator_impl_container
(
backward_op
.
type
)
.
get_impl
(
op_dist_attr
.
impl_idx
)
dist_op
=
get_distributed_operator_impl_container
(
backward_op
.
type
)
if
dist_op
and
op_dist_attr
.
impl_idx
>=
0
:
return
dist_op
.
get_impl
(
op_dist_attr
.
impl_idx
)
dist_op
=
get_distributed_operator_impl_container
(
"default"
)
return
dist_op
.
get_impl
(
0
)
...
...
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
30845734
...
...
@@ -26,6 +26,9 @@ from .dist_context import DistributedContext
from
.dist_attribute
import
OperatorDistributedAttribute
,
TensorDistributedAttribute
from
.process_group
import
new_process_group
,
ProcessGroup
,
_g_process_group_map
# NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops
=
[
'check_finite_and_unscale'
,
'update_loss_scaling'
]
class
AllGatherOpDesc
:
"""
...
...
@@ -966,6 +969,17 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
while
idx
<
len
(
block
.
ops
):
pre_op_count
=
len
(
block
.
ops
)
op
=
block
.
ops
[
idx
]
def
_is_special_op
(
op
):
global
_g_special_ops
if
op
.
type
in
_g_special_ops
:
return
True
return
False
if
_is_special_op
(
op
):
idx
+=
1
continue
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
is
not
None
:
idx_offset
=
0
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
30845734
...
...
@@ -1005,8 +1005,8 @@ def set_grad_var_shape(program, dist_context):
assert
op_dist_attr
is
not
None
for
var_name
in
op
.
output_arg_names
:
assert
"@GRAD"
in
var_nam
e
if
"@GRAD"
not
in
var_name
:
continu
e
forward_var_name
=
var_name
[:
var_name
.
find
(
"@GRAD"
)]
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_identity"
,
"scale"
,
"cast"
...
...
@@ -1076,11 +1076,6 @@ def is_backward_op(op):
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
&
int
(
OpRole
.
Backward
)
def
is_recompute_op
(
op
):
return
OP_ROLE_KEY
in
op
.
attr_names
and
\
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
==
9
def
is_loss_op
(
op
):
return
OP_ROLE_KEY
in
op
.
attr_names
and
\
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
==
(
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
|
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Loss
))
...
...
python/paddle/distributed/passes/__init__.py
浏览文件 @
30845734
...
...
@@ -17,6 +17,7 @@ from .fuse_all_reduce import *
from
.auto_parallel_gradient_merge
import
*
from
.auto_parallel_sharding
import
*
from
.auto_parallel_amp
import
*
from
.auto_parallel_recompute
import
*
from
.cpp_pass
import
*
__all__
=
[
...
...
python/paddle/distributed/passes/auto_parallel_recompute.py
0 → 100644
浏览文件 @
30845734
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
logging
from
.pass_base
import
PassBase
,
register_pass
from
paddle.fluid
import
core
,
unique_name
from
paddle.fluid
import
framework
as
framework
from
paddle.fluid.framework
import
Variable
,
Operator
from
paddle.fluid.backward
import
_append_grad_suffix_
,
_get_no_grad_set_name
from
paddle.fluid.backward
import
ProgramStats
,
_rename_arg_
,
_find_op_path_
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.dist_attribute
import
OperatorDistributedAttribute
from
paddle.distributed.auto_parallel.utils
import
get_loss_op
,
set_var_dist_attr
,
set_dist_op_desc_original_id
from
paddle.distributed.auto_parallel.utils
import
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
class
RecomputeState
(
ProgramStats
):
def
__init__
(
self
,
block
,
ops
):
super
(
RecomputeState
,
self
).
__init__
(
block
=
block
,
ops
=
ops
)
self
.
_block
=
block
self
.
_ops
=
ops
self
.
var_op_deps
=
{}
def
build_stats
(
self
):
for
i
,
op
in
enumerate
(
self
.
_ops
):
for
name
in
op
.
desc
.
input_arg_names
():
if
name
in
self
.
var_op_deps
:
self
.
var_op_deps
[
name
][
"var_as_input_ops"
].
extend
([
i
])
else
:
self
.
var_op_deps
[
name
]
=
{}
self
.
var_op_deps
[
name
][
"var_as_input_ops"
]
=
[
i
]
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
=
[]
for
name
in
op
.
desc
.
output_arg_names
():
if
name
in
self
.
var_op_deps
:
self
.
var_op_deps
[
name
][
"var_as_output_ops"
].
extend
([
i
])
else
:
self
.
var_op_deps
[
name
]
=
{}
self
.
var_op_deps
[
name
][
"var_as_input_ops"
]
=
[]
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
=
[
i
]
def
get_recompute_segments
(
self
,
checkpoints
):
""" get recompute segments from checkpoints """
segments
=
[]
start_idx
=
-
1
pre_segment_end_idx
=
-
1
while
start_idx
+
1
<
len
(
checkpoints
):
if
start_idx
==
-
1
:
ckpt_name
=
checkpoints
[
start_idx
+
1
]
if
ckpt_name
not
in
self
.
var_op_deps
:
start_idx
+=
1
continue
op_idx_list
=
self
.
var_op_deps
[
ckpt_name
][
"var_as_output_ops"
]
if
op_idx_list
:
segments
.
append
([
0
,
max
(
op_idx_list
)
+
1
])
else
:
flag
,
min_idx
,
max_idx
=
self
.
is_subgraph
(
[
checkpoints
[
start_idx
]],
[
checkpoints
[
start_idx
+
1
]])
if
flag
:
min_idx
=
self
.
_update_segment_start
(
min_idx
,
pre_segment_end_idx
)
segments
.
append
([
min_idx
,
max_idx
+
1
])
else
:
logging
.
info
(
"Could not recompute op range [{}] - [{}] "
.
format
(
min_idx
,
max_idx
+
1
))
start_idx
+=
1
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
segments
):
logging
.
info
(
"recompute segment[{}]"
.
format
(
i
))
logging
.
info
(
"segment start op: [{}]: [{}] [{}]"
.
format
(
self
.
_ops
[
idx1
].
desc
.
type
(),
self
.
_ops
[
idx1
].
desc
.
input_arg_names
(
),
self
.
_ops
[
idx1
].
desc
.
output_arg_names
()))
logging
.
info
(
"segment end op: [{}]: [{}] [{}]"
.
format
(
self
.
_ops
[
idx2
-
1
].
desc
.
type
(),
self
.
_ops
[
idx2
-
1
].
desc
.
input_arg_names
(
),
self
.
_ops
[
idx2
-
1
].
desc
.
output_arg_names
()))
return
segments
def
modify_forward_desc_for_recompute
(
self
,
dist_context
):
"""
If program's foward part has 'dropout' op, this function will insert
a seed op before it to guarantee that two dropout op have the same outputs.
"""
op_types
=
[
op
.
desc
.
type
()
for
op
in
self
.
_ops
]
if
"dropout"
not
in
op_types
:
return
op_idx
=
0
while
op_idx
<
len
(
self
.
_ops
):
cur_op
=
self
.
_ops
[
op_idx
]
if
"grad"
in
cur_op
.
type
:
break
if
cur_op
.
type
!=
"dropout"
:
op_idx
+=
1
continue
if
cur_op
.
input
(
"Seed"
)
is
not
None
and
len
(
cur_op
.
input
(
"Seed"
)):
op_idx
+=
1
continue
cur_op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
cur_op
)
# insert seed op to guarantee that two dropout op have the same outputs
op_unique_name
=
unique_name
.
generate
(
"seed"
)
var_unique_name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
op_unique_name
,
'tmp'
]))
seed_var
=
self
.
_block
.
create_var
(
name
=
var_unique_name
,
dtype
=
'int32'
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
False
)
# set new seed_var's dist_attr
ref_dims_mapping
=
[
-
1
]
ref_process_mesh
=
cur_op_dist_attr
.
process_mesh
seed_var_dist_attr
=
set_var_dist_attr
(
dist_context
,
seed_var
,
ref_dims_mapping
,
ref_process_mesh
)
seed
=
0
if
cur_op
.
attr
(
"fix_seed"
)
is
False
else
int
(
cur_op
.
attr
(
"seed"
))
seed_op
=
self
.
_block
.
_insert_op_without_sync
(
index
=
cur_op
.
idx
,
type
=
"seed"
,
inputs
=
{},
outputs
=
{
"Out"
:
seed_var
},
attrs
=
{
"seed"
:
seed
,
"force_cpu"
:
True
})
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
seed_op
,
ref_process_mesh
,
ref_dims_mapping
,
dist_context
)
# modify dropout op's desc
self
.
_ops
.
insert
(
op_idx
,
seed_op
)
cur_op
.
desc
.
set_input
(
"Seed"
,
[
var_unique_name
])
cur_op
.
desc
.
remove_attr
(
"fix_seed"
)
cur_op
.
desc
.
remove_attr
(
"seed"
)
cur_op_dist_attr
.
set_input_dist_attr
(
seed_var
.
name
,
seed_var_dist_attr
)
self
.
_block
.
_sync_with_cpp
()
op_idx
+=
2
def
_find_op_index
(
block
,
cur_op
):
for
idx
in
range
(
block
.
desc
.
op_size
()):
if
cur_op
.
desc
==
block
.
desc
.
op
(
idx
):
return
idx
return
-
1
def
_get_stop_gradients
(
program
,
no_grad_set
):
""" get no grad var """
if
no_grad_set
is
None
:
no_grad_set
=
set
()
else
:
no_grad_set
=
_get_no_grad_set_name
(
no_grad_set
)
no_grad_set_name
=
set
()
for
var
in
program
.
list_vars
():
assert
isinstance
(
var
,
Variable
)
if
"@GRAD"
in
var
.
name
:
break
if
var
.
stop_gradient
:
no_grad_set_name
.
add
(
_append_grad_suffix_
(
var
.
name
))
no_grad_set_name
.
update
(
list
(
map
(
_append_grad_suffix_
,
no_grad_set
)))
return
no_grad_set_name
def
_add_needed_descs_to_block
(
descs
,
block
,
main_block
,
in_memory_vars
,
dist_context
):
"""
Get the recomputed ops which will insert the backward part
"""
if
len
(
descs
)
==
0
:
return
[]
result_descs
=
[]
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
backward
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
for
desc
in
descs
:
if
isinstance
(
desc
,
framework
.
Operator
):
desc
=
desc
.
desc
if
isinstance
(
desc
,
tuple
):
desc
=
desc
[
0
]
is_needed
=
False
for
name
in
desc
.
output_arg_names
():
if
main_block
.
has_var
(
name
)
and
main_block
.
var
(
name
).
persistable
:
continue
if
name
not
in
in_memory_vars
:
is_needed
=
True
if
is_needed
:
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
set_dist_op_desc_original_id
(
new_op_desc
,
desc
,
dist_context
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
result_descs
.
append
(
new_op_desc
)
return
result_descs
@
register_pass
(
"auto_parallel_recompute"
)
class
RecomputePass
(
PassBase
):
def
__init__
(
self
):
super
(
RecomputePass
,
self
).
__init__
()
self
.
set_attr
(
"checkpoints"
,
None
)
self
.
set_attr
(
"loss"
,
None
)
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"no_grad_set"
,
None
)
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
return
False
if
self
.
get_attr
(
"loss"
)
is
None
:
return
False
if
self
.
get_attr
(
"checkpoints"
)
is
None
:
return
False
return
True
def
_check_conflict
(
self
,
other_pass
):
return
True
def
_apply_single_impl
(
self
,
main_programs
,
startup_programs
,
context
):
checkpoints
=
self
.
get_attr
(
"checkpoints"
)
loss
=
self
.
get_attr
(
"loss"
)
no_grad_set
=
self
.
get_attr
(
"no_grad_set"
)
self
.
_dist_context
=
self
.
get_attr
(
"dist_context"
)
main_block
=
main_programs
.
global_block
()
no_grad_set_name
=
_get_stop_gradients
(
main_programs
,
no_grad_set
)
# get op_path which is related to loss
op_path
=
_find_op_path_
(
main_block
,
[
loss
],
[],
no_grad_set_name
)
# step 1: build recompute state
rc_state
=
RecomputeState
(
main_block
,
op_path
)
rc_state
.
modify_forward_desc_for_recompute
(
self
.
_dist_context
)
rc_state
.
build_stats
()
checkpoints
=
rc_state
.
sort_checkpoints
(
checkpoints
)
segments
=
rc_state
.
get_recompute_segments
(
checkpoints
)
if
segments
==
[]:
return
# step 2: get vars_should_be_hold
vars_should_be_hold
=
[]
for
segment
in
segments
:
vars_should_be_hold
.
extend
(
rc_state
.
get_out_of_subgraph_vars
(
segment
[
0
],
segment
[
1
]))
cross_vars
=
set
(
vars_should_be_hold
)
-
set
(
checkpoints
)
logging
.
info
(
"found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars"
.
format
(
len
(
cross_vars
),
cross_vars
))
vars_should_be_hold
.
extend
(
rc_state
.
get_reserved_vars
())
vars_should_be_hold
.
extend
(
rc_state
.
get_input_nodes
())
vars_should_be_hold
=
list
(
set
(
vars_should_be_hold
))
vars_in_memory
=
vars_should_be_hold
+
checkpoints
# step 3: get recomputed fwd ops desc
var_name_dict
=
{}
ckpt_ops_dict
=
{}
buffer_block
=
main_block
.
program
.
_create_block
()
for
i
,
segment
in
enumerate
(
segments
[::
-
1
]):
fwd_ops
=
op_path
[
segment
[
0
]:
segment
[
1
]]
var_suffix
=
".subprog_%d"
%
i
for
op
in
fwd_ops
:
input_and_output_names
=
[]
input_and_output_names
.
extend
(
op
.
desc
.
input_arg_names
())
input_and_output_names
.
extend
(
op
.
desc
.
output_arg_names
())
cur_op_dist_attr
=
self
.
_dist_context
.
get_op_dist_attr_for_program
(
op
)
assert
cur_op_dist_attr
is
not
None
for
name
in
input_and_output_names
:
if
main_block
.
var
(
name
).
persistable
or
name
in
checkpoints
:
continue
if
name
in
vars_should_be_hold
:
continue
if
name
not
in
var_name_dict
:
ref_process_mesh
=
cur_op_dist_attr
.
process_mesh
if
name
in
op
.
desc
.
input_arg_names
():
ref_dims_mapping
=
cur_op_dist_attr
.
get_input_dims_mapping
(
name
)
else
:
ref_dims_mapping
=
cur_op_dist_attr
.
get_output_dims_mapping
(
name
)
# record recomputed var's old_name and new_name (old_name.subprog_XXX)
# create new var with new name
var_name_dict
[
name
]
=
name
+
var_suffix
ref_var
=
main_block
.
var
(
name
)
rc_var
=
main_block
.
create_var
(
name
=
var_name_dict
[
name
],
shape
=
ref_var
.
shape
,
dtype
=
ref_var
.
dtype
,
type
=
ref_var
.
type
,
persistable
=
ref_var
.
persistable
,
stop_gradient
=
ref_var
.
stop_gradient
)
# set new recomputed var's dist attr
set_var_dist_attr
(
self
.
_dist_context
,
rc_var
,
ref_dims_mapping
,
ref_process_mesh
)
# get recomputed segment's descs
segment_descs
=
_add_needed_descs_to_block
(
fwd_ops
,
buffer_block
,
main_block
,
vars_in_memory
,
self
.
_dist_context
)
# rename recomputed ops' input and output var name
for
key
in
var_name_dict
:
_rename_arg_
(
segment_descs
,
key
,
var_name_dict
[
key
])
# NOTE: one forward op could be correspond to multiple xxx_grad op.
# When traversing all grad_ops in reverse, need to set a flag to indicate
# whether the ckpt and its segment_descs can be used.
ckpt_op
=
op_path
[
segment
[
1
]
-
1
]
ckpt_ops_dict
[
ckpt_op
.
desc
.
id
()]
=
[
True
,
segment_descs
]
# step 4: insert recomputed fwd ops
ops
=
main_block
.
ops
loss_op
=
get_loss_op
(
main_block
)
loss_op_idx
=
_find_op_index
(
main_block
,
loss_op
)
dist_op_context
=
self
.
_dist_context
.
dist_op_context
assert
loss_op_idx
!=
-
1
# Traversing all grad_ops in reverse, and if the fwd op corresponding to reverse op is checkpoints,
# segments ops should be inserted.
for
i
in
range
(
len
(
ops
)
-
1
,
loss_op_idx
,
-
1
):
grad_op
=
ops
[
i
]
# remove some attrs of dropout_grad op's desc
if
grad_op
.
type
==
"dropout_grad"
:
grad_op
.
desc
.
remove_attr
(
"fix_seed"
)
grad_op
.
desc
.
remove_attr
(
"seed"
)
main_block
.
_sync_with_cpp
()
# rename grad op's var_name which is not in 'vars_in_memory'
for
key
in
var_name_dict
:
self
.
reset_op_dist_attr
(
grad_op
,
var_name_dict
)
_rename_arg_
([
grad_op
.
desc
],
key
,
var_name_dict
[
key
])
# insert recomputed ops
if
grad_op
.
desc
.
id
()
in
dist_op_context
.
grad_op_id_to_op_id
:
fwd_op_id
=
dist_op_context
.
grad_op_id_to_op_id
[
grad_op
.
desc
.
id
(
)]
if
fwd_op_id
in
ckpt_ops_dict
and
ckpt_ops_dict
[
fwd_op_id
][
0
]:
idx
=
grad_op
.
idx
while
idx
-
1
>=
0
and
ops
[
idx
-
1
].
type
==
"sum"
:
idx
-=
1
segment_descs
=
ckpt_ops_dict
[
fwd_op_id
][
1
]
for
_
,
op_desc
in
reversed
(
list
(
enumerate
(
segment_descs
))):
rc_desc
=
main_block
.
desc
.
_insert_op
(
idx
)
rc_desc
.
copy_from
(
op_desc
)
rc_op
=
Operator
(
main_block
,
rc_desc
)
main_block
.
ops
.
insert
(
idx
,
rc_op
)
# set recomputed ops' dist attr
fwd_op_dist_attr
=
self
.
_dist_context
.
get_op_dist_attr_for_program_with_id
(
rc_desc
.
original_id
())
assert
fwd_op_dist_attr
is
not
None
self
.
set_op_dist_attr
(
rc_op
,
fwd_op_dist_attr
,
var_name_dict
)
ckpt_ops_dict
[
fwd_op_id
][
0
]
=
False
main_block
.
_sync_with_cpp
()
main_programs
.
_sync_with_cpp
()
def
reset_op_dist_attr
(
self
,
op
,
var_name_dict
):
op_dist_attr
=
self
.
_dist_context
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
is
not
None
for
input
in
op
.
desc
.
input_arg_names
():
if
input
in
var_name_dict
.
keys
():
in_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input
)
op_dist_attr
.
set_input_dist_attr
(
var_name_dict
[
input
],
in_dist_attr
)
for
output
in
op
.
desc
.
output_arg_names
():
if
output
in
var_name_dict
.
keys
():
out_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output
)
op_dist_attr
.
set_output_dist_attr
(
var_name_dict
[
output
],
out_dist_attr
)
def
set_op_dist_attr
(
self
,
op
,
old_dist_attr
,
var_name_dict
):
new_dist_attr
=
OperatorDistributedAttribute
()
new_dist_attr
.
is_recompute
=
True
new_dist_attr
.
impl_idx
=
old_dist_attr
.
impl_idx
new_dist_attr
.
process_mesh
=
old_dist_attr
.
process_mesh
for
input
in
old_dist_attr
.
inputs_dist_attrs
.
keys
():
if
input
in
var_name_dict
.
keys
():
in_dist_attr
=
old_dist_attr
.
inputs_dist_attrs
[
input
]
new_dist_attr
.
set_input_dist_attr
(
var_name_dict
[
input
],
in_dist_attr
)
else
:
in_dist_attr
=
old_dist_attr
.
inputs_dist_attrs
[
input
]
new_dist_attr
.
set_input_dist_attr
(
input
,
in_dist_attr
)
for
output
in
old_dist_attr
.
outputs_dist_attrs
.
keys
():
if
output
in
var_name_dict
.
keys
():
out_dist_attr
=
old_dist_attr
.
outputs_dist_attrs
[
output
]
new_dist_attr
.
set_output_dist_attr
(
var_name_dict
[
output
],
out_dist_attr
)
else
:
out_dist_attr
=
old_dist_attr
.
outputs_dist_attrs
[
output
]
new_dist_attr
.
set_output_dist_attr
(
output
,
out_dist_attr
)
self
.
_dist_context
.
set_op_dist_attr_for_program
(
op
,
new_dist_attr
)
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
浏览文件 @
30845734
...
...
@@ -894,7 +894,6 @@ class GPTModel(nn.Layer):
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
input_ids
.
shape
)
-
1
)]
})
attention_mask
.
stop_gradient
=
True
encoder_outputs
=
self
.
decoder
(
embedding_output
,
memory
=
None
,
...
...
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
浏览文件 @
30845734
...
...
@@ -110,14 +110,8 @@ class AutoPallelPassTestBase(DistPassTestBase):
elif
strategy
==
"mp"
:
modeling
.
_global_parallel_strategy
=
"mp"
modeling
.
_global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
,
1
])
elif
strategy
==
"pp"
:
modeling
.
_global_parallel_strategy
=
"pp"
modeling
.
_global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
,
1
])
modeling
.
PP_MESH_LIST
=
[
auto
.
ProcessMesh
(
mesh
=
[
0
]),
auto
.
ProcessMesh
(
mesh
=
[
1
])
]
else
:
raise
ValueError
(
"'get_gpt_model' only support dp
, mp and p
p."
)
raise
ValueError
(
"'get_gpt_model' only support dp
and m
p."
)
tokens
=
paddle
.
static
.
data
(
name
=
"tokens"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
...
...
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py
0 → 100644
浏览文件 @
30845734
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
random
import
numpy
as
np
import
unittest
import
paddle
import
paddle.nn
as
nn
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.auto_parallel
as
auto
from
paddle.distributed.passes
import
new_pass
,
PassManager
from
auto_parallel_pass_test_base
import
AutoPallelPassTestBase
class
TestRecomputePass
(
AutoPallelPassTestBase
):
def
init
(
self
):
if
paddle
.
is_compiled_with_cuda
():
paddle
.
set_flags
({
'FLAGS_cudnn_deterministic'
:
1
})
self
.
rtol
=
1e-6
self
.
atol
=
1e-8
rank
=
paddle
.
distributed
.
get_rank
()
paddle
.
seed
(
rank
+
2021
)
random
.
seed
(
rank
+
2021
)
np
.
random
.
seed
(
rank
+
2021
)
def
apply_passes
(
self
):
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
recompute
=
True
dist_strategy
.
recompute_configs
=
{
"checkpoints"
:
[
"tmp3"
,
"tmp6"
]}
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
def
test_bs_8
(
self
):
self
.
check_main
(
gpus
=
[
0
,
1
],
batch_size
=
8
,
sequence_len
=
512
,
vocab_size
=
1000
)
def
get_model
(
self
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
return
self
.
get_gpt_model
(
"mp"
,
place
,
batch_size
,
sequence_len
,
vocab_size
)
class
TestRecomputePassDP
(
TestRecomputePass
):
def
get_model
(
self
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
return
self
.
get_gpt_model
(
"dp"
,
place
,
batch_size
,
sequence_len
,
vocab_size
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录