From 5dab0b0dc10f0547bf98fad8891ca0e0c14c3d23 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Tue, 27 Sep 2022 22:15:20 +0800 Subject: [PATCH] [AutoParallel] fix amp o1 (#46391) (#46481) --- .../distributed/passes/auto_parallel_amp.py | 40 +++++++++++-------- .../auto_parallel/amp_pass_unittest.py | 3 ++ 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 458cb26ccd4..3545783ba17 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -38,14 +38,18 @@ class AMPState(object): self._op_fp16_dict = { } # op_id --> True/False. 'True' means that the current op is in fp16 mode. self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name} + self.is_train = False def _is_fp16_op(self, op_id): return self._op_fp16_dict.get(op_id, None) - def _build_stats(self, amp_lists, dist_context): + def _build_state(self, amp_lists, dist_context): ops = self._block.ops dist_op_context = dist_context.dist_op_context for op in ops: + if int(op.attr('op_role')) == 257: + self.is_train = True + if int(op.attr('op_role')) == int(OpRole.Forward): self._mark_black_white_ops(amp_lists) elif int(op.attr('op_role')) == int(OpRole.Backward): @@ -59,6 +63,8 @@ class AMPState(object): elif int(op.attr('op_role')) == int(OpRole.Optimize): break + return self.is_train + def _mark_black_white_ops(self, amp_lists): """ this function is modified from paddle.fluid.contrib.mixed_precision @@ -546,23 +552,25 @@ class AMPPass(PassBase): set(self.get_attr("custom_black_list")), set(self.get_attr("custom_black_varnames"))) - amp_state = AMPState(main_program.global_block()) - amp_state._build_stats(amp_lists, self.dist_context) - with paddle.static.program_guard(main_program, startup_program): + amp_state = AMPState(main_program.global_block()) + is_train = amp_state._build_state(amp_lists, self.dist_context) + amp_state.cast_forward_program(self.dist_context) - amp_state.cast_backward_program(params_grads, self.dist_context) - # TODO (JZ-LIANG)support cast forward program only when inference - self._init_amp_var() - self._scale_loss() - - if self.get_attr("use_dynamic_loss_scaling" - ) or self.get_attr("init_loss_scaling") != 1.0: - grads, found_inf = _check_and_update_gradient( - params_grads, self._loss_scaling, self.dist_context) - - if self.get_attr("use_dynamic_loss_scaling"): - self._update_loss_scaling(grads, found_inf) + + if is_train: + with paddle.static.program_guard(main_program, startup_program): + amp_state.cast_backward_program(params_grads, self.dist_context) + self._init_amp_var() + self._scale_loss() + + if self.get_attr("use_dynamic_loss_scaling" + ) or self.get_attr("init_loss_scaling") != 1.0: + grads, found_inf = _check_and_update_gradient( + params_grads, self._loss_scaling, self.dist_context) + + if self.get_attr("use_dynamic_loss_scaling"): + self._update_loss_scaling(grads, found_inf) def _init_amp_var(self): self._loss_scaling = paddle.static.create_global_var( diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py index a00d3073630..45ca5695af4 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py @@ -97,6 +97,7 @@ class TestAMPPass(unittest.TestCase): 3, batch_size=self.batch_size) amp_o1_losses = np.array(amp_o1_losses["loss"]) + amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) # self.check_results(mp_losses, amp_o1_losses) # mp2 amp-o2 training @@ -105,6 +106,7 @@ class TestAMPPass(unittest.TestCase): 3, batch_size=self.batch_size) amp_o2_losses = np.array(amp_o2_losses["loss"]) + amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) # self.check_results(mp_losses, amp_o2_losses) # mp2 amp-o3 training @@ -113,6 +115,7 @@ class TestAMPPass(unittest.TestCase): 3, batch_size=self.batch_size) amp_o3_losses = np.array(amp_o3_losses["loss"]) + amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) # self.check_results(mp_losses, amp_o3_losses) -- GitLab