未验证 提交 95d79b6d 编写于 作者: C ceci3 提交者: GitHub

update error log for batch_norm_grad (#22017)

* update error information about batch_norm_grad

* update bn,test=develop
上级 985e4bae
...@@ -528,10 +528,18 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -528,10 +528,18 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance"); const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
// Get the size for each dimension. // Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width] // NCHW [batch_size, in_channels, in_height, in_width]
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
......
...@@ -423,6 +423,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -423,6 +423,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
......
...@@ -2461,6 +2461,7 @@ def batch_norm(input, ...@@ -2461,6 +2461,7 @@ def batch_norm(input,
Note: Note:
if build_strategy.sync_batch_norm=True, the batch_norm in network will use if build_strategy.sync_batch_norm=True, the batch_norm in network will use
sync_batch_norm automatically. sync_batch_norm automatically.
`is_test = True` can only be used in test program and inference program, `is_test` CANNOT be set to True in train program, if you want to use global status from pre_train model in train program, please set `use_global_stats = True`.
Args: Args:
input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册