From 6ebc318e46e3c9b74d33cf4cf12bad515d69261c Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 16 Nov 2021 14:25:40 +0800 Subject: [PATCH] for pure fp16 (#37230) Add pure fp16 support for fused transformer. --- paddle/fluid/imperative/amp_auto_cast.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index f2ea692ad08..5f314b0f925 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -266,6 +266,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, pair.first != "X") { continue; } + if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { + if (pair.first == "LnScale" || pair.first == "LnBias" || + pair.first == "Ln2Scale" || pair.first == "Ln2Bias" || + pair.first == "Ln1Scale" || pair.first == "Ln1Bias") { + continue; + } + } VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to " << framework::DataTypeToString(dst_type); -- GitLab