未验证 提交 5004c33a 编写于 作者: Z zhangbo9674 提交者: GitHub

fix amp cast bug for bn (#47802)

上级 7c302538
......@@ -103,6 +103,9 @@ inline paddle::experimental::DataType GetAmpDestDtype(
return paddle::experimental::DataType::FLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name) ||
paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops()
->count(op_name)) {
return paddle::experimental::DataType::FLOAT32;
} else {
......
......@@ -353,7 +353,9 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
}
}
return new_ins;
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type) ||
AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(
op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册