From 5004c33a3834ccfe330ae0b1e134f947dce5b9a5 Mon Sep 17 00:00:00 2001
From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com>
Date: Thu, 10 Nov 2022 09:53:47 +0800
Subject: [PATCH] fix amp cast bug for bn (#47802)

---
 paddle/fluid/eager/amp_utils.h           | 3 +++
 paddle/fluid/imperative/amp_auto_cast.cc | 4 +++-
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h
index 666e569d125..115811f6a3d 100644
--- a/paddle/fluid/eager/amp_utils.h
+++ b/paddle/fluid/eager/amp_utils.h
@@ -103,6 +103,9 @@ inline paddle::experimental::DataType GetAmpDestDtype(
         return paddle::experimental::DataType::FLOAT16;
       } else if (paddle::imperative::AmpOperators::Instance()
                      .GetMutableBlockOps()
+                     ->count(op_name) ||
+                 paddle::imperative::AmpOperators::Instance()
+                     .GetMutableUnsupportedFp16Ops()
                      ->count(op_name)) {
         return paddle::experimental::DataType::FLOAT32;
       } else {
diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc
index cc0c9c7871a..55c15208208 100644
--- a/paddle/fluid/imperative/amp_auto_cast.cc
+++ b/paddle/fluid/imperative/amp_auto_cast.cc
@@ -353,7 +353,9 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
       }
     }
     return new_ins;
-  } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
+  } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type) ||
+             AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(
+                 op_type)) {
     for (auto& pair : new_ins) {
       VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
               << GetDtypeStr(*pair.second.cbegin()) << " to float";
-- 
GitLab