未验证 提交 d7f7963f 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] selective recompute (#48111)

* [AutoParallel] selective recompute

* add cmakelist
上级 aafa9820
...@@ -55,6 +55,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug ...@@ -55,6 +55,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug
RECOMPUTE = "recompute" RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False) set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None) set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "no_recompute_segments", [])
set_field_default_config(RECOMPUTE, "enable_tuning", False) set_field_default_config(RECOMPUTE, "enable_tuning", False)
######################################### #########################################
......
...@@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
raise StopIteration raise StopIteration
def _infer_steps(self): def _infer_steps(self):
if isinstance(self.steps_per_epoch, int) and self.steps_per_epoch > 1: if isinstance(self.steps_per_epoch, int) and self.steps_per_epoch > 0:
return self.steps_per_epoch return self.steps_per_epoch
try: try:
if isinstance(self.dataset, IterableDataset): if isinstance(self.dataset, IterableDataset):
......
...@@ -610,7 +610,7 @@ class Engine: ...@@ -610,7 +610,7 @@ class Engine:
if mode != "train": if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True) serial_main_prog = serial_main_prog.clone(for_test=True)
self._set_recompute_ckpts() auto_utils.set_recompute_ckpts(self._model, self._strategy)
self._dist_contexts[mode] = DistributedContext( self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_main_prog,
serial_startup_prog, serial_startup_prog,
...@@ -1518,35 +1518,6 @@ class Engine: ...@@ -1518,35 +1518,6 @@ class Engine:
var_name = _to_name_str(var) var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars return var_name in self.main_program.global_block().vars
def _set_recompute_ckpts(self):
# NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
recompute = self._strategy.recompute
# extract ckpts by specific model
if isinstance(self._model, paddle.nn.Layer):
if hasattr(
self._model, "gpt"
) and self._model.__class__.__name__ in [
'GPTForPretraining',
'GPTForPretrainingAuto',
]:
exact_ckpts = self._model.gpt.checkpoints
else:
exact_ckpts = recompute.checkpoints
else:
exact_ckpts = recompute.checkpoints
# modify strategy
if recompute.enable:
recompute.checkpoints = exact_ckpts[:]
logs = {
'Model Class': self._model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts,
}
self._logger.info(logs)
def _reset_metrics(self): def _reset_metrics(self):
for metric in self._metrics: for metric in self._metrics:
metric.reset() metric.reset()
......
...@@ -195,7 +195,13 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): ...@@ -195,7 +195,13 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
return op return op
_g_recompute_idx = -1
def recompute(op): def recompute(op):
global _g_recompute_idx
_g_recompute_idx += 1
class RecomputeOperator: class RecomputeOperator:
def __init__(self, op): def __init__(self, op):
self._op = op self._op = op
...@@ -209,7 +215,9 @@ def recompute(op): ...@@ -209,7 +215,9 @@ def recompute(op):
for idx in range(op_size, new_op_size): for idx in range(op_size, new_op_size):
op = cur_block.ops[idx] op = cur_block.ops[idx]
op._set_attr("is_recompute@auto_parallel", True) op._set_attr(
'op_namescope', "/auto_parallel/rc_" + str(_g_recompute_idx)
)
return output return output
......
...@@ -33,6 +33,9 @@ from paddle.distributed.auto_parallel.dist_attribute import ( ...@@ -33,6 +33,9 @@ from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute, OperatorDistributedAttribute,
) )
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
__no_shape_var_type__ = [ __no_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES, core.VarDesc.VarType.STEP_SCOPES,
...@@ -1181,7 +1184,6 @@ def _get_split_indices( ...@@ -1181,7 +1184,6 @@ def _get_split_indices(
def set_grad_var_shape(program, dist_context): def set_grad_var_shape(program, dist_context):
from .operators.common import infer_shape from .operators.common import infer_shape
from paddle.distributed.fleet.meta_optimizers.common import OpRole
block = program.global_block() block = program.global_block()
vars = block.vars vars = block.vars
...@@ -1315,10 +1317,6 @@ def set_grad_var_shape(program, dist_context): ...@@ -1315,10 +1317,6 @@ def set_grad_var_shape(program, dist_context):
grad_var.desc.set_shape(ref_shape) grad_var.desc.set_shape(ref_shape)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op): def is_forward_op(op):
op_role = int(op.attr('op_role')) op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and ( return OP_ROLE_KEY in op.attr_names and (
...@@ -1896,6 +1894,39 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): ...@@ -1896,6 +1894,39 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
server_socket.close() server_socket.close()
def set_recompute_ckpts(model, strategy):
from .interface import _g_recompute_idx
if _g_recompute_idx > -1:
return
recompute = strategy.recompute
if not recompute.enable:
return
# NOTE: hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
# extract ckpts by specific model
if isinstance(model, paddle.nn.Layer):
if hasattr(model, "gpt") and model.__class__.__name__ in [
'GPTForPretraining',
'GPTForPretrainingAuto',
]:
exact_ckpts = model.gpt.checkpoints
else:
exact_ckpts = recompute.checkpoints
else:
exact_ckpts = recompute.checkpoints
# modify strategy
recompute.checkpoints = exact_ckpts[:]
logs = {
'Model Class': model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts,
}
logging.info(logs)
def get_input_split_info(cur_rank, var, dist_context): def get_input_split_info(cur_rank, var, dist_context):
# deduce how the input data is split among the cluster # deduce how the input data is split among the cluster
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
......
...@@ -17,7 +17,6 @@ import logging ...@@ -17,7 +17,6 @@ import logging
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid import framework as framework from paddle.fluid import framework as framework
from paddle.fluid.framework import Variable
from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name
from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_ from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import (
...@@ -33,12 +32,21 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -33,12 +32,21 @@ from paddle.distributed.auto_parallel.utils import (
) )
def _to_be_recomputed(op):
return op.has_attr('op_namescope') and "/auto_parallel/rc_" in op.attr(
'op_namescope'
)
class RecomputeState(ProgramStats): class RecomputeState(ProgramStats):
def __init__(self, block, ops): def __init__(self, block, ops):
super().__init__(block=block, ops=ops) super().__init__(block=block, ops=ops)
self._block = block self._block = block
self._ops = ops self._ops = ops
# {varname: {as_input_ops: op_idx, as_output_ops: op_idx}}
self.var_op_deps = {} self.var_op_deps = {}
# {segment_name: op_idx}
self.seg_op_deps = {}
def build_stats(self): def build_stats(self):
for i, op in enumerate(self._ops): for i, op in enumerate(self._ops):
...@@ -58,36 +66,72 @@ class RecomputeState(ProgramStats): ...@@ -58,36 +66,72 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_input_ops"] = [] self.var_op_deps[name]["var_as_input_ops"] = []
self.var_op_deps[name]["var_as_output_ops"] = [i] self.var_op_deps[name]["var_as_output_ops"] = [i]
def get_recompute_segments(self, checkpoints): if not _to_be_recomputed(op):
"""get recompute segments from checkpoints""" continue
seg_name = op.attr('op_namescope')
if seg_name not in self.seg_op_deps:
self.seg_op_deps[seg_name] = [i]
else:
assert (
self.seg_op_deps[seg_name][-1] + 1 == i
), "The recompute segment's ops should be continuous"
self.seg_op_deps[seg_name].extend([i])
def get_recompute_segments(
self, checkpoints_list=None, no_recompute_segments=[]
):
"""get recompute segments and checkpoints"""
segments = [] segments = []
start_idx = -1 checkpoints = checkpoints_list or []
pre_segment_end_idx = -1
while start_idx + 1 < len(checkpoints): if len(checkpoints) == 0:
if start_idx == -1: # the segments is marked by `auto.recompute()` api
ckpt_name = checkpoints[start_idx + 1] for segment_idx in self.seg_op_deps.values():
if ckpt_name not in self.var_op_deps: if len(segment_idx) == 1:
start_idx += 1
continue continue
op_idx_list = self.var_op_deps[ckpt_name]["var_as_output_ops"] segments.append([segment_idx[0], segment_idx[-1] + 1])
if op_idx_list: checkpoints.extend(self._ops[segment_idx[-1]].output_arg_names)
segments.append([0, max(op_idx_list) + 1]) else:
else: # the segments is marked by `strategy.checkpoints` api
flag, min_idx, max_idx = self.is_subgraph( start_idx = -1
[checkpoints[start_idx]], [checkpoints[start_idx + 1]] pre_segment_end_idx = -1
) while start_idx + 1 < len(checkpoints):
if flag: if start_idx == -1:
min_idx = self._update_segment_start( ckpt_name = checkpoints[start_idx + 1]
min_idx, pre_segment_end_idx if ckpt_name not in self.var_op_deps:
) start_idx += 1
segments.append([min_idx, max_idx + 1]) continue
op_idx_list = self.var_op_deps[ckpt_name][
"var_as_output_ops"
]
if op_idx_list:
segments.append([0, max(op_idx_list) + 1])
else: else:
logging.info( flag, min_idx, max_idx = self.is_subgraph(
"Could not recompute op range [{}] - [{}] ".format( [checkpoints[start_idx]], [checkpoints[start_idx + 1]]
min_idx, max_idx + 1
)
) )
start_idx += 1 if flag:
min_idx = self._update_segment_start(
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])
else:
logging.info(
"Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1
)
)
start_idx += 1
if no_recompute_segments:
for i in reversed(sorted(no_recompute_segments)):
assert i < len(
segments
), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format(
i, len(segments)
)
segments.pop(i)
for i, (idx1, idx2) in enumerate(segments): for i, (idx1, idx2) in enumerate(segments):
logging.info("recompute segment[{}]".format(i)) logging.info("recompute segment[{}]".format(i))
...@@ -106,7 +150,10 @@ class RecomputeState(ProgramStats): ...@@ -106,7 +150,10 @@ class RecomputeState(ProgramStats):
) )
) )
return segments return segments, checkpoints
def is_recompute(self):
return any([_to_be_recomputed(op) for op in self._ops])
def modify_forward_desc_for_recompute(self, dist_context): def modify_forward_desc_for_recompute(self, dist_context):
""" """
...@@ -162,6 +209,7 @@ class RecomputeState(ProgramStats): ...@@ -162,6 +209,7 @@ class RecomputeState(ProgramStats):
outputs={"Out": seed_var}, outputs={"Out": seed_var},
attrs={"seed": seed, "force_cpu": True}, attrs={"seed": seed, "force_cpu": True},
) )
seed_op._set_attr('op_namescope', cur_op.attr('op_namescope'))
# set new seed op's dist_attr # set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, ref_process_mesh, ref_dims_mapping, dist_context seed_op, ref_process_mesh, ref_dims_mapping, dist_context
...@@ -196,7 +244,6 @@ def _get_stop_gradients(program, no_grad_set): ...@@ -196,7 +244,6 @@ def _get_stop_gradients(program, no_grad_set):
no_grad_set_name = set() no_grad_set_name = set()
for var in program.list_vars(): for var in program.list_vars():
assert isinstance(var, Variable)
if "@GRAD" in var.name: if "@GRAD" in var.name:
break break
if var.stop_gradient: if var.stop_gradient:
...@@ -244,14 +291,13 @@ class RecomputePass(PassBase): ...@@ -244,14 +291,13 @@ class RecomputePass(PassBase):
self.set_attr("loss", None) self.set_attr("loss", None)
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("no_grad_set", None) self.set_attr("no_grad_set", None)
self.set_attr("no_recompute_segments", [])
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
return False return False
if self.get_attr("loss") is None: if self.get_attr("loss") is None:
return False return False
if self.get_attr("checkpoints") is None:
return False
return True return True
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
...@@ -259,25 +305,32 @@ class RecomputePass(PassBase): ...@@ -259,25 +305,32 @@ class RecomputePass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
checkpoints = self.get_attr("checkpoints") checkpoints = self.get_attr("checkpoints")
no_recompute_segments = self.get_attr("no_recompute_segments")
loss = self.get_attr("loss") loss = self.get_attr("loss")
no_grad_set = self.get_attr("no_grad_set") no_grad_set = self.get_attr("no_grad_set")
self._dist_context = self.get_attr("dist_context") self._dist_context = self.get_attr("dist_context")
# 0. get op_path which is related to loss
main_block = main_program.global_block() main_block = main_program.global_block()
no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) no_grad_set_name = _get_stop_gradients(main_program, no_grad_set)
# get op_path which is related to loss
op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name) op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name)
# step 1: build recompute state # 1. build recompute state
rc_state = RecomputeState(main_block, op_path) rc_state = RecomputeState(main_block, op_path)
if not rc_state.is_recompute() and not checkpoints:
return
# 2. get the segments to be recomputed
rc_state.modify_forward_desc_for_recompute(self._dist_context) rc_state.modify_forward_desc_for_recompute(self._dist_context)
rc_state.build_stats() rc_state.build_stats()
checkpoints = rc_state.sort_checkpoints(checkpoints) checkpoints = rc_state.sort_checkpoints(checkpoints or [])
segments = rc_state.get_recompute_segments(checkpoints) segments, checkpoints = rc_state.get_recompute_segments(
if segments == []: checkpoints, no_recompute_segments
)
if segments == [] or checkpoints == []:
return return
# step 2: get vars_should_be_hold # 3. get vars that should be hold in memory
vars_should_be_hold = [] vars_should_be_hold = []
for segment in segments: for segment in segments:
vars_should_be_hold.extend( vars_should_be_hold.extend(
...@@ -295,9 +348,9 @@ class RecomputePass(PassBase): ...@@ -295,9 +348,9 @@ class RecomputePass(PassBase):
vars_should_be_hold = list(set(vars_should_be_hold)) vars_should_be_hold = list(set(vars_should_be_hold))
vars_in_memory = vars_should_be_hold + checkpoints vars_in_memory = vars_should_be_hold + checkpoints
# step 3: get recomputed fwd ops desc # 4. get the fwd ops desc to be recomputed.
var_name_dict = {} var_name_dict = {} # varname --> varname.subprog_XXX
ckpt_ops_dict = {} ckpt_ops_dict = {} # ckpt_op_id --> segment_descs
buffer_block = main_block.program._create_block() buffer_block = main_block.program._create_block()
for i, segment in enumerate(segments[::-1]): for i, segment in enumerate(segments[::-1]):
fwd_ops = op_path[segment[0] : segment[1]] fwd_ops = op_path[segment[0] : segment[1]]
...@@ -362,7 +415,7 @@ class RecomputePass(PassBase): ...@@ -362,7 +415,7 @@ class RecomputePass(PassBase):
ckpt_op = op_path[segment[1] - 1] ckpt_op = op_path[segment[1] - 1]
ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs] ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs]
# step 4: insert recomputed fwd ops # 5. insert recomputed fwd ops into backward parse
ops = main_block.ops ops = main_block.ops
loss_op = get_loss_op(main_block) loss_op = get_loss_op(main_block)
loss_op_idx = _find_op_index(main_block, loss_op) loss_op_idx = _find_op_index(main_block, loss_op)
......
...@@ -72,6 +72,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -72,6 +72,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_parallel_tuner_predict MODULES py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS}) test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_selective_recompute MODULES test_selective_recompute)
set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS}) ENVS ${dist_ENVS})
......
...@@ -22,13 +22,14 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -22,13 +22,14 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import FakeDataset, generate_model from get_gpt_model import FakeDataset, generate_model
def apply_pass(use_recompute=False): def apply_pass(use_recompute=False, no_recompute_segments=[]):
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
strategy.reinit = True strategy.reinit = True
if use_recompute: if use_recompute:
recompute = strategy.recompute recompute = strategy.recompute
recompute.enable = True recompute.enable = True
recompute.no_recompute_segments = no_recompute_segments
return strategy return strategy
...@@ -53,10 +54,10 @@ class TestRecomputePass(unittest.TestCase): ...@@ -53,10 +54,10 @@ class TestRecomputePass(unittest.TestCase):
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place) engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False): def get_engine(self, use_recompute=False, no_recompute_segments=[]):
reset_prog() reset_prog()
strategy = apply_pass(use_recompute) strategy = apply_pass(use_recompute, no_recompute_segments)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp") model, loss = generate_model("mp")
...@@ -88,6 +89,18 @@ class TestRecomputePass(unittest.TestCase): ...@@ -88,6 +89,18 @@ class TestRecomputePass(unittest.TestCase):
rc_losses = np.array(history.history["loss"]) rc_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc_losses) self.check_results(mp_losses, rc_losses)
# mp2 selective recompute training
rc1_engine = self.get_engine(True, [0])
history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc1_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc1_losses)
def test_recompute_pass_error(self):
with self.assertRaises(AssertionError):
rc_engine = self.get_engine(True, [2])
history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import random
import numpy as np
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import FakeDataset
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTModel,
GPTForPretraining,
GPTPretrainingCriterion,
)
def generate_model(use_new_recompute, recompute_granularity):
modeling.init_global()
modeling._global_parallel_strategy = "serial"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0], dim_names=["x"])
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
use_new_recompute=use_new_recompute,
recompute_granularity=recompute_granularity,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
criterion = GPTPretrainingCriterion()
return model, criterion
def apply_pass(use_recompute=False, no_recompute_segments=[]):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
recompute.no_recompute_segments = no_recompute_segments
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRecomputePassWithRecomputeAPI(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 2
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(
self,
use_recompute=False,
use_new_recompute=False,
recompute_granularity="full",
no_recompute_segments=[],
):
reset_prog()
strategy = apply_pass(use_recompute, no_recompute_segments)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model(use_new_recompute, recompute_granularity)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def recompute_vars(self, program):
return list(filter(lambda a: "subprog" in a.name, program.list_vars()))
def test_recompute_pass(self):
# mp2 training
mp_engine = self.get_engine()
history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(history.history["loss"])
# mp2 recompute with old api
rc4_engine = self.get_engine(True, False)
history = rc4_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc4_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc4_losses)
# mp2 recompute core_attn
rc1_engine = self.get_engine(True, True, "core_attn", [0])
history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc1_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc1_losses)
# mp2 recompute full_attn
rc2_engine = self.get_engine(True, True, "full_attn")
history = rc2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc2_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc2_losses)
# mp2 recompute full
rc3_engine = self.get_engine(True, True, "full")
history = rc3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc3_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc3_losses)
rc0_vars = self.recompute_vars(mp_engine.main_program)
rc1_vars = self.recompute_vars(rc1_engine.main_program)
rc2_vars = self.recompute_vars(rc2_engine.main_program)
rc3_vars = self.recompute_vars(rc3_engine.main_program)
assert rc0_vars == []
assert len(rc1_vars) < len(rc2_vars) and len(rc2_vars) < len(rc3_vars)
def test_recompute_pass_error(self):
with self.assertRaises(AssertionError):
rc_engine = self.get_engine(True, True, "full", [2])
history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
if __name__ == "__main__":
unittest.main()
...@@ -57,6 +57,8 @@ class MultiHeadAttention(nn.Layer): ...@@ -57,6 +57,8 @@ class MultiHeadAttention(nn.Layer):
bias_attr=None, bias_attr=None,
fuse=False, fuse=False,
mesh_idx=None, mesh_idx=None,
use_new_recompute=False,
recompute_granularity="full",
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -67,6 +69,9 @@ class MultiHeadAttention(nn.Layer): ...@@ -67,6 +69,9 @@ class MultiHeadAttention(nn.Layer):
self.need_weights = need_weights self.need_weights = need_weights
self.fuse = fuse self.fuse = fuse
self.mesh_idx = mesh_idx self.mesh_idx = mesh_idx
self.use_new_recompute = use_new_recompute
self.recompute_granularity = recompute_granularity
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert ( assert (
self.head_dim * num_heads == self.embed_dim self.head_dim * num_heads == self.embed_dim
...@@ -225,6 +230,27 @@ class MultiHeadAttention(nn.Layer): ...@@ -225,6 +230,27 @@ class MultiHeadAttention(nn.Layer):
# incremental_state with initial value, mainly for usage like UniLM # incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value) return self.Cache(key, value)
def core_attn(self, q, k, v, attn_mask):
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train",
)
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
return out, weights
def forward( def forward(
self, query, key, value, attn_mask=None, use_cache=False, cache=None self, query, key, value, attn_mask=None, use_cache=False, cache=None
): ):
...@@ -244,23 +270,12 @@ class MultiHeadAttention(nn.Layer): ...@@ -244,23 +270,12 @@ class MultiHeadAttention(nn.Layer):
q, k, v, cache = self._prepare_qkv( q, k, v, cache = self._prepare_qkv(
query, key, value, use_cache, cache query, key, value, use_cache, cache
) )
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5 if self.use_new_recompute and self.recompute_granularity == "core_attn":
) out, weights = auto.recompute(self.core_attn)(q, k, v, attn_mask)
if attn_mask is not None: else:
product = product + attn_mask out, weights = self.core_attn(q, k, v, attn_mask)
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train",
)
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
...@@ -295,12 +310,22 @@ class TransformerDecoder(nn.Layer): ...@@ -295,12 +310,22 @@ class TransformerDecoder(nn.Layer):
TransformerDecoder is a stack of N decoder layers. TransformerDecoder is a stack of N decoder layers.
""" """
def __init__(self, decoder_layers, num_layers, norm=None, hidden_size=None): def __init__(
self,
decoder_layers,
num_layers,
norm=None,
hidden_size=None,
use_new_recompute=False,
recompute_granularity="full",
):
super().__init__() super().__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.layers = decoder_layers self.layers = decoder_layers
self.norm = norm self.norm = norm
self.use_new_recompute = use_new_recompute
self.recompute_granularity = recompute_granularity
if norm == "LayerNorm": if norm == "LayerNorm":
self.norm = nn.LayerNorm(hidden_size) self.norm = nn.LayerNorm(hidden_size)
elif norm is not None: elif norm is not None:
...@@ -348,149 +373,36 @@ class TransformerDecoder(nn.Layer): ...@@ -348,149 +373,36 @@ class TransformerDecoder(nn.Layer):
DPMPPP_MESH_LIST[0], DPMPPP_MESH_LIST[0],
["x"] + [None for i in range(len(output.shape) - 1)], ["x"] + [None for i in range(len(output.shape) - 1)],
) )
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
if self.use_new_recompute and self.recompute_granularity == "full":
mod = auto.recompute(mod)
if cache is None: if cache is None:
if use_cache: if use_cache:
if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op(
mod, PP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_pp":
output, new_cache = auto.shard_op(
mod, DPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPPP_MESH_LIST[mod.mesh_idx],
["x"]
+ [None for i in range(len(output.shape) - 1)],
)
elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_mp_pp":
output, new_cache = auto.shard_op(
mod, DPMPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPMPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
else:
output, new_cache = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache,
)
new_caches.append(new_cache)
else:
if _global_parallel_strategy == "pp":
output = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])(
output, memory, tgt_mask, use_cache, cache
)
auto.shard_tensor(
output,
PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_pp":
output = auto.shard_op(
mod, DPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPPP_MESH_LIST[mod.mesh_idx],
["x"]
+ [None for i in range(len(output.shape) - 1)],
)
elif _global_parallel_strategy == "mp_pp":
output = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_mp_pp":
output = auto.shard_op(
mod, DPMPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPMPPP_MESH_LIST[mod.mesh_idx],
["x"]
+ [None for i in range(len(output.shape) - 1)],
)
else:
output = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache,
)
else:
if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op(
mod, PP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_pp":
output, new_cache = auto.shard_op(
mod, DPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPPP_MESH_LIST[mod.mesh_idx],
["x"] + [None for i in range(len(output.shape) - 1)],
)
elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_mp_pp":
output, new_cache = auto.shard_op(
mod, DPMPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPMPPP_MESH_LIST[mod.mesh_idx],
["x"] + [None for i in range(len(output.shape) - 1)],
)
else:
output, new_cache = mod( output, new_cache = mod(
output, output,
memory, memory,
tgt_mask=tgt_mask, tgt_mask=tgt_mask,
use_cache=use_cache, use_cache=use_cache,
cache=cache[i], cache=cache,
) )
new_caches.append(new_cache)
else:
output = mod(output, memory, tgt_mask, use_cache, cache)
else:
output, new_cache = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache[i],
)
new_caches.append(new_cache) new_caches.append(new_cache)
self.checkpoints.append(output.name)
if not self.use_new_recompute:
self.checkpoints.append(output.name)
if self.norm is not None: if self.norm is not None:
output = self.norm(output) output = self.norm(output)
return output if use_cache is False else (output, new_caches) return output if use_cache is False else (output, new_caches)
...@@ -528,6 +440,8 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -528,6 +440,8 @@ class TransformerDecoderLayer(nn.Layer):
weight_attr=None, weight_attr=None,
bias_attr=None, bias_attr=None,
mesh_idx=None, mesh_idx=None,
use_new_recompute=False,
recompute_granularity="full",
): ):
self._config = locals() self._config = locals()
self._config.pop("self") self._config.pop("self")
...@@ -537,8 +451,12 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -537,8 +451,12 @@ class TransformerDecoderLayer(nn.Layer):
attn_dropout = dropout if attn_dropout is None else attn_dropout attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before self.normalize_before = normalize_before
self.use_new_recompute = use_new_recompute
self.recompute_granularity = recompute_granularity
weight_attrs = _convert_param_attr_to_list(weight_attr, 3) weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
self.self_attn = MultiHeadAttention( self.self_attn = MultiHeadAttention(
d_model, d_model,
nhead, nhead,
...@@ -546,6 +464,8 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -546,6 +464,8 @@ class TransformerDecoderLayer(nn.Layer):
weight_attr=weight_attrs[0], weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0], bias_attr=bias_attrs[0],
mesh_idx=self.mesh_idx, mesh_idx=self.mesh_idx,
use_new_recompute=self.use_new_recompute,
recompute_granularity=self.recompute_granularity,
) )
self.linear1 = nn.Linear( self.linear1 = nn.Linear(
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2] d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]
...@@ -563,12 +483,19 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -563,12 +483,19 @@ class TransformerDecoderLayer(nn.Layer):
residual = tgt residual = tgt
if self.normalize_before: if self.normalize_before:
tgt = self.norm1(tgt) tgt = self.norm1(tgt)
if self.use_new_recompute and self.recompute_granularity == "full_attn":
self_attn = auto.recompute(self.self_attn)
else:
self_attn = self.self_attn
if use_cache is False: if use_cache is False:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) tgt = self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else: else:
tgt, incremental_cache = self.self_attn( tgt, incremental_cache = self_attn(
tgt, tgt, tgt, tgt_mask, use_cache, cache tgt, tgt, tgt, tgt_mask, use_cache, cache
) )
tgt = residual + self.dropout1(tgt) tgt = residual + self.dropout1(tgt)
if not self.normalize_before: if not self.normalize_before:
tgt = self.norm1(tgt) tgt = self.norm1(tgt)
...@@ -716,12 +643,17 @@ class GPTModel(nn.Layer): ...@@ -716,12 +643,17 @@ class GPTModel(nn.Layer):
bos_token_id=0, bos_token_id=0,
eol_token_id=3, eol_token_id=3,
pp_degree=None, pp_degree=None,
use_new_recompute=False,
recompute_granularity="full",
): ):
super().__init__() super().__init__()
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.use_new_recompute = use_new_recompute
self.recompute_granularity = recompute_granularity
self.layer_per_stage = None self.layer_per_stage = None
self.pipline_mode = pp_degree is not None and pp_degree > 1 self.pipline_mode = pp_degree is not None and pp_degree > 1
if self.pipline_mode: if self.pipline_mode:
...@@ -734,6 +666,7 @@ class GPTModel(nn.Layer): ...@@ -734,6 +666,7 @@ class GPTModel(nn.Layer):
type_vocab_size, type_vocab_size,
self.initializer_range, self.initializer_range,
) )
decoder_layers = nn.LayerList() decoder_layers = nn.LayerList()
for i in range(num_hidden_layers): for i in range(num_hidden_layers):
mesh_index = None mesh_index = None
...@@ -756,14 +689,19 @@ class GPTModel(nn.Layer): ...@@ -756,14 +689,19 @@ class GPTModel(nn.Layer):
), ),
bias_attr=None, bias_attr=None,
mesh_idx=mesh_index, mesh_idx=mesh_index,
use_new_recompute=self.use_new_recompute,
recompute_granularity=self.recompute_granularity,
) )
) )
Decoder = TransformerDecoder Decoder = TransformerDecoder
self.decoder = Decoder( self.decoder = Decoder(
decoder_layers, decoder_layers,
num_hidden_layers, num_hidden_layers,
norm="LayerNorm", norm="LayerNorm",
hidden_size=hidden_size, hidden_size=hidden_size,
use_new_recompute=self.use_new_recompute,
recompute_granularity=self.recompute_granularity,
) )
self.checkpoints = [] self.checkpoints = []
...@@ -817,7 +755,8 @@ class GPTModel(nn.Layer): ...@@ -817,7 +755,8 @@ class GPTModel(nn.Layer):
use_cache=use_cache, use_cache=use_cache,
cache=cache, cache=cache,
) )
self.checkpoints.extend(self.decoder.checkpoints) if not self.use_new_recompute:
self.checkpoints.extend(self.decoder.checkpoints)
return encoder_outputs return encoder_outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册