diff --git a/paddle/phi/ops/compat/batch_norm_sig.cc b/paddle/phi/ops/compat/batch_norm_sig.cc index fa1fac5d23779597fee7f8a6e4e467c02d6d4c94..803bb50b438a58eb49beb50caa43f8fb0f408f8c 100644 --- a/paddle/phi/ops/compat/batch_norm_sig.cc +++ b/paddle/phi/ops/compat/batch_norm_sig.cc @@ -18,10 +18,17 @@ namespace phi { KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) { bool is_test = paddle::any_cast(ctx.Attr("is_test")); - bool use_global_stats = paddle::any_cast(ctx.Attr("use_global_stats")); + bool use_global_stats = + ctx.HasAttr("use_global_stats") + ? paddle::any_cast(ctx.Attr("use_global_stats")) + : false; bool trainable_statistics = - paddle::any_cast(ctx.Attr("trainable_statistics")); - bool fuse_with_relu = paddle::any_cast(ctx.Attr("fuse_with_relu")); + ctx.HasAttr("trainable_statistics") + ? paddle::any_cast(ctx.Attr("trainable_statistics")) + : false; + bool fuse_with_relu = ctx.HasAttr("fuse_with_relu") + ? paddle::any_cast(ctx.Attr("fuse_with_relu")) + : false; // Dispenable `MomentumTensor` is useless now if (is_test && !use_global_stats && !trainable_statistics && !fuse_with_relu) {