diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index 666e569d125d69d86787920b5e92476d2f8da5a2..115811f6a3d8e4ace1e39a769e3259448fc4c766 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -103,6 +103,9 @@ inline paddle::experimental::DataType GetAmpDestDtype( return paddle::experimental::DataType::FLOAT16; } else if (paddle::imperative::AmpOperators::Instance() .GetMutableBlockOps() + ->count(op_name) || + paddle::imperative::AmpOperators::Instance() + .GetMutableUnsupportedFp16Ops() ->count(op_name)) { return paddle::experimental::DataType::FLOAT32; } else { diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index cc0c9c7871a3780573d8f4ae44608bb1ca230869..55c15208208085feb9207cac22d105b4ceb96e80 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -353,7 +353,9 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type, } } return new_ins; - } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { + } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type) || + AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count( + op_type)) { for (auto& pair : new_ins) { VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to float";