From 1e62c239d323354eccfc974d4e2e6496f93d848e Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Tue, 22 Jun 2021 16:09:04 +0800 Subject: [PATCH] Dynamic amp support sync_batch_norm op (#32770) (#33709) --- paddle/fluid/imperative/amp_auto_cast.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index fd2bb6e5c9..b4154737e0 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -160,7 +160,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) { for (auto& pair : new_ins) { // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. - if ((op_type == "batch_norm" || op_type == "layer_norm") && + if ((op_type == "batch_norm" || op_type == "layer_norm" || + op_type == "sync_batch_norm") && pair.first != "X") { continue; } @@ -191,7 +192,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, } for (auto& pair : new_ins) { // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. - if ((op_type == "batch_norm" || op_type == "layer_norm") && + if ((op_type == "batch_norm" || op_type == "layer_norm" || + op_type == "sync_batch_norm") && pair.first == "X" && dst_type == framework::proto::VarType::FP32) { continue; } -- GitLab