From 5004c33a3834ccfe330ae0b1e134f947dce5b9a5 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 10 Nov 2022 09:53:47 +0800 Subject: [PATCH] fix amp cast bug for bn (#47802) --- paddle/fluid/eager/amp_utils.h | 3 +++ paddle/fluid/imperative/amp_auto_cast.cc | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index 666e569d125..115811f6a3d 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -103,6 +103,9 @@ inline paddle::experimental::DataType GetAmpDestDtype( return paddle::experimental::DataType::FLOAT16; } else if (paddle::imperative::AmpOperators::Instance() .GetMutableBlockOps() + ->count(op_name) || + paddle::imperative::AmpOperators::Instance() + .GetMutableUnsupportedFp16Ops() ->count(op_name)) { return paddle::experimental::DataType::FLOAT32; } else { diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index cc0c9c7871a..55c15208208 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -353,7 +353,9 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type, } } return new_ins; - } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { + } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type) || + AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count( + op_type)) { for (auto& pair : new_ins) { VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to float"; -- GitLab