diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 51afad94c535b6a2e2982db5c7e47720c9f35c00..857245b9be4257ae9536b8d53c614cf8e18d3f96 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -55,6 +55,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug RECOMPUTE = "recompute" set_field_default_config(RECOMPUTE, "enable", False) set_field_default_config(RECOMPUTE, "checkpoints", None) +set_field_default_config(RECOMPUTE, "no_recompute_segments", []) set_field_default_config(RECOMPUTE, "enable_tuning", False) ######################################### diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index f982f7458999e5b459dbd86f055d2bc6edc23c69..f0e0b8aa5a0d7a7029fd34995c119e14cf5e221c 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): raise StopIteration def _infer_steps(self): - if isinstance(self.steps_per_epoch, int) and self.steps_per_epoch > 1: + if isinstance(self.steps_per_epoch, int) and self.steps_per_epoch > 0: return self.steps_per_epoch try: if isinstance(self.dataset, IterableDataset): diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 116eaa97f1088530df16dc71b3f60167be532936..8e27b9aac6c703d1765ca11c08b438afef1e6805 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -610,7 +610,7 @@ class Engine: if mode != "train": serial_main_prog = serial_main_prog.clone(for_test=True) - self._set_recompute_ckpts() + auto_utils.set_recompute_ckpts(self._model, self._strategy) self._dist_contexts[mode] = DistributedContext( serial_main_prog, serial_startup_prog, @@ -1518,35 +1518,6 @@ class Engine: var_name = _to_name_str(var) return var_name in self.main_program.global_block().vars - def _set_recompute_ckpts(self): - # NOTE hack to enable recompute in engine api for GPT-3 - # TODO support more PaddleNLP/CV models here - - recompute = self._strategy.recompute - - # extract ckpts by specific model - if isinstance(self._model, paddle.nn.Layer): - if hasattr( - self._model, "gpt" - ) and self._model.__class__.__name__ in [ - 'GPTForPretraining', - 'GPTForPretrainingAuto', - ]: - exact_ckpts = self._model.gpt.checkpoints - else: - exact_ckpts = recompute.checkpoints - else: - exact_ckpts = recompute.checkpoints - - # modify strategy - if recompute.enable: - recompute.checkpoints = exact_ckpts[:] - logs = { - 'Model Class': self._model.__class__.__name__, - 'Applied Recompute ckpts': exact_ckpts, - } - self._logger.info(logs) - def _reset_metrics(self): for metric in self._metrics: metric.reset() diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index cc8afb4f27173b674a325a4042f0735ed0435f31..b85d85011a1fab09d2ea136eb9f1612b85b62de9 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -195,7 +195,13 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): return op +_g_recompute_idx = -1 + + def recompute(op): + global _g_recompute_idx + _g_recompute_idx += 1 + class RecomputeOperator: def __init__(self, op): self._op = op @@ -209,7 +215,9 @@ def recompute(op): for idx in range(op_size, new_op_size): op = cur_block.ops[idx] - op._set_attr("is_recompute@auto_parallel", True) + op._set_attr( + 'op_namescope', "/auto_parallel/rc_" + str(_g_recompute_idx) + ) return output diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 35b3483a31481c5af164cf5ec5a9358c2e6347a8..be4c68d97d840c8803c2e172d45e0fa700eb75cf 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -33,6 +33,9 @@ from paddle.distributed.auto_parallel.dist_attribute import ( OperatorDistributedAttribute, ) +OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() +OpRole = core.op_proto_and_checker_maker.OpRole + __no_shape_var_type__ = [ core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES, @@ -1181,7 +1184,6 @@ def _get_split_indices( def set_grad_var_shape(program, dist_context): from .operators.common import infer_shape - from paddle.distributed.fleet.meta_optimizers.common import OpRole block = program.global_block() vars = block.vars @@ -1315,10 +1317,6 @@ def set_grad_var_shape(program, dist_context): grad_var.desc.set_shape(ref_shape) -OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() -OpRole = core.op_proto_and_checker_maker.OpRole - - def is_forward_op(op): op_role = int(op.attr('op_role')) return OP_ROLE_KEY in op.attr_names and ( @@ -1896,6 +1894,39 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): server_socket.close() +def set_recompute_ckpts(model, strategy): + from .interface import _g_recompute_idx + + if _g_recompute_idx > -1: + return + + recompute = strategy.recompute + if not recompute.enable: + return + + # NOTE: hack to enable recompute in engine api for GPT-3 + # TODO support more PaddleNLP/CV models here + # extract ckpts by specific model + if isinstance(model, paddle.nn.Layer): + if hasattr(model, "gpt") and model.__class__.__name__ in [ + 'GPTForPretraining', + 'GPTForPretrainingAuto', + ]: + exact_ckpts = model.gpt.checkpoints + else: + exact_ckpts = recompute.checkpoints + else: + exact_ckpts = recompute.checkpoints + + # modify strategy + recompute.checkpoints = exact_ckpts[:] + logs = { + 'Model Class': model.__class__.__name__, + 'Applied Recompute ckpts': exact_ckpts, + } + logging.info(logs) + + def get_input_split_info(cur_rank, var, dist_context): # deduce how the input data is split among the cluster tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index b725ac004eb010891cc3afff39910658bedf133b..5bdbe9d2dd5d9d2cf1b424e2bdb25c3f68a5f76a 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -17,7 +17,6 @@ import logging from .pass_base import PassBase, register_pass from paddle.fluid import core, unique_name from paddle.fluid import framework as framework -from paddle.fluid.framework import Variable from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_ from paddle.distributed.auto_parallel.dist_attribute import ( @@ -33,12 +32,21 @@ from paddle.distributed.auto_parallel.utils import ( ) +def _to_be_recomputed(op): + return op.has_attr('op_namescope') and "/auto_parallel/rc_" in op.attr( + 'op_namescope' + ) + + class RecomputeState(ProgramStats): def __init__(self, block, ops): super().__init__(block=block, ops=ops) self._block = block self._ops = ops + # {varname: {as_input_ops: op_idx, as_output_ops: op_idx}} self.var_op_deps = {} + # {segment_name: op_idx} + self.seg_op_deps = {} def build_stats(self): for i, op in enumerate(self._ops): @@ -58,36 +66,72 @@ class RecomputeState(ProgramStats): self.var_op_deps[name]["var_as_input_ops"] = [] self.var_op_deps[name]["var_as_output_ops"] = [i] - def get_recompute_segments(self, checkpoints): - """get recompute segments from checkpoints""" + if not _to_be_recomputed(op): + continue + + seg_name = op.attr('op_namescope') + if seg_name not in self.seg_op_deps: + self.seg_op_deps[seg_name] = [i] + else: + assert ( + self.seg_op_deps[seg_name][-1] + 1 == i + ), "The recompute segment's ops should be continuous" + self.seg_op_deps[seg_name].extend([i]) + + def get_recompute_segments( + self, checkpoints_list=None, no_recompute_segments=[] + ): + """get recompute segments and checkpoints""" segments = [] - start_idx = -1 - pre_segment_end_idx = -1 - while start_idx + 1 < len(checkpoints): - if start_idx == -1: - ckpt_name = checkpoints[start_idx + 1] - if ckpt_name not in self.var_op_deps: - start_idx += 1 + checkpoints = checkpoints_list or [] + + if len(checkpoints) == 0: + # the segments is marked by `auto.recompute()` api + for segment_idx in self.seg_op_deps.values(): + if len(segment_idx) == 1: continue - op_idx_list = self.var_op_deps[ckpt_name]["var_as_output_ops"] - if op_idx_list: - segments.append([0, max(op_idx_list) + 1]) - else: - flag, min_idx, max_idx = self.is_subgraph( - [checkpoints[start_idx]], [checkpoints[start_idx + 1]] - ) - if flag: - min_idx = self._update_segment_start( - min_idx, pre_segment_end_idx - ) - segments.append([min_idx, max_idx + 1]) + segments.append([segment_idx[0], segment_idx[-1] + 1]) + checkpoints.extend(self._ops[segment_idx[-1]].output_arg_names) + else: + # the segments is marked by `strategy.checkpoints` api + start_idx = -1 + pre_segment_end_idx = -1 + while start_idx + 1 < len(checkpoints): + if start_idx == -1: + ckpt_name = checkpoints[start_idx + 1] + if ckpt_name not in self.var_op_deps: + start_idx += 1 + continue + op_idx_list = self.var_op_deps[ckpt_name][ + "var_as_output_ops" + ] + if op_idx_list: + segments.append([0, max(op_idx_list) + 1]) else: - logging.info( - "Could not recompute op range [{}] - [{}] ".format( - min_idx, max_idx + 1 - ) + flag, min_idx, max_idx = self.is_subgraph( + [checkpoints[start_idx]], [checkpoints[start_idx + 1]] ) - start_idx += 1 + if flag: + min_idx = self._update_segment_start( + min_idx, pre_segment_end_idx + ) + segments.append([min_idx, max_idx + 1]) + else: + logging.info( + "Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1 + ) + ) + start_idx += 1 + + if no_recompute_segments: + for i in reversed(sorted(no_recompute_segments)): + assert i < len( + segments + ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( + i, len(segments) + ) + segments.pop(i) for i, (idx1, idx2) in enumerate(segments): logging.info("recompute segment[{}]".format(i)) @@ -106,7 +150,10 @@ class RecomputeState(ProgramStats): ) ) - return segments + return segments, checkpoints + + def is_recompute(self): + return any([_to_be_recomputed(op) for op in self._ops]) def modify_forward_desc_for_recompute(self, dist_context): """ @@ -162,6 +209,7 @@ class RecomputeState(ProgramStats): outputs={"Out": seed_var}, attrs={"seed": seed, "force_cpu": True}, ) + seed_op._set_attr('op_namescope', cur_op.attr('op_namescope')) # set new seed op's dist_attr naive_set_dist_op_attr_for_program_by_mesh_and_mapping( seed_op, ref_process_mesh, ref_dims_mapping, dist_context @@ -196,7 +244,6 @@ def _get_stop_gradients(program, no_grad_set): no_grad_set_name = set() for var in program.list_vars(): - assert isinstance(var, Variable) if "@GRAD" in var.name: break if var.stop_gradient: @@ -244,14 +291,13 @@ class RecomputePass(PassBase): self.set_attr("loss", None) self.set_attr("dist_context", None) self.set_attr("no_grad_set", None) + self.set_attr("no_recompute_segments", []) def _check_self(self): if self.get_attr("dist_context") is None: return False if self.get_attr("loss") is None: return False - if self.get_attr("checkpoints") is None: - return False return True def _check_conflict(self, other_pass): @@ -259,25 +305,32 @@ class RecomputePass(PassBase): def _apply_single_impl(self, main_program, startup_program, context): checkpoints = self.get_attr("checkpoints") + no_recompute_segments = self.get_attr("no_recompute_segments") loss = self.get_attr("loss") no_grad_set = self.get_attr("no_grad_set") self._dist_context = self.get_attr("dist_context") + # 0. get op_path which is related to loss main_block = main_program.global_block() no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) - # get op_path which is related to loss op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name) - # step 1: build recompute state + # 1. build recompute state rc_state = RecomputeState(main_block, op_path) + if not rc_state.is_recompute() and not checkpoints: + return + + # 2. get the segments to be recomputed rc_state.modify_forward_desc_for_recompute(self._dist_context) rc_state.build_stats() - checkpoints = rc_state.sort_checkpoints(checkpoints) - segments = rc_state.get_recompute_segments(checkpoints) - if segments == []: + checkpoints = rc_state.sort_checkpoints(checkpoints or []) + segments, checkpoints = rc_state.get_recompute_segments( + checkpoints, no_recompute_segments + ) + if segments == [] or checkpoints == []: return - # step 2: get vars_should_be_hold + # 3. get vars that should be hold in memory vars_should_be_hold = [] for segment in segments: vars_should_be_hold.extend( @@ -295,9 +348,9 @@ class RecomputePass(PassBase): vars_should_be_hold = list(set(vars_should_be_hold)) vars_in_memory = vars_should_be_hold + checkpoints - # step 3: get recomputed fwd ops desc - var_name_dict = {} - ckpt_ops_dict = {} + # 4. get the fwd ops desc to be recomputed. + var_name_dict = {} # varname --> varname.subprog_XXX + ckpt_ops_dict = {} # ckpt_op_id --> segment_descs buffer_block = main_block.program._create_block() for i, segment in enumerate(segments[::-1]): fwd_ops = op_path[segment[0] : segment[1]] @@ -362,7 +415,7 @@ class RecomputePass(PassBase): ckpt_op = op_path[segment[1] - 1] ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs] - # step 4: insert recomputed fwd ops + # 5. insert recomputed fwd ops into backward parse ops = main_block.ops loss_op = get_loss_op(main_block) loss_op_idx = _find_op_index(main_block, loss_op) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index b2935a0b175b3d0f2b55b5aee74711efec9f9b6c..201241cb31e63678af75c33644f93273d304a76c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -72,6 +72,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_parallel_tuner_predict MODULES test_parallel_tuner_predict ENVS ${dist_ENVS}) set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) + py_test_modules(test_selective_recompute MODULES test_selective_recompute) + set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py index 1aa83f1a8c97864c0bdf1cb7fc53d71b30dec54e..d9c179dda09b4a2d82b03ebb7bd43cf704d26d92 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py @@ -22,13 +22,14 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from get_gpt_model import FakeDataset, generate_model -def apply_pass(use_recompute=False): +def apply_pass(use_recompute=False, no_recompute_segments=[]): strategy = auto.Strategy() strategy.auto_mode = "semi" strategy.reinit = True if use_recompute: recompute = strategy.recompute recompute.enable = True + recompute.no_recompute_segments = no_recompute_segments return strategy @@ -53,10 +54,10 @@ class TestRecomputePass(unittest.TestCase): place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) - def get_engine(self, use_recompute=False): + def get_engine(self, use_recompute=False, no_recompute_segments=[]): reset_prog() - strategy = apply_pass(use_recompute) + strategy = apply_pass(use_recompute, no_recompute_segments) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) model, loss = generate_model("mp") @@ -88,6 +89,18 @@ class TestRecomputePass(unittest.TestCase): rc_losses = np.array(history.history["loss"]) self.check_results(mp_losses, rc_losses) + # mp2 selective recompute training + rc1_engine = self.get_engine(True, [0]) + history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc1_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc1_losses) + + def test_recompute_pass_error(self): + + with self.assertRaises(AssertionError): + rc_engine = self.get_engine(True, [2]) + history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py new file mode 100644 index 0000000000000000000000000000000000000000..97e175a39801a5f6d9feedafec6284d8f9f106c7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +import random +import numpy as np +import paddle + +from paddle.distributed.fleet import auto +from paddle.fluid.dygraph.parallel import ParallelEnv +from get_gpt_model import FakeDataset + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import ( + GPTModel, + GPTForPretraining, + GPTPretrainingCriterion, +) + + +def generate_model(use_new_recompute, recompute_granularity): + modeling.init_global() + modeling._global_parallel_strategy = "serial" + modeling._global_process_mesh = auto.ProcessMesh(mesh=[0], dim_names=["x"]) + + gpt = GPTModel( + vocab_size=1000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=256, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + use_new_recompute=use_new_recompute, + recompute_granularity=recompute_granularity, + ) + model = GPTForPretraining( + gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 + ) + criterion = GPTPretrainingCriterion() + return model, criterion + + +def apply_pass(use_recompute=False, no_recompute_segments=[]): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + if use_recompute: + recompute = strategy.recompute + recompute.enable = True + recompute.no_recompute_segments = no_recompute_segments + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestRecomputePassWithRecomputeAPI(unittest.TestCase): + def setUp(self): + self.rtol = 1e-6 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 2 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine( + self, + use_recompute=False, + use_new_recompute=False, + recompute_granularity="full", + no_recompute_segments=[], + ): + reset_prog() + + strategy = apply_pass(use_recompute, no_recompute_segments) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model(use_new_recompute, recompute_granularity) + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) + + def recompute_vars(self, program): + return list(filter(lambda a: "subprog" in a.name, program.list_vars())) + + def test_recompute_pass(self): + # mp2 training + mp_engine = self.get_engine() + history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + mp_losses = np.array(history.history["loss"]) + + # mp2 recompute with old api + rc4_engine = self.get_engine(True, False) + history = rc4_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc4_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc4_losses) + + # mp2 recompute core_attn + rc1_engine = self.get_engine(True, True, "core_attn", [0]) + history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc1_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc1_losses) + + # mp2 recompute full_attn + rc2_engine = self.get_engine(True, True, "full_attn") + history = rc2_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc2_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc2_losses) + + # mp2 recompute full + rc3_engine = self.get_engine(True, True, "full") + history = rc3_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc3_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc3_losses) + + rc0_vars = self.recompute_vars(mp_engine.main_program) + rc1_vars = self.recompute_vars(rc1_engine.main_program) + rc2_vars = self.recompute_vars(rc2_engine.main_program) + rc3_vars = self.recompute_vars(rc3_engine.main_program) + + assert rc0_vars == [] + assert len(rc1_vars) < len(rc2_vars) and len(rc2_vars) < len(rc3_vars) + + def test_recompute_pass_error(self): + + with self.assertRaises(AssertionError): + rc_engine = self.get_engine(True, True, "full", [2]) + history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index 6e96ee6dcf83c9d8e892096a8b59d2030e71bad8..829e7f7a5ddc5c3abff5c45f3d245014ec7fd4e1 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -57,6 +57,8 @@ class MultiHeadAttention(nn.Layer): bias_attr=None, fuse=False, mesh_idx=None, + use_new_recompute=False, + recompute_granularity="full", ): super().__init__() self.embed_dim = embed_dim @@ -67,6 +69,9 @@ class MultiHeadAttention(nn.Layer): self.need_weights = need_weights self.fuse = fuse self.mesh_idx = mesh_idx + self.use_new_recompute = use_new_recompute + self.recompute_granularity = recompute_granularity + self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim @@ -225,6 +230,27 @@ class MultiHeadAttention(nn.Layer): # incremental_state with initial value, mainly for usage like UniLM return self.Cache(key, value) + def core_attn(self, q, k, v, attn_mask): + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5 + ) + if attn_mask is not None: + product = product + attn_mask + weights = F.softmax(product) + if self.dropout: + weights = F.dropout( + weights, + self.dropout, + training=self.training, + mode="upscale_in_train", + ) + out = tensor.matmul(weights, v) + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + return out, weights + def forward( self, query, key, value, attn_mask=None, use_cache=False, cache=None ): @@ -244,23 +270,12 @@ class MultiHeadAttention(nn.Layer): q, k, v, cache = self._prepare_qkv( query, key, value, use_cache, cache ) - product = layers.matmul( - x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5 - ) - if attn_mask is not None: - product = product + attn_mask - weights = F.softmax(product) - if self.dropout: - weights = F.dropout( - weights, - self.dropout, - training=self.training, - mode="upscale_in_train", - ) - out = tensor.matmul(weights, v) - # combine heads - out = tensor.transpose(out, perm=[0, 2, 1, 3]) - out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + if self.use_new_recompute and self.recompute_granularity == "core_attn": + out, weights = auto.recompute(self.core_attn)(q, k, v, attn_mask) + else: + out, weights = self.core_attn(q, k, v, attn_mask) + # project to output out = self.out_proj(out) if _global_parallel_strategy == "mp": @@ -295,12 +310,22 @@ class TransformerDecoder(nn.Layer): TransformerDecoder is a stack of N decoder layers. """ - def __init__(self, decoder_layers, num_layers, norm=None, hidden_size=None): + def __init__( + self, + decoder_layers, + num_layers, + norm=None, + hidden_size=None, + use_new_recompute=False, + recompute_granularity="full", + ): super().__init__() self.num_layers = num_layers self.layers = decoder_layers self.norm = norm + self.use_new_recompute = use_new_recompute + self.recompute_granularity = recompute_granularity if norm == "LayerNorm": self.norm = nn.LayerNorm(hidden_size) elif norm is not None: @@ -348,149 +373,36 @@ class TransformerDecoder(nn.Layer): DPMPPP_MESH_LIST[0], ["x"] + [None for i in range(len(output.shape) - 1)], ) + for i, mod in enumerate(self.layers): + if self.use_new_recompute and self.recompute_granularity == "full": + mod = auto.recompute(mod) + if cache is None: if use_cache: - if _global_parallel_strategy == "pp": - output, new_cache = auto.shard_op( - mod, PP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - PP_MESH_LIST[mod.mesh_idx], - [None for i in range(len(output.shape))], - ) - elif _global_parallel_strategy == "dp_pp": - output, new_cache = auto.shard_op( - mod, DPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - DPPP_MESH_LIST[mod.mesh_idx], - ["x"] - + [None for i in range(len(output.shape) - 1)], - ) - elif _global_parallel_strategy == "mp_pp": - output, new_cache = auto.shard_op( - mod, MPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - MPPP_MESH_LIST[mod.mesh_idx], - [None for i in range(len(output.shape))], - ) - elif _global_parallel_strategy == "dp_mp_pp": - output, new_cache = auto.shard_op( - mod, DPMPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - DPMPPP_MESH_LIST[mod.mesh_idx], - [None for i in range(len(output.shape))], - ) - else: - output, new_cache = mod( - output, - memory, - tgt_mask=tgt_mask, - use_cache=use_cache, - cache=cache, - ) - new_caches.append(new_cache) - else: - if _global_parallel_strategy == "pp": - output = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])( - output, memory, tgt_mask, use_cache, cache - ) - auto.shard_tensor( - output, - PP_MESH_LIST[mod.mesh_idx], - [None for i in range(len(output.shape))], - ) - elif _global_parallel_strategy == "dp_pp": - output = auto.shard_op( - mod, DPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - DPPP_MESH_LIST[mod.mesh_idx], - ["x"] - + [None for i in range(len(output.shape) - 1)], - ) - elif _global_parallel_strategy == "mp_pp": - output = auto.shard_op( - mod, MPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - MPPP_MESH_LIST[mod.mesh_idx], - [None for i in range(len(output.shape))], - ) - elif _global_parallel_strategy == "dp_mp_pp": - output = auto.shard_op( - mod, DPMPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - DPMPPP_MESH_LIST[mod.mesh_idx], - ["x"] - + [None for i in range(len(output.shape) - 1)], - ) - else: - output = mod( - output, - memory, - tgt_mask=tgt_mask, - use_cache=use_cache, - cache=cache, - ) - else: - if _global_parallel_strategy == "pp": - output, new_cache = auto.shard_op( - mod, PP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - PP_MESH_LIST[mod.mesh_idx], - [None for i in range(len(output.shape))], - ) - elif _global_parallel_strategy == "dp_pp": - output, new_cache = auto.shard_op( - mod, DPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - DPPP_MESH_LIST[mod.mesh_idx], - ["x"] + [None for i in range(len(output.shape) - 1)], - ) - elif _global_parallel_strategy == "mp_pp": - output, new_cache = auto.shard_op( - mod, MPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - MPPP_MESH_LIST[mod.mesh_idx], - [None for i in range(len(output.shape))], - ) - elif _global_parallel_strategy == "dp_mp_pp": - output, new_cache = auto.shard_op( - mod, DPMPPP_MESH_LIST[mod.mesh_idx] - )(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - DPMPPP_MESH_LIST[mod.mesh_idx], - ["x"] + [None for i in range(len(output.shape) - 1)], - ) - else: output, new_cache = mod( output, memory, tgt_mask=tgt_mask, use_cache=use_cache, - cache=cache[i], + cache=cache, ) + new_caches.append(new_cache) + else: + output = mod(output, memory, tgt_mask, use_cache, cache) + else: + output, new_cache = mod( + output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache[i], + ) new_caches.append(new_cache) - self.checkpoints.append(output.name) + + if not self.use_new_recompute: + self.checkpoints.append(output.name) + if self.norm is not None: output = self.norm(output) return output if use_cache is False else (output, new_caches) @@ -528,6 +440,8 @@ class TransformerDecoderLayer(nn.Layer): weight_attr=None, bias_attr=None, mesh_idx=None, + use_new_recompute=False, + recompute_granularity="full", ): self._config = locals() self._config.pop("self") @@ -537,8 +451,12 @@ class TransformerDecoderLayer(nn.Layer): attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before + self.use_new_recompute = use_new_recompute + self.recompute_granularity = recompute_granularity + weight_attrs = _convert_param_attr_to_list(weight_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3) + self.self_attn = MultiHeadAttention( d_model, nhead, @@ -546,6 +464,8 @@ class TransformerDecoderLayer(nn.Layer): weight_attr=weight_attrs[0], bias_attr=bias_attrs[0], mesh_idx=self.mesh_idx, + use_new_recompute=self.use_new_recompute, + recompute_granularity=self.recompute_granularity, ) self.linear1 = nn.Linear( d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2] @@ -563,12 +483,19 @@ class TransformerDecoderLayer(nn.Layer): residual = tgt if self.normalize_before: tgt = self.norm1(tgt) + + if self.use_new_recompute and self.recompute_granularity == "full_attn": + self_attn = auto.recompute(self.self_attn) + else: + self_attn = self.self_attn + if use_cache is False: - tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) + tgt = self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) else: - tgt, incremental_cache = self.self_attn( + tgt, incremental_cache = self_attn( tgt, tgt, tgt, tgt_mask, use_cache, cache ) + tgt = residual + self.dropout1(tgt) if not self.normalize_before: tgt = self.norm1(tgt) @@ -716,12 +643,17 @@ class GPTModel(nn.Layer): bos_token_id=0, eol_token_id=3, pp_degree=None, + use_new_recompute=False, + recompute_granularity="full", ): super().__init__() self.pad_token_id = pad_token_id self.initializer_range = initializer_range self.hidden_size = hidden_size self.vocab_size = vocab_size + self.use_new_recompute = use_new_recompute + self.recompute_granularity = recompute_granularity + self.layer_per_stage = None self.pipline_mode = pp_degree is not None and pp_degree > 1 if self.pipline_mode: @@ -734,6 +666,7 @@ class GPTModel(nn.Layer): type_vocab_size, self.initializer_range, ) + decoder_layers = nn.LayerList() for i in range(num_hidden_layers): mesh_index = None @@ -756,14 +689,19 @@ class GPTModel(nn.Layer): ), bias_attr=None, mesh_idx=mesh_index, + use_new_recompute=self.use_new_recompute, + recompute_granularity=self.recompute_granularity, ) ) + Decoder = TransformerDecoder self.decoder = Decoder( decoder_layers, num_hidden_layers, norm="LayerNorm", hidden_size=hidden_size, + use_new_recompute=self.use_new_recompute, + recompute_granularity=self.recompute_granularity, ) self.checkpoints = [] @@ -817,7 +755,8 @@ class GPTModel(nn.Layer): use_cache=use_cache, cache=cache, ) - self.checkpoints.extend(self.decoder.checkpoints) + if not self.use_new_recompute: + self.checkpoints.extend(self.decoder.checkpoints) return encoder_outputs