未验证 提交 bd8d35a2 编写于 作者: C ceci3 提交者: GitHub

remove is_test=True in grad (#32678)

上级 5ada0329
...@@ -575,7 +575,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -575,7 +575,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
// SavedVariance have been reverted in forward operator // SavedVariance have been reverted in forward operator
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance"); const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout = const DataLayout data_layout =
...@@ -585,6 +585,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -585,6 +585,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
use_global_stats = is_test || use_global_stats;
// batch_norm with inplace as false will take X as grad input, which // batch_norm with inplace as false will take X as grad input, which
// is same as cuDNN batch_norm backward calculation, batch_norm // is same as cuDNN batch_norm backward calculation, batch_norm
// with inplace as true only take Y as input and X should be calculate // with inplace as true only take Y as input and X should be calculate
...@@ -605,13 +607,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -605,13 +607,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
"X@GRAD and Y@GRAD inplaced in non-inplace mode")); "X@GRAD and Y@GRAD inplaced in non-inplace mode"));
} }
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
// Get the size for each dimension. // Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width] // NCHW [batch_size, in_channels, in_height, in_width]
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
......
...@@ -817,7 +817,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -817,7 +817,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
platform::errors::InvalidArgument("It must use CUDAPlace.")); platform::errors::InvalidArgument("It must use CUDAPlace."));
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
...@@ -850,12 +850,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -850,12 +850,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ( use_global_stats = is_test || use_global_stats;
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册