未验证 提交 6ebc318e 编写于 作者: Z zhangkaihuo 提交者: GitHub

for pure fp16 (#37230)

Add pure fp16 support for fused transformer.
上级 56810f45
...@@ -266,6 +266,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, ...@@ -266,6 +266,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
pair.first != "X") { pair.first != "X") {
continue; continue;
} }
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
pair.first == "Ln1Scale" || pair.first == "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to " << GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type); << framework::DataTypeToString(dst_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册