未验证 提交 5004c33a 编写于 作者: Z zhangbo9674 提交者: GitHub

fix amp cast bug for bn (#47802)

上级 7c302538
...@@ -103,6 +103,9 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -103,6 +103,9 @@ inline paddle::experimental::DataType GetAmpDestDtype(
return paddle::experimental::DataType::FLOAT16; return paddle::experimental::DataType::FLOAT16;
} else if (paddle::imperative::AmpOperators::Instance() } else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps() .GetMutableBlockOps()
->count(op_name) ||
paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops()
->count(op_name)) { ->count(op_name)) {
return paddle::experimental::DataType::FLOAT32; return paddle::experimental::DataType::FLOAT32;
} else { } else {
......
...@@ -353,7 +353,9 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type, ...@@ -353,7 +353,9 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
} }
} }
return new_ins; return new_ins;
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type) ||
AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(
op_type)) {
for (auto& pair : new_ins) { for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float"; << GetDtypeStr(*pair.second.cbegin()) << " to float";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册