diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index f2ea692ad088085becd56b6ebfdde2af84abe468..5f314b0f925759844e9a4fce94623c1059ecb7fe 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -266,6 +266,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, pair.first != "X") { continue; } + if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { + if (pair.first == "LnScale" || pair.first == "LnBias" || + pair.first == "Ln2Scale" || pair.first == "Ln2Bias" || + pair.first == "Ln1Scale" || pair.first == "Ln1Bias") { + continue; + } + } VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to " << framework::DataTypeToString(dst_type);