Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d7f7963f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d7f7963f
编写于
11月 18, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
11月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] selective recompute (#48111)
* [AutoParallel] selective recompute * add cmakelist
上级
aafa9820
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
428 addition
and
235 deletion
+428
-235
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+1
-0
python/paddle/distributed/auto_parallel/dist_loader.py
python/paddle/distributed/auto_parallel/dist_loader.py
+1
-1
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+1
-30
python/paddle/distributed/auto_parallel/interface.py
python/paddle/distributed/auto_parallel/interface.py
+9
-1
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+36
-5
python/paddle/distributed/passes/auto_parallel_recompute.py
python/paddle/distributed/passes/auto_parallel_recompute.py
+94
-41
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py
.../tests/unittests/auto_parallel/recompute_pass_unittest.py
+16
-3
python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py
...tests/unittests/auto_parallel/test_selective_recompute.py
+175
-0
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
...n/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
+93
-154
未找到文件。
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
d7f7963f
...
...
@@ -55,6 +55,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug
RECOMPUTE
=
"recompute"
set_field_default_config
(
RECOMPUTE
,
"enable"
,
False
)
set_field_default_config
(
RECOMPUTE
,
"checkpoints"
,
None
)
set_field_default_config
(
RECOMPUTE
,
"no_recompute_segments"
,
[])
set_field_default_config
(
RECOMPUTE
,
"enable_tuning"
,
False
)
#########################################
...
...
python/paddle/distributed/auto_parallel/dist_loader.py
浏览文件 @
d7f7963f
...
...
@@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
raise
StopIteration
def
_infer_steps
(
self
):
if
isinstance
(
self
.
steps_per_epoch
,
int
)
and
self
.
steps_per_epoch
>
1
:
if
isinstance
(
self
.
steps_per_epoch
,
int
)
and
self
.
steps_per_epoch
>
0
:
return
self
.
steps_per_epoch
try
:
if
isinstance
(
self
.
dataset
,
IterableDataset
):
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
d7f7963f
...
...
@@ -610,7 +610,7 @@ class Engine:
if
mode
!=
"train"
:
serial_main_prog
=
serial_main_prog
.
clone
(
for_test
=
True
)
self
.
_set_recompute_ckpts
(
)
auto_utils
.
set_recompute_ckpts
(
self
.
_model
,
self
.
_strategy
)
self
.
_dist_contexts
[
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_startup_prog
,
...
...
@@ -1518,35 +1518,6 @@ class Engine:
var_name
=
_to_name_str
(
var
)
return
var_name
in
self
.
main_program
.
global_block
().
vars
def
_set_recompute_ckpts
(
self
):
# NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
recompute
=
self
.
_strategy
.
recompute
# extract ckpts by specific model
if
isinstance
(
self
.
_model
,
paddle
.
nn
.
Layer
):
if
hasattr
(
self
.
_model
,
"gpt"
)
and
self
.
_model
.
__class__
.
__name__
in
[
'GPTForPretraining'
,
'GPTForPretrainingAuto'
,
]:
exact_ckpts
=
self
.
_model
.
gpt
.
checkpoints
else
:
exact_ckpts
=
recompute
.
checkpoints
else
:
exact_ckpts
=
recompute
.
checkpoints
# modify strategy
if
recompute
.
enable
:
recompute
.
checkpoints
=
exact_ckpts
[:]
logs
=
{
'Model Class'
:
self
.
_model
.
__class__
.
__name__
,
'Applied Recompute ckpts'
:
exact_ckpts
,
}
self
.
_logger
.
info
(
logs
)
def
_reset_metrics
(
self
):
for
metric
in
self
.
_metrics
:
metric
.
reset
()
...
...
python/paddle/distributed/auto_parallel/interface.py
浏览文件 @
d7f7963f
...
...
@@ -195,7 +195,13 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
return
op
_g_recompute_idx
=
-
1
def
recompute
(
op
):
global
_g_recompute_idx
_g_recompute_idx
+=
1
class
RecomputeOperator
:
def
__init__
(
self
,
op
):
self
.
_op
=
op
...
...
@@ -209,7 +215,9 @@ def recompute(op):
for
idx
in
range
(
op_size
,
new_op_size
):
op
=
cur_block
.
ops
[
idx
]
op
.
_set_attr
(
"is_recompute@auto_parallel"
,
True
)
op
.
_set_attr
(
'op_namescope'
,
"/auto_parallel/rc_"
+
str
(
_g_recompute_idx
)
)
return
output
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
d7f7963f
...
...
@@ -33,6 +33,9 @@ from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute
,
)
OP_ROLE_KEY
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
__no_shape_var_type__
=
[
core
.
VarDesc
.
VarType
.
READER
,
core
.
VarDesc
.
VarType
.
STEP_SCOPES
,
...
...
@@ -1181,7 +1184,6 @@ def _get_split_indices(
def
set_grad_var_shape
(
program
,
dist_context
):
from
.operators.common
import
infer_shape
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
block
=
program
.
global_block
()
vars
=
block
.
vars
...
...
@@ -1315,10 +1317,6 @@ def set_grad_var_shape(program, dist_context):
grad_var
.
desc
.
set_shape
(
ref_shape
)
OP_ROLE_KEY
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
def
is_forward_op
(
op
):
op_role
=
int
(
op
.
attr
(
'op_role'
))
return
OP_ROLE_KEY
in
op
.
attr_names
and
(
...
...
@@ -1896,6 +1894,39 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
server_socket
.
close
()
def
set_recompute_ckpts
(
model
,
strategy
):
from
.interface
import
_g_recompute_idx
if
_g_recompute_idx
>
-
1
:
return
recompute
=
strategy
.
recompute
if
not
recompute
.
enable
:
return
# NOTE: hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
# extract ckpts by specific model
if
isinstance
(
model
,
paddle
.
nn
.
Layer
):
if
hasattr
(
model
,
"gpt"
)
and
model
.
__class__
.
__name__
in
[
'GPTForPretraining'
,
'GPTForPretrainingAuto'
,
]:
exact_ckpts
=
model
.
gpt
.
checkpoints
else
:
exact_ckpts
=
recompute
.
checkpoints
else
:
exact_ckpts
=
recompute
.
checkpoints
# modify strategy
recompute
.
checkpoints
=
exact_ckpts
[:]
logs
=
{
'Model Class'
:
model
.
__class__
.
__name__
,
'Applied Recompute ckpts'
:
exact_ckpts
,
}
logging
.
info
(
logs
)
def
get_input_split_info
(
cur_rank
,
var
,
dist_context
):
# deduce how the input data is split among the cluster
tensor_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
var
)
...
...
python/paddle/distributed/passes/auto_parallel_recompute.py
浏览文件 @
d7f7963f
...
...
@@ -17,7 +17,6 @@ import logging
from
.pass_base
import
PassBase
,
register_pass
from
paddle.fluid
import
core
,
unique_name
from
paddle.fluid
import
framework
as
framework
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.backward
import
_append_grad_suffix_
,
_get_no_grad_set_name
from
paddle.fluid.backward
import
ProgramStats
,
_rename_arg_
,
_find_op_path_
from
paddle.distributed.auto_parallel.dist_attribute
import
(
...
...
@@ -33,12 +32,21 @@ from paddle.distributed.auto_parallel.utils import (
)
def
_to_be_recomputed
(
op
):
return
op
.
has_attr
(
'op_namescope'
)
and
"/auto_parallel/rc_"
in
op
.
attr
(
'op_namescope'
)
class
RecomputeState
(
ProgramStats
):
def
__init__
(
self
,
block
,
ops
):
super
().
__init__
(
block
=
block
,
ops
=
ops
)
self
.
_block
=
block
self
.
_ops
=
ops
# {varname: {as_input_ops: op_idx, as_output_ops: op_idx}}
self
.
var_op_deps
=
{}
# {segment_name: op_idx}
self
.
seg_op_deps
=
{}
def
build_stats
(
self
):
for
i
,
op
in
enumerate
(
self
.
_ops
):
...
...
@@ -58,9 +66,34 @@ class RecomputeState(ProgramStats):
self
.
var_op_deps
[
name
][
"var_as_input_ops"
]
=
[]
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
=
[
i
]
def
get_recompute_segments
(
self
,
checkpoints
):
"""get recompute segments from checkpoints"""
if
not
_to_be_recomputed
(
op
):
continue
seg_name
=
op
.
attr
(
'op_namescope'
)
if
seg_name
not
in
self
.
seg_op_deps
:
self
.
seg_op_deps
[
seg_name
]
=
[
i
]
else
:
assert
(
self
.
seg_op_deps
[
seg_name
][
-
1
]
+
1
==
i
),
"The recompute segment's ops should be continuous"
self
.
seg_op_deps
[
seg_name
].
extend
([
i
])
def
get_recompute_segments
(
self
,
checkpoints_list
=
None
,
no_recompute_segments
=
[]
):
"""get recompute segments and checkpoints"""
segments
=
[]
checkpoints
=
checkpoints_list
or
[]
if
len
(
checkpoints
)
==
0
:
# the segments is marked by `auto.recompute()` api
for
segment_idx
in
self
.
seg_op_deps
.
values
():
if
len
(
segment_idx
)
==
1
:
continue
segments
.
append
([
segment_idx
[
0
],
segment_idx
[
-
1
]
+
1
])
checkpoints
.
extend
(
self
.
_ops
[
segment_idx
[
-
1
]].
output_arg_names
)
else
:
# the segments is marked by `strategy.checkpoints` api
start_idx
=
-
1
pre_segment_end_idx
=
-
1
while
start_idx
+
1
<
len
(
checkpoints
):
...
...
@@ -69,7 +102,9 @@ class RecomputeState(ProgramStats):
if
ckpt_name
not
in
self
.
var_op_deps
:
start_idx
+=
1
continue
op_idx_list
=
self
.
var_op_deps
[
ckpt_name
][
"var_as_output_ops"
]
op_idx_list
=
self
.
var_op_deps
[
ckpt_name
][
"var_as_output_ops"
]
if
op_idx_list
:
segments
.
append
([
0
,
max
(
op_idx_list
)
+
1
])
else
:
...
...
@@ -89,6 +124,15 @@ class RecomputeState(ProgramStats):
)
start_idx
+=
1
if
no_recompute_segments
:
for
i
in
reversed
(
sorted
(
no_recompute_segments
)):
assert
i
<
len
(
segments
),
"the no_recompute_segments idx [{}] should be lower the number of segment [{}]"
.
format
(
i
,
len
(
segments
)
)
segments
.
pop
(
i
)
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
segments
):
logging
.
info
(
"recompute segment[{}]"
.
format
(
i
))
logging
.
info
(
...
...
@@ -106,7 +150,10 @@ class RecomputeState(ProgramStats):
)
)
return
segments
return
segments
,
checkpoints
def
is_recompute
(
self
):
return
any
([
_to_be_recomputed
(
op
)
for
op
in
self
.
_ops
])
def
modify_forward_desc_for_recompute
(
self
,
dist_context
):
"""
...
...
@@ -162,6 +209,7 @@ class RecomputeState(ProgramStats):
outputs
=
{
"Out"
:
seed_var
},
attrs
=
{
"seed"
:
seed
,
"force_cpu"
:
True
},
)
seed_op
.
_set_attr
(
'op_namescope'
,
cur_op
.
attr
(
'op_namescope'
))
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
seed_op
,
ref_process_mesh
,
ref_dims_mapping
,
dist_context
...
...
@@ -196,7 +244,6 @@ def _get_stop_gradients(program, no_grad_set):
no_grad_set_name
=
set
()
for
var
in
program
.
list_vars
():
assert
isinstance
(
var
,
Variable
)
if
"@GRAD"
in
var
.
name
:
break
if
var
.
stop_gradient
:
...
...
@@ -244,14 +291,13 @@ class RecomputePass(PassBase):
self
.
set_attr
(
"loss"
,
None
)
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"no_grad_set"
,
None
)
self
.
set_attr
(
"no_recompute_segments"
,
[])
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
return
False
if
self
.
get_attr
(
"loss"
)
is
None
:
return
False
if
self
.
get_attr
(
"checkpoints"
)
is
None
:
return
False
return
True
def
_check_conflict
(
self
,
other_pass
):
...
...
@@ -259,25 +305,32 @@ class RecomputePass(PassBase):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
checkpoints
=
self
.
get_attr
(
"checkpoints"
)
no_recompute_segments
=
self
.
get_attr
(
"no_recompute_segments"
)
loss
=
self
.
get_attr
(
"loss"
)
no_grad_set
=
self
.
get_attr
(
"no_grad_set"
)
self
.
_dist_context
=
self
.
get_attr
(
"dist_context"
)
# 0. get op_path which is related to loss
main_block
=
main_program
.
global_block
()
no_grad_set_name
=
_get_stop_gradients
(
main_program
,
no_grad_set
)
# get op_path which is related to loss
op_path
=
_find_op_path_
(
main_block
,
[
loss
],
[],
no_grad_set_name
)
#
step 1:
build recompute state
#
1.
build recompute state
rc_state
=
RecomputeState
(
main_block
,
op_path
)
if
not
rc_state
.
is_recompute
()
and
not
checkpoints
:
return
# 2. get the segments to be recomputed
rc_state
.
modify_forward_desc_for_recompute
(
self
.
_dist_context
)
rc_state
.
build_stats
()
checkpoints
=
rc_state
.
sort_checkpoints
(
checkpoints
)
segments
=
rc_state
.
get_recompute_segments
(
checkpoints
)
if
segments
==
[]:
checkpoints
=
rc_state
.
sort_checkpoints
(
checkpoints
or
[])
segments
,
checkpoints
=
rc_state
.
get_recompute_segments
(
checkpoints
,
no_recompute_segments
)
if
segments
==
[]
or
checkpoints
==
[]:
return
#
step 2: get vars_should_be_hold
#
3. get vars that should be hold in memory
vars_should_be_hold
=
[]
for
segment
in
segments
:
vars_should_be_hold
.
extend
(
...
...
@@ -295,9 +348,9 @@ class RecomputePass(PassBase):
vars_should_be_hold
=
list
(
set
(
vars_should_be_hold
))
vars_in_memory
=
vars_should_be_hold
+
checkpoints
#
step 3: get recomputed fwd ops desc
var_name_dict
=
{}
ckpt_ops_dict
=
{}
#
4. get the fwd ops desc to be recomputed.
var_name_dict
=
{}
# varname --> varname.subprog_XXX
ckpt_ops_dict
=
{}
# ckpt_op_id --> segment_descs
buffer_block
=
main_block
.
program
.
_create_block
()
for
i
,
segment
in
enumerate
(
segments
[::
-
1
]):
fwd_ops
=
op_path
[
segment
[
0
]
:
segment
[
1
]]
...
...
@@ -362,7 +415,7 @@ class RecomputePass(PassBase):
ckpt_op
=
op_path
[
segment
[
1
]
-
1
]
ckpt_ops_dict
[
ckpt_op
.
desc
.
original_id
()]
=
[
True
,
segment_descs
]
#
step 4: insert recomputed fwd ops
#
5. insert recomputed fwd ops into backward parse
ops
=
main_block
.
ops
loss_op
=
get_loss_op
(
main_block
)
loss_op_idx
=
_find_op_index
(
main_block
,
loss_op
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
d7f7963f
...
...
@@ -72,6 +72,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_parallel_tuner_predict PROPERTIES TIMEOUT 120
)
py_test_modules
(
test_selective_recompute MODULES test_selective_recompute
)
set_tests_properties
(
test_selective_recompute PROPERTIES TIMEOUT 50
)
py_test_modules
(
test_while_op_completion MODULES test_while_op_completion
ENVS
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py
浏览文件 @
d7f7963f
...
...
@@ -22,13 +22,14 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
get_gpt_model
import
FakeDataset
,
generate_model
def
apply_pass
(
use_recompute
=
False
):
def
apply_pass
(
use_recompute
=
False
,
no_recompute_segments
=
[]
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_recompute
:
recompute
=
strategy
.
recompute
recompute
.
enable
=
True
recompute
.
no_recompute_segments
=
no_recompute_segments
return
strategy
...
...
@@ -53,10 +54,10 @@ class TestRecomputePass(unittest.TestCase):
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_recompute
=
False
):
def
get_engine
(
self
,
use_recompute
=
False
,
no_recompute_segments
=
[]
):
reset_prog
()
strategy
=
apply_pass
(
use_recompute
)
strategy
=
apply_pass
(
use_recompute
,
no_recompute_segments
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
self
.
clip_norm
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"mp"
)
...
...
@@ -88,6 +89,18 @@ class TestRecomputePass(unittest.TestCase):
rc_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc_losses
)
# mp2 selective recompute training
rc1_engine
=
self
.
get_engine
(
True
,
[
0
])
history
=
rc1_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc1_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc1_losses
)
def
test_recompute_pass_error
(
self
):
with
self
.
assertRaises
(
AssertionError
):
rc_engine
=
self
.
get_engine
(
True
,
[
2
])
history
=
rc_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py
0 → 100644
浏览文件 @
d7f7963f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
unittest
import
random
import
numpy
as
np
import
paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
get_gpt_model
import
FakeDataset
sys
.
path
.
append
(
".."
)
import
auto_parallel_gpt_model
as
modeling
from
auto_parallel_gpt_model
import
(
GPTModel
,
GPTForPretraining
,
GPTPretrainingCriterion
,
)
def
generate_model
(
use_new_recompute
,
recompute_granularity
):
modeling
.
init_global
()
modeling
.
_global_parallel_strategy
=
"serial"
modeling
.
_global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
],
dim_names
=
[
"x"
])
gpt
=
GPTModel
(
vocab_size
=
1000
,
hidden_size
=
64
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
intermediate_size
=
256
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.0
,
attention_probs_dropout_prob
=
0.0
,
max_position_embeddings
=
1024
,
type_vocab_size
=
1
,
initializer_range
=
0.02
,
pad_token_id
=
0
,
eos_token_id
=
7
,
bos_token_id
=
0
,
eol_token_id
=
3
,
use_new_recompute
=
use_new_recompute
,
recompute_granularity
=
recompute_granularity
,
)
model
=
GPTForPretraining
(
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
)
criterion
=
GPTPretrainingCriterion
()
return
model
,
criterion
def
apply_pass
(
use_recompute
=
False
,
no_recompute_segments
=
[]):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_recompute
:
recompute
=
strategy
.
recompute
recompute
.
enable
=
True
recompute
.
no_recompute_segments
=
no_recompute_segments
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
TestRecomputePassWithRecomputeAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
rtol
=
1e-6
self
.
atol
=
1e-8
self
.
batch_size
=
1
self
.
batch_num
=
2
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
random
.
seed
(
2022
)
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_recompute
=
False
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
no_recompute_segments
=
[],
):
reset_prog
()
strategy
=
apply_pass
(
use_recompute
,
no_recompute_segments
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
self
.
clip_norm
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
use_new_recompute
,
recompute_granularity
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_results
(
self
,
ref_losses
,
check_losses
):
np
.
testing
.
assert_allclose
(
ref_losses
,
check_losses
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
,
err_msg
=
'pass {} has wrong results!,
\n
u={}
\n
v={}
\n
diff={}'
.
format
(
__class__
,
ref_losses
,
check_losses
,
ref_losses
-
check_losses
),
)
def
recompute_vars
(
self
,
program
):
return
list
(
filter
(
lambda
a
:
"subprog"
in
a
.
name
,
program
.
list_vars
()))
def
test_recompute_pass
(
self
):
# mp2 training
mp_engine
=
self
.
get_engine
()
history
=
mp_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
mp_losses
=
np
.
array
(
history
.
history
[
"loss"
])
# mp2 recompute with old api
rc4_engine
=
self
.
get_engine
(
True
,
False
)
history
=
rc4_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc4_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc4_losses
)
# mp2 recompute core_attn
rc1_engine
=
self
.
get_engine
(
True
,
True
,
"core_attn"
,
[
0
])
history
=
rc1_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc1_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc1_losses
)
# mp2 recompute full_attn
rc2_engine
=
self
.
get_engine
(
True
,
True
,
"full_attn"
)
history
=
rc2_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc2_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc2_losses
)
# mp2 recompute full
rc3_engine
=
self
.
get_engine
(
True
,
True
,
"full"
)
history
=
rc3_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc3_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc3_losses
)
rc0_vars
=
self
.
recompute_vars
(
mp_engine
.
main_program
)
rc1_vars
=
self
.
recompute_vars
(
rc1_engine
.
main_program
)
rc2_vars
=
self
.
recompute_vars
(
rc2_engine
.
main_program
)
rc3_vars
=
self
.
recompute_vars
(
rc3_engine
.
main_program
)
assert
rc0_vars
==
[]
assert
len
(
rc1_vars
)
<
len
(
rc2_vars
)
and
len
(
rc2_vars
)
<
len
(
rc3_vars
)
def
test_recompute_pass_error
(
self
):
with
self
.
assertRaises
(
AssertionError
):
rc_engine
=
self
.
get_engine
(
True
,
True
,
"full"
,
[
2
])
history
=
rc_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
浏览文件 @
d7f7963f
...
...
@@ -57,6 +57,8 @@ class MultiHeadAttention(nn.Layer):
bias_attr
=
None
,
fuse
=
False
,
mesh_idx
=
None
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
...
...
@@ -67,6 +69,9 @@ class MultiHeadAttention(nn.Layer):
self
.
need_weights
=
need_weights
self
.
fuse
=
fuse
self
.
mesh_idx
=
mesh_idx
self
.
use_new_recompute
=
use_new_recompute
self
.
recompute_granularity
=
recompute_granularity
self
.
head_dim
=
embed_dim
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
self
.
embed_dim
...
...
@@ -225,6 +230,27 @@ class MultiHeadAttention(nn.Layer):
# incremental_state with initial value, mainly for usage like UniLM
return
self
.
Cache
(
key
,
value
)
def
core_attn
(
self
,
q
,
k
,
v
,
attn_mask
):
product
=
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
if
attn_mask
is
not
None
:
product
=
product
+
attn_mask
weights
=
F
.
softmax
(
product
)
if
self
.
dropout
:
weights
=
F
.
dropout
(
weights
,
self
.
dropout
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
,
)
out
=
tensor
.
matmul
(
weights
,
v
)
# combine heads
out
=
tensor
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
tensor
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
return
out
,
weights
def
forward
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
use_cache
=
False
,
cache
=
None
):
...
...
@@ -244,23 +270,12 @@ class MultiHeadAttention(nn.Layer):
q
,
k
,
v
,
cache
=
self
.
_prepare_qkv
(
query
,
key
,
value
,
use_cache
,
cache
)
product
=
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
if
attn_mask
is
not
None
:
product
=
product
+
attn_mask
weights
=
F
.
softmax
(
product
)
if
self
.
dropout
:
weights
=
F
.
dropout
(
weights
,
self
.
dropout
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
,
)
out
=
tensor
.
matmul
(
weights
,
v
)
# combine heads
out
=
tensor
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
tensor
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
if
self
.
use_new_recompute
and
self
.
recompute_granularity
==
"core_attn"
:
out
,
weights
=
auto
.
recompute
(
self
.
core_attn
)(
q
,
k
,
v
,
attn_mask
)
else
:
out
,
weights
=
self
.
core_attn
(
q
,
k
,
v
,
attn_mask
)
# project to output
out
=
self
.
out_proj
(
out
)
if
_global_parallel_strategy
==
"mp"
:
...
...
@@ -295,12 +310,22 @@ class TransformerDecoder(nn.Layer):
TransformerDecoder is a stack of N decoder layers.
"""
def
__init__
(
self
,
decoder_layers
,
num_layers
,
norm
=
None
,
hidden_size
=
None
):
def
__init__
(
self
,
decoder_layers
,
num_layers
,
norm
=
None
,
hidden_size
=
None
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
):
super
().
__init__
()
self
.
num_layers
=
num_layers
self
.
layers
=
decoder_layers
self
.
norm
=
norm
self
.
use_new_recompute
=
use_new_recompute
self
.
recompute_granularity
=
recompute_granularity
if
norm
==
"LayerNorm"
:
self
.
norm
=
nn
.
LayerNorm
(
hidden_size
)
elif
norm
is
not
None
:
...
...
@@ -348,47 +373,13 @@ class TransformerDecoder(nn.Layer):
DPMPPP_MESH_LIST
[
0
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
for
i
,
mod
in
enumerate
(
self
.
layers
):
if
self
.
use_new_recompute
and
self
.
recompute_granularity
==
"full"
:
mod
=
auto
.
recompute
(
mod
)
if
cache
is
None
:
if
use_cache
:
if
_global_parallel_strategy
==
"pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
PP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
PP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
elif
_global_parallel_strategy
==
"mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
else
:
output
,
new_cache
=
mod
(
output
,
memory
,
...
...
@@ -398,89 +389,7 @@ class TransformerDecoder(nn.Layer):
)
new_caches
.
append
(
new_cache
)
else
:
if
_global_parallel_strategy
==
"pp"
:
output
=
auto
.
shard_op
(
mod
,
PP_MESH_LIST
[
mod
.
mesh_idx
])(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
PP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_pp"
:
output
=
auto
.
shard_op
(
mod
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
elif
_global_parallel_strategy
==
"mp_pp"
:
output
=
auto
.
shard_op
(
mod
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
=
auto
.
shard_op
(
mod
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
else
:
output
=
mod
(
output
,
memory
,
tgt_mask
=
tgt_mask
,
use_cache
=
use_cache
,
cache
=
cache
,
)
else
:
if
_global_parallel_strategy
==
"pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
PP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
PP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
elif
_global_parallel_strategy
==
"mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
output
=
mod
(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
else
:
output
,
new_cache
=
mod
(
output
,
...
...
@@ -490,7 +399,10 @@ class TransformerDecoder(nn.Layer):
cache
=
cache
[
i
],
)
new_caches
.
append
(
new_cache
)
if
not
self
.
use_new_recompute
:
self
.
checkpoints
.
append
(
output
.
name
)
if
self
.
norm
is
not
None
:
output
=
self
.
norm
(
output
)
return
output
if
use_cache
is
False
else
(
output
,
new_caches
)
...
...
@@ -528,6 +440,8 @@ class TransformerDecoderLayer(nn.Layer):
weight_attr
=
None
,
bias_attr
=
None
,
mesh_idx
=
None
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
):
self
.
_config
=
locals
()
self
.
_config
.
pop
(
"self"
)
...
...
@@ -537,8 +451,12 @@ class TransformerDecoderLayer(nn.Layer):
attn_dropout
=
dropout
if
attn_dropout
is
None
else
attn_dropout
act_dropout
=
dropout
if
act_dropout
is
None
else
act_dropout
self
.
normalize_before
=
normalize_before
self
.
use_new_recompute
=
use_new_recompute
self
.
recompute_granularity
=
recompute_granularity
weight_attrs
=
_convert_param_attr_to_list
(
weight_attr
,
3
)
bias_attrs
=
_convert_param_attr_to_list
(
bias_attr
,
3
)
self
.
self_attn
=
MultiHeadAttention
(
d_model
,
nhead
,
...
...
@@ -546,6 +464,8 @@ class TransformerDecoderLayer(nn.Layer):
weight_attr
=
weight_attrs
[
0
],
bias_attr
=
bias_attrs
[
0
],
mesh_idx
=
self
.
mesh_idx
,
use_new_recompute
=
self
.
use_new_recompute
,
recompute_granularity
=
self
.
recompute_granularity
,
)
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
,
weight_attrs
[
2
],
bias_attr
=
bias_attrs
[
2
]
...
...
@@ -563,12 +483,19 @@ class TransformerDecoderLayer(nn.Layer):
residual
=
tgt
if
self
.
normalize_before
:
tgt
=
self
.
norm1
(
tgt
)
if
self
.
use_new_recompute
and
self
.
recompute_granularity
==
"full_attn"
:
self_attn
=
auto
.
recompute
(
self
.
self_attn
)
else
:
self_attn
=
self
.
self_attn
if
use_cache
is
False
:
tgt
=
self
.
self
_attn
(
tgt
,
tgt
,
tgt
,
tgt_mask
,
use_cache
,
cache
)
tgt
=
self_attn
(
tgt
,
tgt
,
tgt
,
tgt_mask
,
use_cache
,
cache
)
else
:
tgt
,
incremental_cache
=
self
.
self
_attn
(
tgt
,
incremental_cache
=
self_attn
(
tgt
,
tgt
,
tgt
,
tgt_mask
,
use_cache
,
cache
)
tgt
=
residual
+
self
.
dropout1
(
tgt
)
if
not
self
.
normalize_before
:
tgt
=
self
.
norm1
(
tgt
)
...
...
@@ -716,12 +643,17 @@ class GPTModel(nn.Layer):
bos_token_id
=
0
,
eol_token_id
=
3
,
pp_degree
=
None
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
):
super
().
__init__
()
self
.
pad_token_id
=
pad_token_id
self
.
initializer_range
=
initializer_range
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab_size
self
.
use_new_recompute
=
use_new_recompute
self
.
recompute_granularity
=
recompute_granularity
self
.
layer_per_stage
=
None
self
.
pipline_mode
=
pp_degree
is
not
None
and
pp_degree
>
1
if
self
.
pipline_mode
:
...
...
@@ -734,6 +666,7 @@ class GPTModel(nn.Layer):
type_vocab_size
,
self
.
initializer_range
,
)
decoder_layers
=
nn
.
LayerList
()
for
i
in
range
(
num_hidden_layers
):
mesh_index
=
None
...
...
@@ -756,14 +689,19 @@ class GPTModel(nn.Layer):
),
bias_attr
=
None
,
mesh_idx
=
mesh_index
,
use_new_recompute
=
self
.
use_new_recompute
,
recompute_granularity
=
self
.
recompute_granularity
,
)
)
Decoder
=
TransformerDecoder
self
.
decoder
=
Decoder
(
decoder_layers
,
num_hidden_layers
,
norm
=
"LayerNorm"
,
hidden_size
=
hidden_size
,
use_new_recompute
=
self
.
use_new_recompute
,
recompute_granularity
=
self
.
recompute_granularity
,
)
self
.
checkpoints
=
[]
...
...
@@ -817,6 +755,7 @@ class GPTModel(nn.Layer):
use_cache
=
use_cache
,
cache
=
cache
,
)
if
not
self
.
use_new_recompute
:
self
.
checkpoints
.
extend
(
self
.
decoder
.
checkpoints
)
return
encoder_outputs
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录