未验证 提交 bd8d35a2 编写于 作者: C ceci3 提交者: GitHub

remove is_test=True in grad (#32678)

上级 5ada0329
......@@ -575,7 +575,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
// SavedVariance have been reverted in forward operator
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout =
......@@ -585,6 +585,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
use_global_stats = is_test || use_global_stats;
// batch_norm with inplace as false will take X as grad input, which
// is same as cuDNN batch_norm backward calculation, batch_norm
// with inplace as true only take Y as input and X should be calculate
......@@ -605,13 +607,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
}
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
const auto &x_dims = x->dims();
......
......@@ -817,7 +817,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
platform::errors::InvalidArgument("It must use CUDAPlace."));
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
......@@ -850,12 +850,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
use_global_stats = is_test || use_global_stats;
const auto &x_dims = x->dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册