diff --git a/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py index 890e5fd2be57c99b2ac28a6c5d8d0d9b0b306f24..64eaca28c06ea515fc65d5fbe6123dd9f9db97fe 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py @@ -308,7 +308,7 @@ class OptimizationTuner: self._baseline_dist_context.serial_feed_vars["inputs"] + self._baseline_dist_context.serial_feed_vars["labels"] ) - if config["use_pure_fp16"]: + if config["dtype"] == "float16" and config["level"] == "o2": config["base_opt"] = dist_context.serial_optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply( diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index cb875f63e3d259c45dcbdb02824c40413370b493..34ab1c29534a9a00e65eae6a34fb47453297b8e6 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -728,7 +728,7 @@ class AMPPass(PassBase): if is_train: self._update_backward_cast_ops() - self._cast_loss() + self._cast_loss(self.amp_dtype) if is_train and self.amp_dtype == "float16": self._init_amp_var() @@ -913,7 +913,7 @@ class AMPPass(PassBase): world_process_group.ranks, ) - def _cast_loss(self): + def _cast_loss(self, target_dtype): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() @@ -957,11 +957,18 @@ class AMPPass(PassBase): ) # backward - first_backward_op = main_block.ops[loss_op_idx + 2] - assert ( - first_backward_op.type == "fill_constant" - and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 - ) + first_backward_op = None + insert_op_offset = 3 + for idx, op in enumerate(main_block.ops[loss_op_idx:]): + if op.type == "fill_constant" and is_loss_grad_op(op): + first_backward_op = op + insert_op_offset = idx + 1 + break + if is_backward_op(op): + break + + assert first_backward_op is not None, "There is not loss_grad op." + cast_loss_grad = main_block.create_var( name=unique_name.generate(tmp_name + "@GRAD"), shape=loss.shape, @@ -984,13 +991,13 @@ class AMPPass(PassBase): self.dist_context, ) cast_grad_op = main_block._insert_op( - loss_op_idx + 3, + loss_op_idx + insert_op_offset, type='cast', inputs={'X': [cast_loss_grad]}, outputs={'Out': [pre_grad_name]}, attrs={ "in_dtype": core.VarDesc.VarType.FP32, - "out_dtype": _str_to_dtype(self.amp_dtype), + "out_dtype": _str_to_dtype(target_dtype), "op_role": OpRole.Backward, }, ) @@ -1002,6 +1009,7 @@ class AMPPass(PassBase): ) loss_op = cast_op loss = cast_loss + self.set_attr("loss", loss) self._loss = loss main_block._sync_with_cpp() diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index fc1c5bd9807a8f0e900bc317cb0bb58a414784f8..c9a0f772db5cecd74dbc4eb960e90cd7a9235251 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from collections import defaultdict import paddle @@ -408,8 +409,8 @@ class FP16State: (cast_name, in_var.name, dst_dtype, src_dtype, in_name) ] - in_var_dist_attr = consume_op_attr.get_input_dist_attr( - in_var.name + in_var_dist_attr = copy.deepcopy( + consume_op_attr.get_input_dist_attr(in_var.name) ) assert in_var_dist_attr is not None # truly insert cast op @@ -800,6 +801,8 @@ class FP16Pass(AMPPass): is_train = fp16_state._build_state() cast_startup_program() + if is_train: + self._cast_loss(self.target_dtype) if is_train: if self.target_dtype == "float16": diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 458c273951e058f54ce3996b2cc008b2048332e4..aeb00a0fc7296c35b9f9164266f844857961b4ed 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -105,6 +105,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 300) + py_test_modules(test_tuning_recompute_with_amp MODULES + test_tuning_recompute_with_amp) + set_tests_properties(test_tuning_recompute_with_amp PROPERTIES TIMEOUT 60) py_test_modules(test_fused_linear_pass MODULES test_fused_linear_pass) set_tests_properties(test_fused_linear_pass PROPERTIES TIMEOUT 40) py_test_modules(test_align_tool MODULES test_align_tool) diff --git a/test/auto_parallel/test_fp16_assign.py b/test/auto_parallel/test_fp16_assign.py index 5eabd0501ac5660a1373539f70b9fb1b571a92cb..a7257cb0254705b697c66bbf97e8095a8479416b 100644 --- a/test/auto_parallel/test_fp16_assign.py +++ b/test/auto_parallel/test_fp16_assign.py @@ -80,7 +80,8 @@ def parallelizer(program_func, rank): strategy = auto.Strategy() amp = strategy.amp amp.enable = True - amp.use_pure_fp16 = True + amp.dtype = "float16" + amp.level = "o2" amp.init_loss_scaling = 32768 amp.use_fp16_guard = False amp.custom_black_list = ['where'] diff --git a/test/auto_parallel/test_tuning_recompute_with_amp.py b/test/auto_parallel/test_tuning_recompute_with_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f093ab0545c31b1806b114f21f3179796f56fc --- /dev/null +++ b/test/auto_parallel/test_tuning_recompute_with_amp.py @@ -0,0 +1,113 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +from get_gpt_model import FakeDataset + +import paddle +from paddle.distributed.fleet import auto + +sys.path.append("../legacy_test") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import ( + GPTForPretraining, + GPTModel, + GPTPretrainingCriterion, +) + +paddle.enable_static() + + +def generate_model(): + modeling.init_global() + modeling._global_parallel_strategy = "serial" + ranks = list(range(paddle.distributed.get_world_size())) + modeling._global_process_mesh = auto.ProcessMesh( + mesh=ranks, dim_names=["x"] + ) + + gpt = GPTModel( + vocab_size=50304, + hidden_size=1024, + num_hidden_layers=8, + num_attention_heads=16, + intermediate_size=1024 * 4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + use_new_recompute=True, + recompute_granularity="full", + ) + model = GPTForPretraining( + gpt, vocab_size=50304, hidden_size=1024, initializer_range=0.02 + ) + criterion = GPTPretrainingCriterion() + return model, criterion + + +def apply_pass(): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + recompute = strategy.recompute + recompute.enable = True + recompute.enable_tuning = True + + tuning = strategy.tuning + tuning.enable = True + tuning.profile_start_step = 1 + tuning.profile_end_step = 2 + tuning.run_after_tuning = True + tuning.verbose = True + + amp = strategy.amp + amp.enable = True + amp.dtype = "float16" + amp.level = "o2" + + return strategy + + +class TestRecomputeWithAMPPassTuning(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.batch_num = 10 + self.dataset = FakeDataset( + self.batch_size * self.batch_num, + vocab_size=50304, + sequence_len=1024, + ) + + def test_recompute_with_amp_pass(self): + strategy = apply_pass() + clip = paddle.nn.ClipGradByGlobalNorm(0.2) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model() + + engine = auto.Engine(model, loss, opt, strategy=strategy) + # engine.fit(self.dataset, 3, batch_size=self.batch_size) + engine._tune(self.dataset, 3, batch_size=self.batch_size) + + +if __name__ == "__main__": + unittest.main()