未验证 提交 10fbb831 编写于 作者: Q qingqing01 提交者: GitHub

Skip BatchNorm when feature only has 1 element. (#11578)

* Fix batch norm when only 1 elements in normzalize dimension during training.
上级 110c6aed
......@@ -216,6 +216,18 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_mean_e.setZero();
saved_variance_e.setZero();
EigenVectorArrayMap<T> running_mean_arr(
mean_out->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), C);
if ((N * sample_size) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x.";
framework::TensorCopySync(*x, ctx.GetPlace(), y);
return;
}
switch (data_layout) {
case DataLayout::kNCHW: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
......@@ -247,10 +259,6 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
PADDLE_THROW("Unknown storage order: %s", data_layout_str);
}
EigenVectorArrayMap<T> running_mean_arr(
mean_out->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), C);
running_mean_arr =
running_mean_arr * momentum + saved_mean_e * (1. - momentum);
running_var_arr =
......@@ -427,6 +435,11 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_bias_arr.setZero();
d_scale_arr.setZero();
if ((N * sample_size) == 1) {
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
return;
}
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size);
switch (data_layout) {
......
......@@ -72,6 +72,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
auto *y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
......@@ -93,7 +96,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
VLOG(1) << "Setting descriptors.";
VLOG(3) << "Setting descriptors.";
std::vector<int> dims;
std::vector<int> strides;
if (data_layout == DataLayout::kNCHW) {
......@@ -113,11 +116,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
auto *y = ctx.Output<Tensor>("Y");
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
......@@ -162,6 +160,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
if ((N * H * W * D) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x.";
framework::TensorCopySync(*x, ctx.GetPlace(), y);
} else {
double this_factor = 1. - momentum;
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
......@@ -179,6 +182,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
}
}
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
......@@ -209,6 +213,25 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if ((N * H * W * D) == 1) {
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
return;
}
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
......@@ -247,21 +270,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data = saved_mean->template data<T>();
const void *saved_var_data = saved_var->template data<T>();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
......
......@@ -124,8 +124,7 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"Tensor<float/double> with shape [N x D].");
AddOutput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor with shape "
"[N x 1]. The cross entropy loss.")
.Reuse("X");
"[N x 1]. The cross entropy loss.");
AddAttr<bool>("soft_label",
"(bool, default false), a flag indicating whether to "
"interpretate the given labels as soft labels.")
......
......@@ -40,7 +40,6 @@ class TestFakeDequantizeMaxAbsOp(OpTest):
self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype("float32")
yq, scale = quantize_max_abs(x, self.num_bits)
print 'scale ', scale
ydq = dequantize_max_abs(yq, self.num_bits, scale)
self.inputs = {'X': yq}
......
......@@ -113,7 +113,9 @@ class BaseParallelForTest(unittest.TestCase):
generator = callback()
# Automatically insert parallel do if use_parallel = True
if use_parallel:
places = fluid.layers.get_places()
thread_num = fluid.core.get_cuda_device_count(
) if use_gpu else 8
places = fluid.layers.get_places(thread_num)
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
data = next(generator)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册