Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d7f7963f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d7f7963f
编写于
11月 18, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
11月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] selective recompute (#48111)
* [AutoParallel] selective recompute * add cmakelist
上级
aafa9820
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
428 addition
and
235 deletion
+428
-235
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+1
-0
python/paddle/distributed/auto_parallel/dist_loader.py
python/paddle/distributed/auto_parallel/dist_loader.py
+1
-1
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+1
-30
python/paddle/distributed/auto_parallel/interface.py
python/paddle/distributed/auto_parallel/interface.py
+9
-1
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+36
-5
python/paddle/distributed/passes/auto_parallel_recompute.py
python/paddle/distributed/passes/auto_parallel_recompute.py
+94
-41
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py
.../tests/unittests/auto_parallel/recompute_pass_unittest.py
+16
-3
python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py
...tests/unittests/auto_parallel/test_selective_recompute.py
+175
-0
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
...n/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
+93
-154
未找到文件。
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
d7f7963f
...
...
@@ -55,6 +55,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug
RECOMPUTE
=
"recompute"
set_field_default_config
(
RECOMPUTE
,
"enable"
,
False
)
set_field_default_config
(
RECOMPUTE
,
"checkpoints"
,
None
)
set_field_default_config
(
RECOMPUTE
,
"no_recompute_segments"
,
[])
set_field_default_config
(
RECOMPUTE
,
"enable_tuning"
,
False
)
#########################################
...
...
python/paddle/distributed/auto_parallel/dist_loader.py
浏览文件 @
d7f7963f
...
...
@@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
raise
StopIteration
def
_infer_steps
(
self
):
if
isinstance
(
self
.
steps_per_epoch
,
int
)
and
self
.
steps_per_epoch
>
1
:
if
isinstance
(
self
.
steps_per_epoch
,
int
)
and
self
.
steps_per_epoch
>
0
:
return
self
.
steps_per_epoch
try
:
if
isinstance
(
self
.
dataset
,
IterableDataset
):
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
d7f7963f
...
...
@@ -610,7 +610,7 @@ class Engine:
if
mode
!=
"train"
:
serial_main_prog
=
serial_main_prog
.
clone
(
for_test
=
True
)
self
.
_set_recompute_ckpts
(
)
auto_utils
.
set_recompute_ckpts
(
self
.
_model
,
self
.
_strategy
)
self
.
_dist_contexts
[
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_startup_prog
,
...
...
@@ -1518,35 +1518,6 @@ class Engine:
var_name
=
_to_name_str
(
var
)
return
var_name
in
self
.
main_program
.
global_block
().
vars
def
_set_recompute_ckpts
(
self
):
# NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
recompute
=
self
.
_strategy
.
recompute
# extract ckpts by specific model
if
isinstance
(
self
.
_model
,
paddle
.
nn
.
Layer
):
if
hasattr
(
self
.
_model
,
"gpt"
)
and
self
.
_model
.
__class__
.
__name__
in
[
'GPTForPretraining'
,
'GPTForPretrainingAuto'
,
]:
exact_ckpts
=
self
.
_model
.
gpt
.
checkpoints
else
:
exact_ckpts
=
recompute
.
checkpoints
else
:
exact_ckpts
=
recompute
.
checkpoints
# modify strategy
if
recompute
.
enable
:
recompute
.
checkpoints
=
exact_ckpts
[:]
logs
=
{
'Model Class'
:
self
.
_model
.
__class__
.
__name__
,
'Applied Recompute ckpts'
:
exact_ckpts
,
}
self
.
_logger
.
info
(
logs
)
def
_reset_metrics
(
self
):
for
metric
in
self
.
_metrics
:
metric
.
reset
()
...
...
python/paddle/distributed/auto_parallel/interface.py
浏览文件 @
d7f7963f
...
...
@@ -195,7 +195,13 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
return
op
_g_recompute_idx
=
-
1
def
recompute
(
op
):
global
_g_recompute_idx
_g_recompute_idx
+=
1
class
RecomputeOperator
:
def
__init__
(
self
,
op
):
self
.
_op
=
op
...
...
@@ -209,7 +215,9 @@ def recompute(op):
for
idx
in
range
(
op_size
,
new_op_size
):
op
=
cur_block
.
ops
[
idx
]
op
.
_set_attr
(
"is_recompute@auto_parallel"
,
True
)
op
.
_set_attr
(
'op_namescope'
,
"/auto_parallel/rc_"
+
str
(
_g_recompute_idx
)
)
return
output
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
d7f7963f
...
...
@@ -33,6 +33,9 @@ from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute
,
)
OP_ROLE_KEY
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
__no_shape_var_type__
=
[
core
.
VarDesc
.
VarType
.
READER
,
core
.
VarDesc
.
VarType
.
STEP_SCOPES
,
...
...
@@ -1181,7 +1184,6 @@ def _get_split_indices(
def
set_grad_var_shape
(
program
,
dist_context
):
from
.operators.common
import
infer_shape
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
block
=
program
.
global_block
()
vars
=
block
.
vars
...
...
@@ -1315,10 +1317,6 @@ def set_grad_var_shape(program, dist_context):
grad_var
.
desc
.
set_shape
(
ref_shape
)
OP_ROLE_KEY
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
def
is_forward_op
(
op
):
op_role
=
int
(
op
.
attr
(
'op_role'
))
return
OP_ROLE_KEY
in
op
.
attr_names
and
(
...
...
@@ -1896,6 +1894,39 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
server_socket
.
close
()
def
set_recompute_ckpts
(
model
,
strategy
):
from
.interface
import
_g_recompute_idx
if
_g_recompute_idx
>
-
1
:
return
recompute
=
strategy
.
recompute
if
not
recompute
.
enable
:
return
# NOTE: hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
# extract ckpts by specific model
if
isinstance
(
model
,
paddle
.
nn
.
Layer
):
if
hasattr
(
model
,
"gpt"
)
and
model
.
__class__
.
__name__
in
[
'GPTForPretraining'
,
'GPTForPretrainingAuto'
,
]:
exact_ckpts
=
model
.
gpt
.
checkpoints
else
:
exact_ckpts
=
recompute
.
checkpoints
else
:
exact_ckpts
=
recompute
.
checkpoints
# modify strategy
recompute
.
checkpoints
=
exact_ckpts
[:]
logs
=
{
'Model Class'
:
model
.
__class__
.
__name__
,
'Applied Recompute ckpts'
:
exact_ckpts
,
}
logging
.
info
(
logs
)
def
get_input_split_info
(
cur_rank
,
var
,
dist_context
):
# deduce how the input data is split among the cluster
tensor_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
var
)
...
...
python/paddle/distributed/passes/auto_parallel_recompute.py
浏览文件 @
d7f7963f
...
...
@@ -17,7 +17,6 @@ import logging
from
.pass_base
import
PassBase
,
register_pass
from
paddle.fluid
import
core
,
unique_name
from
paddle.fluid
import
framework
as
framework
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.backward
import
_append_grad_suffix_
,
_get_no_grad_set_name
from
paddle.fluid.backward
import
ProgramStats
,
_rename_arg_
,
_find_op_path_
from
paddle.distributed.auto_parallel.dist_attribute
import
(
...
...
@@ -33,12 +32,21 @@ from paddle.distributed.auto_parallel.utils import (
)
def
_to_be_recomputed
(
op
):
return
op
.
has_attr
(
'op_namescope'
)
and
"/auto_parallel/rc_"
in
op
.
attr
(
'op_namescope'
)
class
RecomputeState
(
ProgramStats
):
def
__init__
(
self
,
block
,
ops
):
super
().
__init__
(
block
=
block
,
ops
=
ops
)
self
.
_block
=
block
self
.
_ops
=
ops
# {varname: {as_input_ops: op_idx, as_output_ops: op_idx}}
self
.
var_op_deps
=
{}
# {segment_name: op_idx}
self
.
seg_op_deps
=
{}
def
build_stats
(
self
):
for
i
,
op
in
enumerate
(
self
.
_ops
):
...
...
@@ -58,36 +66,72 @@ class RecomputeState(ProgramStats):
self
.
var_op_deps
[
name
][
"var_as_input_ops"
]
=
[]
self
.
var_op_deps
[
name
][
"var_as_output_ops"
]
=
[
i
]
def
get_recompute_segments
(
self
,
checkpoints
):
"""get recompute segments from checkpoints"""
if
not
_to_be_recomputed
(
op
):
continue
seg_name
=
op
.
attr
(
'op_namescope'
)
if
seg_name
not
in
self
.
seg_op_deps
:
self
.
seg_op_deps
[
seg_name
]
=
[
i
]
else
:
assert
(
self
.
seg_op_deps
[
seg_name
][
-
1
]
+
1
==
i
),
"The recompute segment's ops should be continuous"
self
.
seg_op_deps
[
seg_name
].
extend
([
i
])
def
get_recompute_segments
(
self
,
checkpoints_list
=
None
,
no_recompute_segments
=
[]
):
"""get recompute segments and checkpoints"""
segments
=
[]
start_idx
=
-
1
pre_segment_end_idx
=
-
1
while
start_idx
+
1
<
len
(
checkpoints
):
if
start_idx
==
-
1
:
ckpt_name
=
checkpoints
[
start_idx
+
1
]
if
ckpt_name
not
in
self
.
var_op_deps
:
start_idx
+=
1
checkpoints
=
checkpoints_list
or
[]
if
len
(
checkpoints
)
==
0
:
# the segments is marked by `auto.recompute()` api
for
segment_idx
in
self
.
seg_op_deps
.
values
():
if
len
(
segment_idx
)
==
1
:
continue
op_idx_list
=
self
.
var_op_deps
[
ckpt_name
][
"var_as_output_ops"
]
if
op_idx_list
:
segments
.
append
([
0
,
max
(
op_idx_list
)
+
1
])
else
:
flag
,
min_idx
,
max_idx
=
self
.
is_subgraph
(
[
checkpoints
[
start_idx
]],
[
checkpoints
[
start_idx
+
1
]]
)
if
flag
:
min_idx
=
self
.
_update_segment_start
(
min_idx
,
pre_segment_end_idx
)
segments
.
append
([
min_idx
,
max_idx
+
1
])
segments
.
append
([
segment_idx
[
0
],
segment_idx
[
-
1
]
+
1
])
checkpoints
.
extend
(
self
.
_ops
[
segment_idx
[
-
1
]].
output_arg_names
)
else
:
# the segments is marked by `strategy.checkpoints` api
start_idx
=
-
1
pre_segment_end_idx
=
-
1
while
start_idx
+
1
<
len
(
checkpoints
):
if
start_idx
==
-
1
:
ckpt_name
=
checkpoints
[
start_idx
+
1
]
if
ckpt_name
not
in
self
.
var_op_deps
:
start_idx
+=
1
continue
op_idx_list
=
self
.
var_op_deps
[
ckpt_name
][
"var_as_output_ops"
]
if
op_idx_list
:
segments
.
append
([
0
,
max
(
op_idx_list
)
+
1
])
else
:
logging
.
info
(
"Could not recompute op range [{}] - [{}] "
.
format
(
min_idx
,
max_idx
+
1
)
flag
,
min_idx
,
max_idx
=
self
.
is_subgraph
(
[
checkpoints
[
start_idx
]],
[
checkpoints
[
start_idx
+
1
]]
)
start_idx
+=
1
if
flag
:
min_idx
=
self
.
_update_segment_start
(
min_idx
,
pre_segment_end_idx
)
segments
.
append
([
min_idx
,
max_idx
+
1
])
else
:
logging
.
info
(
"Could not recompute op range [{}] - [{}] "
.
format
(
min_idx
,
max_idx
+
1
)
)
start_idx
+=
1
if
no_recompute_segments
:
for
i
in
reversed
(
sorted
(
no_recompute_segments
)):
assert
i
<
len
(
segments
),
"the no_recompute_segments idx [{}] should be lower the number of segment [{}]"
.
format
(
i
,
len
(
segments
)
)
segments
.
pop
(
i
)
for
i
,
(
idx1
,
idx2
)
in
enumerate
(
segments
):
logging
.
info
(
"recompute segment[{}]"
.
format
(
i
))
...
...
@@ -106,7 +150,10 @@ class RecomputeState(ProgramStats):
)
)
return
segments
return
segments
,
checkpoints
def
is_recompute
(
self
):
return
any
([
_to_be_recomputed
(
op
)
for
op
in
self
.
_ops
])
def
modify_forward_desc_for_recompute
(
self
,
dist_context
):
"""
...
...
@@ -162,6 +209,7 @@ class RecomputeState(ProgramStats):
outputs
=
{
"Out"
:
seed_var
},
attrs
=
{
"seed"
:
seed
,
"force_cpu"
:
True
},
)
seed_op
.
_set_attr
(
'op_namescope'
,
cur_op
.
attr
(
'op_namescope'
))
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
seed_op
,
ref_process_mesh
,
ref_dims_mapping
,
dist_context
...
...
@@ -196,7 +244,6 @@ def _get_stop_gradients(program, no_grad_set):
no_grad_set_name
=
set
()
for
var
in
program
.
list_vars
():
assert
isinstance
(
var
,
Variable
)
if
"@GRAD"
in
var
.
name
:
break
if
var
.
stop_gradient
:
...
...
@@ -244,14 +291,13 @@ class RecomputePass(PassBase):
self
.
set_attr
(
"loss"
,
None
)
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"no_grad_set"
,
None
)
self
.
set_attr
(
"no_recompute_segments"
,
[])
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
return
False
if
self
.
get_attr
(
"loss"
)
is
None
:
return
False
if
self
.
get_attr
(
"checkpoints"
)
is
None
:
return
False
return
True
def
_check_conflict
(
self
,
other_pass
):
...
...
@@ -259,25 +305,32 @@ class RecomputePass(PassBase):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
checkpoints
=
self
.
get_attr
(
"checkpoints"
)
no_recompute_segments
=
self
.
get_attr
(
"no_recompute_segments"
)
loss
=
self
.
get_attr
(
"loss"
)
no_grad_set
=
self
.
get_attr
(
"no_grad_set"
)
self
.
_dist_context
=
self
.
get_attr
(
"dist_context"
)
# 0. get op_path which is related to loss
main_block
=
main_program
.
global_block
()
no_grad_set_name
=
_get_stop_gradients
(
main_program
,
no_grad_set
)
# get op_path which is related to loss
op_path
=
_find_op_path_
(
main_block
,
[
loss
],
[],
no_grad_set_name
)
#
step 1:
build recompute state
#
1.
build recompute state
rc_state
=
RecomputeState
(
main_block
,
op_path
)
if
not
rc_state
.
is_recompute
()
and
not
checkpoints
:
return
# 2. get the segments to be recomputed
rc_state
.
modify_forward_desc_for_recompute
(
self
.
_dist_context
)
rc_state
.
build_stats
()
checkpoints
=
rc_state
.
sort_checkpoints
(
checkpoints
)
segments
=
rc_state
.
get_recompute_segments
(
checkpoints
)
if
segments
==
[]:
checkpoints
=
rc_state
.
sort_checkpoints
(
checkpoints
or
[])
segments
,
checkpoints
=
rc_state
.
get_recompute_segments
(
checkpoints
,
no_recompute_segments
)
if
segments
==
[]
or
checkpoints
==
[]:
return
#
step 2: get vars_should_be_hold
#
3. get vars that should be hold in memory
vars_should_be_hold
=
[]
for
segment
in
segments
:
vars_should_be_hold
.
extend
(
...
...
@@ -295,9 +348,9 @@ class RecomputePass(PassBase):
vars_should_be_hold
=
list
(
set
(
vars_should_be_hold
))
vars_in_memory
=
vars_should_be_hold
+
checkpoints
#
step 3: get recomputed fwd ops desc
var_name_dict
=
{}
ckpt_ops_dict
=
{}
#
4. get the fwd ops desc to be recomputed.
var_name_dict
=
{}
# varname --> varname.subprog_XXX
ckpt_ops_dict
=
{}
# ckpt_op_id --> segment_descs
buffer_block
=
main_block
.
program
.
_create_block
()
for
i
,
segment
in
enumerate
(
segments
[::
-
1
]):
fwd_ops
=
op_path
[
segment
[
0
]
:
segment
[
1
]]
...
...
@@ -362,7 +415,7 @@ class RecomputePass(PassBase):
ckpt_op
=
op_path
[
segment
[
1
]
-
1
]
ckpt_ops_dict
[
ckpt_op
.
desc
.
original_id
()]
=
[
True
,
segment_descs
]
#
step 4: insert recomputed fwd ops
#
5. insert recomputed fwd ops into backward parse
ops
=
main_block
.
ops
loss_op
=
get_loss_op
(
main_block
)
loss_op_idx
=
_find_op_index
(
main_block
,
loss_op
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
d7f7963f
...
...
@@ -72,6 +72,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_parallel_tuner_predict PROPERTIES TIMEOUT 120
)
py_test_modules
(
test_selective_recompute MODULES test_selective_recompute
)
set_tests_properties
(
test_selective_recompute PROPERTIES TIMEOUT 50
)
py_test_modules
(
test_while_op_completion MODULES test_while_op_completion
ENVS
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py
浏览文件 @
d7f7963f
...
...
@@ -22,13 +22,14 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
get_gpt_model
import
FakeDataset
,
generate_model
def
apply_pass
(
use_recompute
=
False
):
def
apply_pass
(
use_recompute
=
False
,
no_recompute_segments
=
[]
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_recompute
:
recompute
=
strategy
.
recompute
recompute
.
enable
=
True
recompute
.
no_recompute_segments
=
no_recompute_segments
return
strategy
...
...
@@ -53,10 +54,10 @@ class TestRecomputePass(unittest.TestCase):
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_recompute
=
False
):
def
get_engine
(
self
,
use_recompute
=
False
,
no_recompute_segments
=
[]
):
reset_prog
()
strategy
=
apply_pass
(
use_recompute
)
strategy
=
apply_pass
(
use_recompute
,
no_recompute_segments
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
self
.
clip_norm
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"mp"
)
...
...
@@ -88,6 +89,18 @@ class TestRecomputePass(unittest.TestCase):
rc_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc_losses
)
# mp2 selective recompute training
rc1_engine
=
self
.
get_engine
(
True
,
[
0
])
history
=
rc1_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc1_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc1_losses
)
def
test_recompute_pass_error
(
self
):
with
self
.
assertRaises
(
AssertionError
):
rc_engine
=
self
.
get_engine
(
True
,
[
2
])
history
=
rc_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py
0 → 100644
浏览文件 @
d7f7963f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
unittest
import
random
import
numpy
as
np
import
paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
get_gpt_model
import
FakeDataset
sys
.
path
.
append
(
".."
)
import
auto_parallel_gpt_model
as
modeling
from
auto_parallel_gpt_model
import
(
GPTModel
,
GPTForPretraining
,
GPTPretrainingCriterion
,
)
def
generate_model
(
use_new_recompute
,
recompute_granularity
):
modeling
.
init_global
()
modeling
.
_global_parallel_strategy
=
"serial"
modeling
.
_global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
],
dim_names
=
[
"x"
])
gpt
=
GPTModel
(
vocab_size
=
1000
,
hidden_size
=
64
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
intermediate_size
=
256
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.0
,
attention_probs_dropout_prob
=
0.0
,
max_position_embeddings
=
1024
,
type_vocab_size
=
1
,
initializer_range
=
0.02
,
pad_token_id
=
0
,
eos_token_id
=
7
,
bos_token_id
=
0
,
eol_token_id
=
3
,
use_new_recompute
=
use_new_recompute
,
recompute_granularity
=
recompute_granularity
,
)
model
=
GPTForPretraining
(
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
)
criterion
=
GPTPretrainingCriterion
()
return
model
,
criterion
def
apply_pass
(
use_recompute
=
False
,
no_recompute_segments
=
[]):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_recompute
:
recompute
=
strategy
.
recompute
recompute
.
enable
=
True
recompute
.
no_recompute_segments
=
no_recompute_segments
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
TestRecomputePassWithRecomputeAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
rtol
=
1e-6
self
.
atol
=
1e-8
self
.
batch_size
=
1
self
.
batch_num
=
2
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
random
.
seed
(
2022
)
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_recompute
=
False
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
no_recompute_segments
=
[],
):
reset_prog
()
strategy
=
apply_pass
(
use_recompute
,
no_recompute_segments
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
self
.
clip_norm
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
use_new_recompute
,
recompute_granularity
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_results
(
self
,
ref_losses
,
check_losses
):
np
.
testing
.
assert_allclose
(
ref_losses
,
check_losses
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
,
err_msg
=
'pass {} has wrong results!,
\n
u={}
\n
v={}
\n
diff={}'
.
format
(
__class__
,
ref_losses
,
check_losses
,
ref_losses
-
check_losses
),
)
def
recompute_vars
(
self
,
program
):
return
list
(
filter
(
lambda
a
:
"subprog"
in
a
.
name
,
program
.
list_vars
()))
def
test_recompute_pass
(
self
):
# mp2 training
mp_engine
=
self
.
get_engine
()
history
=
mp_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
mp_losses
=
np
.
array
(
history
.
history
[
"loss"
])
# mp2 recompute with old api
rc4_engine
=
self
.
get_engine
(
True
,
False
)
history
=
rc4_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc4_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc4_losses
)
# mp2 recompute core_attn
rc1_engine
=
self
.
get_engine
(
True
,
True
,
"core_attn"
,
[
0
])
history
=
rc1_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc1_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc1_losses
)
# mp2 recompute full_attn
rc2_engine
=
self
.
get_engine
(
True
,
True
,
"full_attn"
)
history
=
rc2_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc2_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc2_losses
)
# mp2 recompute full
rc3_engine
=
self
.
get_engine
(
True
,
True
,
"full"
)
history
=
rc3_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
rc3_losses
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
mp_losses
,
rc3_losses
)
rc0_vars
=
self
.
recompute_vars
(
mp_engine
.
main_program
)
rc1_vars
=
self
.
recompute_vars
(
rc1_engine
.
main_program
)
rc2_vars
=
self
.
recompute_vars
(
rc2_engine
.
main_program
)
rc3_vars
=
self
.
recompute_vars
(
rc3_engine
.
main_program
)
assert
rc0_vars
==
[]
assert
len
(
rc1_vars
)
<
len
(
rc2_vars
)
and
len
(
rc2_vars
)
<
len
(
rc3_vars
)
def
test_recompute_pass_error
(
self
):
with
self
.
assertRaises
(
AssertionError
):
rc_engine
=
self
.
get_engine
(
True
,
True
,
"full"
,
[
2
])
history
=
rc_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
浏览文件 @
d7f7963f
...
...
@@ -57,6 +57,8 @@ class MultiHeadAttention(nn.Layer):
bias_attr
=
None
,
fuse
=
False
,
mesh_idx
=
None
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
...
...
@@ -67,6 +69,9 @@ class MultiHeadAttention(nn.Layer):
self
.
need_weights
=
need_weights
self
.
fuse
=
fuse
self
.
mesh_idx
=
mesh_idx
self
.
use_new_recompute
=
use_new_recompute
self
.
recompute_granularity
=
recompute_granularity
self
.
head_dim
=
embed_dim
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
self
.
embed_dim
...
...
@@ -225,6 +230,27 @@ class MultiHeadAttention(nn.Layer):
# incremental_state with initial value, mainly for usage like UniLM
return
self
.
Cache
(
key
,
value
)
def
core_attn
(
self
,
q
,
k
,
v
,
attn_mask
):
product
=
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
if
attn_mask
is
not
None
:
product
=
product
+
attn_mask
weights
=
F
.
softmax
(
product
)
if
self
.
dropout
:
weights
=
F
.
dropout
(
weights
,
self
.
dropout
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
,
)
out
=
tensor
.
matmul
(
weights
,
v
)
# combine heads
out
=
tensor
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
tensor
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
return
out
,
weights
def
forward
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
use_cache
=
False
,
cache
=
None
):
...
...
@@ -244,23 +270,12 @@ class MultiHeadAttention(nn.Layer):
q
,
k
,
v
,
cache
=
self
.
_prepare_qkv
(
query
,
key
,
value
,
use_cache
,
cache
)
product
=
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
if
attn_mask
is
not
None
:
product
=
product
+
attn_mask
weights
=
F
.
softmax
(
product
)
if
self
.
dropout
:
weights
=
F
.
dropout
(
weights
,
self
.
dropout
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
,
)
out
=
tensor
.
matmul
(
weights
,
v
)
# combine heads
out
=
tensor
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
tensor
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
if
self
.
use_new_recompute
and
self
.
recompute_granularity
==
"core_attn"
:
out
,
weights
=
auto
.
recompute
(
self
.
core_attn
)(
q
,
k
,
v
,
attn_mask
)
else
:
out
,
weights
=
self
.
core_attn
(
q
,
k
,
v
,
attn_mask
)
# project to output
out
=
self
.
out_proj
(
out
)
if
_global_parallel_strategy
==
"mp"
:
...
...
@@ -295,12 +310,22 @@ class TransformerDecoder(nn.Layer):
TransformerDecoder is a stack of N decoder layers.
"""
def
__init__
(
self
,
decoder_layers
,
num_layers
,
norm
=
None
,
hidden_size
=
None
):
def
__init__
(
self
,
decoder_layers
,
num_layers
,
norm
=
None
,
hidden_size
=
None
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
):
super
().
__init__
()
self
.
num_layers
=
num_layers
self
.
layers
=
decoder_layers
self
.
norm
=
norm
self
.
use_new_recompute
=
use_new_recompute
self
.
recompute_granularity
=
recompute_granularity
if
norm
==
"LayerNorm"
:
self
.
norm
=
nn
.
LayerNorm
(
hidden_size
)
elif
norm
is
not
None
:
...
...
@@ -348,149 +373,36 @@ class TransformerDecoder(nn.Layer):
DPMPPP_MESH_LIST
[
0
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
for
i
,
mod
in
enumerate
(
self
.
layers
):
if
self
.
use_new_recompute
and
self
.
recompute_granularity
==
"full"
:
mod
=
auto
.
recompute
(
mod
)
if
cache
is
None
:
if
use_cache
:
if
_global_parallel_strategy
==
"pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
PP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
PP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
elif
_global_parallel_strategy
==
"mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
else
:
output
,
new_cache
=
mod
(
output
,
memory
,
tgt_mask
=
tgt_mask
,
use_cache
=
use_cache
,
cache
=
cache
,
)
new_caches
.
append
(
new_cache
)
else
:
if
_global_parallel_strategy
==
"pp"
:
output
=
auto
.
shard_op
(
mod
,
PP_MESH_LIST
[
mod
.
mesh_idx
])(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
PP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_pp"
:
output
=
auto
.
shard_op
(
mod
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
elif
_global_parallel_strategy
==
"mp_pp"
:
output
=
auto
.
shard_op
(
mod
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
=
auto
.
shard_op
(
mod
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
else
:
output
=
mod
(
output
,
memory
,
tgt_mask
=
tgt_mask
,
use_cache
=
use_cache
,
cache
=
cache
,
)
else
:
if
_global_parallel_strategy
==
"pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
PP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
PP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
elif
_global_parallel_strategy
==
"mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
None
for
i
in
range
(
len
(
output
.
shape
))],
)
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
)(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
[
"x"
]
+
[
None
for
i
in
range
(
len
(
output
.
shape
)
-
1
)],
)
else
:
output
,
new_cache
=
mod
(
output
,
memory
,
tgt_mask
=
tgt_mask
,
use_cache
=
use_cache
,
cache
=
cache
[
i
]
,
cache
=
cache
,
)
new_caches
.
append
(
new_cache
)
else
:
output
=
mod
(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
else
:
output
,
new_cache
=
mod
(
output
,
memory
,
tgt_mask
=
tgt_mask
,
use_cache
=
use_cache
,
cache
=
cache
[
i
],
)
new_caches
.
append
(
new_cache
)
self
.
checkpoints
.
append
(
output
.
name
)
if
not
self
.
use_new_recompute
:
self
.
checkpoints
.
append
(
output
.
name
)
if
self
.
norm
is
not
None
:
output
=
self
.
norm
(
output
)
return
output
if
use_cache
is
False
else
(
output
,
new_caches
)
...
...
@@ -528,6 +440,8 @@ class TransformerDecoderLayer(nn.Layer):
weight_attr
=
None
,
bias_attr
=
None
,
mesh_idx
=
None
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
):
self
.
_config
=
locals
()
self
.
_config
.
pop
(
"self"
)
...
...
@@ -537,8 +451,12 @@ class TransformerDecoderLayer(nn.Layer):
attn_dropout
=
dropout
if
attn_dropout
is
None
else
attn_dropout
act_dropout
=
dropout
if
act_dropout
is
None
else
act_dropout
self
.
normalize_before
=
normalize_before
self
.
use_new_recompute
=
use_new_recompute
self
.
recompute_granularity
=
recompute_granularity
weight_attrs
=
_convert_param_attr_to_list
(
weight_attr
,
3
)
bias_attrs
=
_convert_param_attr_to_list
(
bias_attr
,
3
)
self
.
self_attn
=
MultiHeadAttention
(
d_model
,
nhead
,
...
...
@@ -546,6 +464,8 @@ class TransformerDecoderLayer(nn.Layer):
weight_attr
=
weight_attrs
[
0
],
bias_attr
=
bias_attrs
[
0
],
mesh_idx
=
self
.
mesh_idx
,
use_new_recompute
=
self
.
use_new_recompute
,
recompute_granularity
=
self
.
recompute_granularity
,
)
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
,
weight_attrs
[
2
],
bias_attr
=
bias_attrs
[
2
]
...
...
@@ -563,12 +483,19 @@ class TransformerDecoderLayer(nn.Layer):
residual
=
tgt
if
self
.
normalize_before
:
tgt
=
self
.
norm1
(
tgt
)
if
self
.
use_new_recompute
and
self
.
recompute_granularity
==
"full_attn"
:
self_attn
=
auto
.
recompute
(
self
.
self_attn
)
else
:
self_attn
=
self
.
self_attn
if
use_cache
is
False
:
tgt
=
self
.
self
_attn
(
tgt
,
tgt
,
tgt
,
tgt_mask
,
use_cache
,
cache
)
tgt
=
self_attn
(
tgt
,
tgt
,
tgt
,
tgt_mask
,
use_cache
,
cache
)
else
:
tgt
,
incremental_cache
=
self
.
self
_attn
(
tgt
,
incremental_cache
=
self_attn
(
tgt
,
tgt
,
tgt
,
tgt_mask
,
use_cache
,
cache
)
tgt
=
residual
+
self
.
dropout1
(
tgt
)
if
not
self
.
normalize_before
:
tgt
=
self
.
norm1
(
tgt
)
...
...
@@ -716,12 +643,17 @@ class GPTModel(nn.Layer):
bos_token_id
=
0
,
eol_token_id
=
3
,
pp_degree
=
None
,
use_new_recompute
=
False
,
recompute_granularity
=
"full"
,
):
super
().
__init__
()
self
.
pad_token_id
=
pad_token_id
self
.
initializer_range
=
initializer_range
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab_size
self
.
use_new_recompute
=
use_new_recompute
self
.
recompute_granularity
=
recompute_granularity
self
.
layer_per_stage
=
None
self
.
pipline_mode
=
pp_degree
is
not
None
and
pp_degree
>
1
if
self
.
pipline_mode
:
...
...
@@ -734,6 +666,7 @@ class GPTModel(nn.Layer):
type_vocab_size
,
self
.
initializer_range
,
)
decoder_layers
=
nn
.
LayerList
()
for
i
in
range
(
num_hidden_layers
):
mesh_index
=
None
...
...
@@ -756,14 +689,19 @@ class GPTModel(nn.Layer):
),
bias_attr
=
None
,
mesh_idx
=
mesh_index
,
use_new_recompute
=
self
.
use_new_recompute
,
recompute_granularity
=
self
.
recompute_granularity
,
)
)
Decoder
=
TransformerDecoder
self
.
decoder
=
Decoder
(
decoder_layers
,
num_hidden_layers
,
norm
=
"LayerNorm"
,
hidden_size
=
hidden_size
,
use_new_recompute
=
self
.
use_new_recompute
,
recompute_granularity
=
self
.
recompute_granularity
,
)
self
.
checkpoints
=
[]
...
...
@@ -817,7 +755,8 @@ class GPTModel(nn.Layer):
use_cache
=
use_cache
,
cache
=
cache
,
)
self
.
checkpoints
.
extend
(
self
.
decoder
.
checkpoints
)
if
not
self
.
use_new_recompute
:
self
.
checkpoints
.
extend
(
self
.
decoder
.
checkpoints
)
return
encoder_outputs
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录