From 95d79b6d00baadb2cb2af2d84e80907643486ea3 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 2 Jan 2020 17:28:40 +0800 Subject: [PATCH] update error log for batch_norm_grad (#22017) * update error information about batch_norm_grad * update bn,test=develop --- paddle/fluid/operators/batch_norm_op.cc | 8 ++++++++ paddle/fluid/operators/batch_norm_op.cu | 7 +++++++ python/paddle/fluid/layers/nn.py | 1 + 3 files changed, 16 insertions(+) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 8152660826..e03c9cfcb8 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -528,10 +528,18 @@ class BatchNormGradKernel const auto *saved_inv_variance = ctx.Input("SavedVariance"); const std::string data_layout_str = ctx.Attr("data_layout"); const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool is_test = ctx.Attr("is_test"); const float epsilon = ctx.Attr("epsilon"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + PADDLE_ENFORCE_EQ( + is_test, false, + platform::errors::InvalidArgument( + "`is_test = True` CANNOT be used in train program. If " + "you want to use global status in pre_train model, " + "please set `use_global_stats = True`")); + // Get the size for each dimension. // NCHW [batch_size, in_channels, in_height, in_width] const auto &x_dims = x->dims(); diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index f8e2d9f393..a47e20dde5 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -423,6 +423,13 @@ class BatchNormGradKernel const auto *x = ctx.Input("X"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto *scale = ctx.Input("Scale"); + const bool is_test = ctx.Attr("is_test"); + PADDLE_ENFORCE_EQ( + is_test, false, + platform::errors::InvalidArgument( + "`is_test = True` CANNOT be used in train program. If " + "you want to use global status in pre_train model, " + "please set `use_global_stats = True`")); const auto &x_dims = x->dims(); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6b65aaf356..63425fbc12 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2461,6 +2461,7 @@ def batch_norm(input, Note: if build_strategy.sync_batch_norm=True, the batch_norm in network will use sync_batch_norm automatically. + `is_test = True` can only be used in test program and inference program, `is_test` CANNOT be set to True in train program, if you want to use global status from pre_train model in train program, please set `use_global_stats = True`. Args: input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type -- GitLab