diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index fd2bb6e5c995222cdabedefab93cd696c7c3d9e1..b4154737e0fbc6245617fb0208f6623e4ebb5943 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -160,7 +160,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) { for (auto& pair : new_ins) { // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. - if ((op_type == "batch_norm" || op_type == "layer_norm") && + if ((op_type == "batch_norm" || op_type == "layer_norm" || + op_type == "sync_batch_norm") && pair.first != "X") { continue; } @@ -191,7 +192,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, } for (auto& pair : new_ins) { // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. - if ((op_type == "batch_norm" || op_type == "layer_norm") && + if ((op_type == "batch_norm" || op_type == "layer_norm" || + op_type == "sync_batch_norm") && pair.first == "X" && dst_type == framework::proto::VarType::FP32) { continue; }