From 23ab01e306effc92a54177d04168244f16b7de1e Mon Sep 17 00:00:00 2001 From: Roc Date: Mon, 10 May 2021 10:52:26 +0800 Subject: [PATCH] Dynamic amp support sync_batch_norm op (#32770) --- paddle/fluid/imperative/amp_auto_cast.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index fd2bb6e5c9..b4154737e0 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -160,7 +160,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) { for (auto& pair : new_ins) { // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. - if ((op_type == "batch_norm" || op_type == "layer_norm") && + if ((op_type == "batch_norm" || op_type == "layer_norm" || + op_type == "sync_batch_norm") && pair.first != "X") { continue; } @@ -191,7 +192,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, } for (auto& pair : new_ins) { // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. - if ((op_type == "batch_norm" || op_type == "layer_norm") && + if ((op_type == "batch_norm" || op_type == "layer_norm" || + op_type == "sync_batch_norm") && pair.first == "X" && dst_type == framework::proto::VarType::FP32) { continue; } -- GitLab