diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index fc31885824b55f22bba77559d728a1e40d47e784..edad20435b41c9eb59c3df793c00ab3bfe96771b 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -575,7 +575,7 @@ class BatchNormGradKernel // SavedVariance have been reverted in forward operator const auto *saved_inv_variance = ctx.Input("SavedVariance"); const std::string data_layout_str = ctx.Attr("data_layout"); - const bool use_global_stats = ctx.Attr("use_global_stats"); + bool use_global_stats = ctx.Attr("use_global_stats"); const bool is_test = ctx.Attr("is_test"); const float epsilon = ctx.Attr("epsilon"); const DataLayout data_layout = @@ -585,6 +585,8 @@ class BatchNormGradKernel auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + use_global_stats = is_test || use_global_stats; + // batch_norm with inplace as false will take X as grad input, which // is same as cuDNN batch_norm backward calculation, batch_norm // with inplace as true only take Y as input and X should be calculate @@ -605,13 +607,6 @@ class BatchNormGradKernel "X@GRAD and Y@GRAD inplaced in non-inplace mode")); } - PADDLE_ENFORCE_EQ( - is_test, false, - platform::errors::InvalidArgument( - "`is_test = True` CANNOT be used in train program. If " - "you want to use global status in pre_train model, " - "please set `use_global_stats = True`")); - // Get the size for each dimension. // NCHW [batch_size, in_channels, in_height, in_width] const auto &x_dims = x->dims(); diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 41dc87ac1ba4788b89ad0a0dd01c7aba981fd746..6fc78732b1063af04a34de5d690a4f2ed75978f2 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -817,7 +817,7 @@ class BatchNormGradKernel platform::errors::InvalidArgument("It must use CUDAPlace.")); double epsilon = static_cast(ctx.Attr("epsilon")); const std::string data_layout_str = ctx.Attr("data_layout"); - const bool use_global_stats = ctx.Attr("use_global_stats"); + bool use_global_stats = ctx.Attr("use_global_stats"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); @@ -850,12 +850,7 @@ class BatchNormGradKernel } const bool is_test = ctx.Attr("is_test"); - PADDLE_ENFORCE_EQ( - is_test, false, - platform::errors::InvalidArgument( - "`is_test = True` CANNOT be used in train program. If " - "you want to use global status in pre_train model, " - "please set `use_global_stats = True`")); + use_global_stats = is_test || use_global_stats; const auto &x_dims = x->dims();