From e9e07a19b6bc0fcfc16ef3e892a558ebf78e05c5 Mon Sep 17 00:00:00 2001 From: Wennie396 <44974020+Wennie396@users.noreply.github.com> Date: Tue, 5 Sep 2023 14:30:58 +0800 Subject: [PATCH] fix some bugs for amp and test case test_tuning_recompute_with_amp.py (#56864) * replace amp.use_pure_fp16 with amp.dtype and amp.level * old api still use use_pure_fp16 * test_fuse_adamw_pass still use use_pure_fp16 * add test case tuning recompute with amp(float16,o2) * reset new test case properties TIMEOUT 60 * set smaller value of batch_size and batch_num * deepcopy dist_context fix _rename_input problem * fix loss name after cast * set tuning.enable=True and use engine._tune() * restore some changes in _rename_input()/_rename_output() * add self.amp_dtype for _cast_loss() in auto_parallel_amp.py * fix insert op index in _cast_loss() --- .../static/tuner/optimization_tuner.py | 2 +- .../distributed/passes/auto_parallel_amp.py | 26 ++-- .../distributed/passes/auto_parallel_fp16.py | 7 +- test/auto_parallel/CMakeLists.txt | 3 + test/auto_parallel/test_fp16_assign.py | 3 +- .../test_tuning_recompute_with_amp.py | 113 ++++++++++++++++++ 6 files changed, 141 insertions(+), 13 deletions(-) create mode 100644 test/auto_parallel/test_tuning_recompute_with_amp.py diff --git a/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py index 890e5fd2be5..64eaca28c06 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py @@ -308,7 +308,7 @@ class OptimizationTuner: self._baseline_dist_context.serial_feed_vars["inputs"] + self._baseline_dist_context.serial_feed_vars["labels"] ) - if config["use_pure_fp16"]: + if config["dtype"] == "float16" and config["level"] == "o2": config["base_opt"] = dist_context.serial_optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply( diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index cb875f63e3d..34ab1c29534 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -728,7 +728,7 @@ class AMPPass(PassBase): if is_train: self._update_backward_cast_ops() - self._cast_loss() + self._cast_loss(self.amp_dtype) if is_train and self.amp_dtype == "float16": self._init_amp_var() @@ -913,7 +913,7 @@ class AMPPass(PassBase): world_process_group.ranks, ) - def _cast_loss(self): + def _cast_loss(self, target_dtype): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() @@ -957,11 +957,18 @@ class AMPPass(PassBase): ) # backward - first_backward_op = main_block.ops[loss_op_idx + 2] - assert ( - first_backward_op.type == "fill_constant" - and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 - ) + first_backward_op = None + insert_op_offset = 3 + for idx, op in enumerate(main_block.ops[loss_op_idx:]): + if op.type == "fill_constant" and is_loss_grad_op(op): + first_backward_op = op + insert_op_offset = idx + 1 + break + if is_backward_op(op): + break + + assert first_backward_op is not None, "There is not loss_grad op." + cast_loss_grad = main_block.create_var( name=unique_name.generate(tmp_name + "@GRAD"), shape=loss.shape, @@ -984,13 +991,13 @@ class AMPPass(PassBase): self.dist_context, ) cast_grad_op = main_block._insert_op( - loss_op_idx + 3, + loss_op_idx + insert_op_offset, type='cast', inputs={'X': [cast_loss_grad]}, outputs={'Out': [pre_grad_name]}, attrs={ "in_dtype": core.VarDesc.VarType.FP32, - "out_dtype": _str_to_dtype(self.amp_dtype), + "out_dtype": _str_to_dtype(target_dtype), "op_role": OpRole.Backward, }, ) @@ -1002,6 +1009,7 @@ class AMPPass(PassBase): ) loss_op = cast_op loss = cast_loss + self.set_attr("loss", loss) self._loss = loss main_block._sync_with_cpp() diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index fc1c5bd9807..c9a0f772db5 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from collections import defaultdict import paddle @@ -408,8 +409,8 @@ class FP16State: (cast_name, in_var.name, dst_dtype, src_dtype, in_name) ] - in_var_dist_attr = consume_op_attr.get_input_dist_attr( - in_var.name + in_var_dist_attr = copy.deepcopy( + consume_op_attr.get_input_dist_attr(in_var.name) ) assert in_var_dist_attr is not None # truly insert cast op @@ -800,6 +801,8 @@ class FP16Pass(AMPPass): is_train = fp16_state._build_state() cast_startup_program() + if is_train: + self._cast_loss(self.target_dtype) if is_train: if self.target_dtype == "float16": diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 458c273951e..aeb00a0fc72 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -105,6 +105,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 300) + py_test_modules(test_tuning_recompute_with_amp MODULES + test_tuning_recompute_with_amp) + set_tests_properties(test_tuning_recompute_with_amp PROPERTIES TIMEOUT 60) py_test_modules(test_fused_linear_pass MODULES test_fused_linear_pass) set_tests_properties(test_fused_linear_pass PROPERTIES TIMEOUT 40) py_test_modules(test_align_tool MODULES test_align_tool) diff --git a/test/auto_parallel/test_fp16_assign.py b/test/auto_parallel/test_fp16_assign.py index 5eabd0501ac..a7257cb0254 100644 --- a/test/auto_parallel/test_fp16_assign.py +++ b/test/auto_parallel/test_fp16_assign.py @@ -80,7 +80,8 @@ def parallelizer(program_func, rank): strategy = auto.Strategy() amp = strategy.amp amp.enable = True - amp.use_pure_fp16 = True + amp.dtype = "float16" + amp.level = "o2" amp.init_loss_scaling = 32768 amp.use_fp16_guard = False amp.custom_black_list = ['where'] diff --git a/test/auto_parallel/test_tuning_recompute_with_amp.py b/test/auto_parallel/test_tuning_recompute_with_amp.py new file mode 100644 index 00000000000..a7f093ab054 --- /dev/null +++ b/test/auto_parallel/test_tuning_recompute_with_amp.py @@ -0,0 +1,113 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +from get_gpt_model import FakeDataset + +import paddle +from paddle.distributed.fleet import auto + +sys.path.append("../legacy_test") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import ( + GPTForPretraining, + GPTModel, + GPTPretrainingCriterion, +) + +paddle.enable_static() + + +def generate_model(): + modeling.init_global() + modeling._global_parallel_strategy = "serial" + ranks = list(range(paddle.distributed.get_world_size())) + modeling._global_process_mesh = auto.ProcessMesh( + mesh=ranks, dim_names=["x"] + ) + + gpt = GPTModel( + vocab_size=50304, + hidden_size=1024, + num_hidden_layers=8, + num_attention_heads=16, + intermediate_size=1024 * 4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + use_new_recompute=True, + recompute_granularity="full", + ) + model = GPTForPretraining( + gpt, vocab_size=50304, hidden_size=1024, initializer_range=0.02 + ) + criterion = GPTPretrainingCriterion() + return model, criterion + + +def apply_pass(): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + recompute = strategy.recompute + recompute.enable = True + recompute.enable_tuning = True + + tuning = strategy.tuning + tuning.enable = True + tuning.profile_start_step = 1 + tuning.profile_end_step = 2 + tuning.run_after_tuning = True + tuning.verbose = True + + amp = strategy.amp + amp.enable = True + amp.dtype = "float16" + amp.level = "o2" + + return strategy + + +class TestRecomputeWithAMPPassTuning(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.batch_num = 10 + self.dataset = FakeDataset( + self.batch_size * self.batch_num, + vocab_size=50304, + sequence_len=1024, + ) + + def test_recompute_with_amp_pass(self): + strategy = apply_pass() + clip = paddle.nn.ClipGradByGlobalNorm(0.2) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model() + + engine = auto.Engine(model, loss, opt, strategy=strategy) + # engine.fit(self.dataset, 3, batch_size=self.batch_size) + engine._tune(self.dataset, 3, batch_size=self.batch_size) + + +if __name__ == "__main__": + unittest.main() -- GitLab