From dec9094d1be5ceb5623d46e83ca71dbd67ab5a12 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 22 Mar 2022 11:22:38 +0800 Subject: [PATCH] polish batch norm sig (#40746) --- paddle/phi/ops/compat/batch_norm_sig.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/phi/ops/compat/batch_norm_sig.cc b/paddle/phi/ops/compat/batch_norm_sig.cc index fa1fac5d23..803bb50b43 100644 --- a/paddle/phi/ops/compat/batch_norm_sig.cc +++ b/paddle/phi/ops/compat/batch_norm_sig.cc @@ -18,10 +18,17 @@ namespace phi { KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) { bool is_test = paddle::any_cast(ctx.Attr("is_test")); - bool use_global_stats = paddle::any_cast(ctx.Attr("use_global_stats")); + bool use_global_stats = + ctx.HasAttr("use_global_stats") + ? paddle::any_cast(ctx.Attr("use_global_stats")) + : false; bool trainable_statistics = - paddle::any_cast(ctx.Attr("trainable_statistics")); - bool fuse_with_relu = paddle::any_cast(ctx.Attr("fuse_with_relu")); + ctx.HasAttr("trainable_statistics") + ? paddle::any_cast(ctx.Attr("trainable_statistics")) + : false; + bool fuse_with_relu = ctx.HasAttr("fuse_with_relu") + ? paddle::any_cast(ctx.Attr("fuse_with_relu")) + : false; // Dispenable `MomentumTensor` is useless now if (is_test && !use_global_stats && !trainable_statistics && !fuse_with_relu) { -- GitLab