From 5dab0b0dc10f0547bf98fad8891ca0e0c14c3d23 Mon Sep 17 00:00:00 2001
From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
Date: Tue, 27 Sep 2022 22:15:20 +0800
Subject: [PATCH] [AutoParallel] fix amp o1 (#46391) (#46481)

---
 .../distributed/passes/auto_parallel_amp.py   | 40 +++++++++++--------
 .../auto_parallel/amp_pass_unittest.py        |  3 ++
 2 files changed, 27 insertions(+), 16 deletions(-)

diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py
index 458cb26ccd4..3545783ba17 100644
--- a/python/paddle/distributed/passes/auto_parallel_amp.py
+++ b/python/paddle/distributed/passes/auto_parallel_amp.py
@@ -38,14 +38,18 @@ class AMPState(object):
         self._op_fp16_dict = {
         }  # op_id --> True/False. 'True' means that the current op is in fp16 mode.
         self._var_name_dict = {}  # fwd_op_id --> {old_name: cast_name}
+        self.is_train = False
 
     def _is_fp16_op(self, op_id):
         return self._op_fp16_dict.get(op_id, None)
 
-    def _build_stats(self, amp_lists, dist_context):
+    def _build_state(self, amp_lists, dist_context):
         ops = self._block.ops
         dist_op_context = dist_context.dist_op_context
         for op in ops:
+            if int(op.attr('op_role')) == 257:
+                self.is_train = True
+
             if int(op.attr('op_role')) == int(OpRole.Forward):
                 self._mark_black_white_ops(amp_lists)
             elif int(op.attr('op_role')) == int(OpRole.Backward):
@@ -59,6 +63,8 @@ class AMPState(object):
             elif int(op.attr('op_role')) == int(OpRole.Optimize):
                 break
 
+        return self.is_train
+
     def _mark_black_white_ops(self, amp_lists):
         """
         this function is modified from paddle.fluid.contrib.mixed_precision
@@ -546,23 +552,25 @@ class AMPPass(PassBase):
             set(self.get_attr("custom_black_list")),
             set(self.get_attr("custom_black_varnames")))
 
-        amp_state = AMPState(main_program.global_block())
-        amp_state._build_stats(amp_lists, self.dist_context)
-
         with paddle.static.program_guard(main_program, startup_program):
+            amp_state = AMPState(main_program.global_block())
+            is_train = amp_state._build_state(amp_lists, self.dist_context)
+
             amp_state.cast_forward_program(self.dist_context)
-            amp_state.cast_backward_program(params_grads, self.dist_context)
-            # TODO (JZ-LIANG)support cast forward program only when inference
-            self._init_amp_var()
-            self._scale_loss()
-
-            if self.get_attr("use_dynamic_loss_scaling"
-                             ) or self.get_attr("init_loss_scaling") != 1.0:
-                grads, found_inf = _check_and_update_gradient(
-                    params_grads, self._loss_scaling, self.dist_context)
-
-            if self.get_attr("use_dynamic_loss_scaling"):
-                self._update_loss_scaling(grads, found_inf)
+
+        if is_train:
+            with paddle.static.program_guard(main_program, startup_program):
+                amp_state.cast_backward_program(params_grads, self.dist_context)
+                self._init_amp_var()
+                self._scale_loss()
+
+                if self.get_attr("use_dynamic_loss_scaling"
+                                 ) or self.get_attr("init_loss_scaling") != 1.0:
+                    grads, found_inf = _check_and_update_gradient(
+                        params_grads, self._loss_scaling, self.dist_context)
+
+                if self.get_attr("use_dynamic_loss_scaling"):
+                    self._update_loss_scaling(grads, found_inf)
 
     def _init_amp_var(self):
         self._loss_scaling = paddle.static.create_global_var(
diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
index a00d3073630..45ca5695af4 100644
--- a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
+++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
@@ -97,6 +97,7 @@ class TestAMPPass(unittest.TestCase):
                                           3,
                                           batch_size=self.batch_size)
         amp_o1_losses = np.array(amp_o1_losses["loss"])
+        amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
         # self.check_results(mp_losses, amp_o1_losses)
 
         # mp2 amp-o2 training
@@ -105,6 +106,7 @@ class TestAMPPass(unittest.TestCase):
                                           3,
                                           batch_size=self.batch_size)
         amp_o2_losses = np.array(amp_o2_losses["loss"])
+        amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
         # self.check_results(mp_losses, amp_o2_losses)
 
         # mp2 amp-o3 training
@@ -113,6 +115,7 @@ class TestAMPPass(unittest.TestCase):
                                           3,
                                           batch_size=self.batch_size)
         amp_o3_losses = np.array(amp_o3_losses["loss"])
+        amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
         # self.check_results(mp_losses, amp_o3_losses)
 
 
-- 
GitLab