From 1a417a4c74364ec5d1ce5bbd411fee0d2c76041b Mon Sep 17 00:00:00 2001 From: ceci3 Date: Fri, 30 Apr 2021 12:58:07 +0800 Subject: [PATCH] remove is_test=True in grad (#32683) --- paddle/fluid/operators/batch_norm_op.cc | 11 +++-------- paddle/fluid/operators/batch_norm_op.cu | 9 ++------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index fc31885824..edad20435b 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -575,7 +575,7 @@ class BatchNormGradKernel // SavedVariance have been reverted in forward operator const auto *saved_inv_variance = ctx.Input("SavedVariance"); const std::string data_layout_str = ctx.Attr("data_layout"); - const bool use_global_stats = ctx.Attr("use_global_stats"); + bool use_global_stats = ctx.Attr("use_global_stats"); const bool is_test = ctx.Attr("is_test"); const float epsilon = ctx.Attr("epsilon"); const DataLayout data_layout = @@ -585,6 +585,8 @@ class BatchNormGradKernel auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + use_global_stats = is_test || use_global_stats; + // batch_norm with inplace as false will take X as grad input, which // is same as cuDNN batch_norm backward calculation, batch_norm // with inplace as true only take Y as input and X should be calculate @@ -605,13 +607,6 @@ class BatchNormGradKernel "X@GRAD and Y@GRAD inplaced in non-inplace mode")); } - PADDLE_ENFORCE_EQ( - is_test, false, - platform::errors::InvalidArgument( - "`is_test = True` CANNOT be used in train program. If " - "you want to use global status in pre_train model, " - "please set `use_global_stats = True`")); - // Get the size for each dimension. // NCHW [batch_size, in_channels, in_height, in_width] const auto &x_dims = x->dims(); diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 41dc87ac1b..6fc78732b1 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -817,7 +817,7 @@ class BatchNormGradKernel platform::errors::InvalidArgument("It must use CUDAPlace.")); double epsilon = static_cast(ctx.Attr("epsilon")); const std::string data_layout_str = ctx.Attr("data_layout"); - const bool use_global_stats = ctx.Attr("use_global_stats"); + bool use_global_stats = ctx.Attr("use_global_stats"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); @@ -850,12 +850,7 @@ class BatchNormGradKernel } const bool is_test = ctx.Attr("is_test"); - PADDLE_ENFORCE_EQ( - is_test, false, - platform::errors::InvalidArgument( - "`is_test = True` CANNOT be used in train program. If " - "you want to use global status in pre_train model, " - "please set `use_global_stats = True`")); + use_global_stats = is_test || use_global_stats; const auto &x_dims = x->dims(); -- GitLab