From 731d45a39ab1a076519cedb0167c04801d1e7c84 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 29 Nov 2018 15:18:56 +0800 Subject: [PATCH] Enable BatchNorm to use global mean and variane during training (#14630) * Enable BatchNorm to use global mean and variane during training * Update doc and follow comments. --- paddle/fluid/API.spec | 2 +- .../fluid/operators/batch_norm_mkldnn_op.cc | 17 +- paddle/fluid/operators/batch_norm_op.cc | 176 +++++++++---- .../{batch_norm_op.cu.cc => batch_norm_op.cu} | 234 +++++++++++++----- python/paddle/fluid/layers/nn.py | 30 ++- .../tests/unittests/test_batch_norm_op.py | 133 ++++++++-- .../fluid/tests/unittests/test_layers.py | 9 + 7 files changed, 456 insertions(+), 145 deletions(-) rename paddle/fluid/operators/{batch_norm_op.cu.cc => batch_norm_op.cu} (57%) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index c40f603341..6b5ed10244 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -69,7 +69,7 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'] paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)) paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)) -paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False)) +paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)) paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/batch_norm_mkldnn_op.cc index de641cb08e..29d950967f 100644 --- a/paddle/fluid/operators/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/batch_norm_mkldnn_op.cc @@ -146,7 +146,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { const float epsilon = ctx.Attr("epsilon"); const float momentum = ctx.Attr("momentum"); const bool is_test = ctx.Attr("is_test"); + const bool use_global_stats = ctx.Attr("use_global_stats"); const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); + bool global_stats = is_test || use_global_stats; const auto *x = ctx.Input("X"); const auto *mean = ctx.Input("Mean"); @@ -177,13 +179,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { T *batch_mean_data = nullptr; T *batch_variance_data = nullptr; - if (!is_test) { + if (!global_stats) { batch_mean_data = batch_mean->mutable_data(ctx.GetPlace()); batch_variance_data = batch_variance->mutable_data(ctx.GetPlace()); } - auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring - : mkldnn::prop_kind::forward_training; + auto propagation = global_stats == true + ? mkldnn::prop_kind::forward_scoring + : mkldnn::prop_kind::forward_training; auto src_tz = paddle::framework::vectorize2int(x->dims()); auto scale_tz = paddle::framework::vectorize2int(scale->dims()); @@ -199,7 +202,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { shift->data() + ic, &scaleshift_data); unsigned flags = mkldnn::use_scale_shift; - if (is_test) flags |= mkldnn::use_global_stats; + if (global_stats) flags |= mkldnn::use_global_stats; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; // create mkldnn memory from input x tensor @@ -208,7 +211,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { // keys for backward pass const std::string key = BatchNormMKLDNNHandler::GetHash( - src_tz, epsilon, flags, is_test, input_format, + src_tz, epsilon, flags, global_stats, input_format, ctx.op().Output("SavedMean")); const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; @@ -239,7 +242,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data); std::shared_ptr batch_norm_p; - if (is_test) { + if (global_stats) { // create mkldnn memory for stats (as input) std::shared_ptr mean_memory = handler.AcquireMeanMemoryFromPrimitive(to_void_cast(mean_data)); @@ -269,7 +272,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { pipeline.push_back(*batch_norm_p); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); - if (!is_test) { + if (!global_stats) { // mkldnn only compute stats for current batch // so we need compute momentum stats via Eigen lib EigenVectorArrayMap batch_mean_e(batch_mean_data, ic); diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 2463c939bc..f66813989c 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -159,6 +159,14 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("fuse_with_relu", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("use_global_stats", + "(bool, default false) Whether to use global mean and " + "variance. In inference or test mode, set use_global_stats " + "to true or is_test true. the behavior is equivalent. " + "In train mode, when setting use_global_stats True, the " + "global mean and variance are also used during train time, " + "the BN acts as scaling and shiffting.") + .SetDefault(false); AddComment(R"DOC( Batch Normalization. @@ -190,6 +198,10 @@ class BatchNormKernel const float epsilon = ctx.Attr("epsilon"); const float momentum = ctx.Attr("momentum"); const bool is_test = ctx.Attr("is_test"); + const bool use_global_stats = ctx.Attr("use_global_stats"); + + bool global_stats = is_test || use_global_stats; + const std::string data_layout_str = ctx.Attr("data_layout"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); @@ -217,7 +229,7 @@ class BatchNormKernel saved_mean->mutable_data(ctx.GetPlace()); saved_variance->mutable_data(ctx.GetPlace()); - if (!is_test) { + if (!global_stats) { // saved_xx is use just in this batch of data EigenVectorArrayMap saved_mean_e( saved_mean->mutable_data(ctx.GetPlace()), C); @@ -234,7 +246,7 @@ class BatchNormKernel if ((N * sample_size) == 1) { LOG(WARNING) << "Only 1 element in normalization dimension, " << "we skip the batch norm calculation, let y = x."; - framework::TensorCopySync(*x, ctx.GetPlace(), y); + framework::TensorCopy(*x, ctx.GetPlace(), y); return; } @@ -277,7 +289,7 @@ class BatchNormKernel // use SavedMean and SavedVariance to do normalize Eigen::Array inv_std(C); - if (is_test) { + if (global_stats) { ConstEigenVectorArrayMap var_arr( ctx.Input("Variance")->data(), C); inv_std = (var_arr + epsilon).sqrt().inverse(); @@ -289,8 +301,8 @@ class BatchNormKernel inv_std = saved_inv_std; } ConstEigenVectorArrayMap mean_arr( - is_test ? ctx.Input("Mean")->data() - : ctx.Output("SavedMean")->data(), + global_stats ? ctx.Input("Mean")->data() + : ctx.Output("SavedMean")->data(), C); // ((x - est_mean) * (inv_var) * scale + bias @@ -336,15 +348,27 @@ class BatchNormGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { // check input PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); - PADDLE_ENFORCE(ctx->HasInput("SavedMean"), ""); - PADDLE_ENFORCE(ctx->HasInput("SavedVariance"), ""); + PADDLE_ENFORCE(ctx->HasInput("Scale"), "Input(scale) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("SavedMean"), + "Input(SavedMean) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("SavedVariance"), + "Input(SavedVariance) should not be null"); // check output PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Scale")), ""); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), ""); + if (ctx->HasOutput(framework::GradVarName("Scale"))) { + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), + "Output(Scale@GRAD) and Output(Bias@GRAD) should not be " + "null at same time"); + } + const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); + if (use_global_stats) { + PADDLE_ENFORCE(!ctx->Attrs().Get("use_mkldnn"), + "Using global stats during training is not supported " + "in gradient op kernel of batch_norm_mkldnn_op now."); + } const auto x_dims = ctx->GetInputDim("X"); const DataLayout data_layout = framework::StringToDataLayout( @@ -354,8 +378,10 @@ class BatchNormGradOp : public framework::OperatorWithKernel { : x_dims[x_dims.size() - 1]); ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); - ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); + if (ctx->HasOutput(framework::GradVarName("Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); + ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); + } } protected: @@ -405,6 +431,8 @@ class BatchNormGradKernel // SavedVariance have been reverted in forward operator const auto *saved_inv_variance = ctx.Input("SavedVariance"); const std::string data_layout_str = ctx.Attr("data_layout"); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const float epsilon = ctx.Attr("epsilon"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); @@ -419,38 +447,60 @@ class BatchNormGradKernel : x_dims[x_dims.size() - 1]); const int sample_size = x->numel() / N / C; - ConstEigenVectorArrayMap scale_arr(scale->data(), C); - ConstEigenVectorArrayMap mean_arr(saved_mean->data(), C); - ConstEigenVectorArrayMap inv_var_arr(saved_inv_variance->data(), C); - // init output auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); d_x->mutable_data(ctx.GetPlace()); - d_scale->mutable_data(ctx.GetPlace()); - d_bias->mutable_data(ctx.GetPlace()); + + const T *mean_data = saved_mean->data(); + const T *inv_var_data = saved_inv_variance->data(); + Tensor inv_var_tensor; + if (use_global_stats) { + const auto *running_mean = ctx.Input("Mean"); + const auto *running_variance = ctx.Input("Variance"); + mean_data = running_mean->data(); + T *running_inv_var_data = inv_var_tensor.mutable_data(ctx.GetPlace()); + EigenVectorArrayMap inv_var_tmp(running_inv_var_data, C); + ConstEigenVectorArrayMap var_arr(running_variance->data(), C); + + inv_var_tmp = (var_arr + epsilon).sqrt().inverse().eval(); + inv_var_data = running_inv_var_data; + } + + ConstEigenVectorArrayMap scale_arr(scale->data(), C); + ConstEigenVectorArrayMap mean_arr(mean_data, C); + ConstEigenVectorArrayMap inv_var_arr(inv_var_data, C); + + T *d_bias_data = nullptr; + T *d_scale_data = nullptr; + if (d_scale && d_bias) { + d_scale->mutable_data(ctx.GetPlace()); + d_bias->mutable_data(ctx.GetPlace()); + d_bias_data = d_bias->mutable_data(ctx.GetPlace()); + d_scale_data = d_scale->mutable_data(ctx.GetPlace()); + } // d_bias = np.sum(d_y, axis=0) // d_scale = np.sum((X - mean) / inv_std * dy, axis=0) // d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0) // - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0)) + EigenVectorArrayMap d_bias_arr(d_bias_data, C); + EigenVectorArrayMap d_scale_arr(d_scale_data, C); - EigenVectorArrayMap d_bias_arr(d_bias->mutable_data(ctx.GetPlace()), - C); - EigenVectorArrayMap d_scale_arr(d_scale->mutable_data(ctx.GetPlace()), - C); - - d_bias_arr.setZero(); - d_scale_arr.setZero(); + if (d_scale && d_bias) { + d_bias_arr.setZero(); + d_scale_arr.setZero(); + } - if ((N * sample_size) == 1) { - framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x); + if ((N * sample_size) == 1 && !use_global_stats) { + framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); return; } - const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size); + int scale_coefff = use_global_stats ? 1 : N * sample_size; + const auto scale_inv_var_nhw = scale_arr * inv_var_arr / scale_coefff; switch (data_layout) { case DataLayout::kNCHW: { @@ -460,19 +510,29 @@ class BatchNormGradKernel sample_size, N * C); d_x_arr.setZero(); - for (int nc = 0; nc < N * C; ++nc) { - int c = nc % C; - d_bias_arr(c) += d_y_arr.col(nc).sum(); - d_scale_arr(c) += - ((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * d_y_arr.col(nc)) - .sum(); + if (d_scale && d_bias) { + for (int nc = 0; nc < N * C; ++nc) { + int c = nc % C; + d_bias_arr(c) += d_y_arr.col(nc).sum(); + d_scale_arr(c) += ((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * + d_y_arr.col(nc)) + .sum(); + } } - for (int nc = 0; nc < N * C; ++nc) { - int c = nc % C; - d_x_arr.col(nc) += - scale_inv_var_nhw(c) * - (d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) - - (x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) * inv_var_arr(c)); + if (!use_global_stats) { + for (int nc = 0; nc < N * C; ++nc) { + int c = nc % C; + d_x_arr.col(nc) += + scale_inv_var_nhw(c) * + (d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) - + (x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) * + inv_var_arr(c)); + } + } else { + for (int nc = 0; nc < N * C; ++nc) { + int c = nc % C; + d_x_arr.col(nc) += scale_inv_var_nhw(c) * d_y_arr.col(nc); + } } break; } @@ -488,15 +548,27 @@ class BatchNormGradKernel const auto d_y_mul_x_minus_mean_row_sum = (d_y_arr * x_minus_mean).rowwise().sum(); const auto inv_var_sqr = inv_var_arr * inv_var_arr; - for (int nhw = 0; nhw < N * sample_size; ++nhw) { - d_bias_arr += d_y_arr.col(nhw); - d_scale_arr += - (x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw); - d_x_arr.col(nhw) += - scale_inv_var_nhw * - (d_y_arr.col(nhw) * N * sample_size - d_y_row_sum - - x_minus_mean.col(nhw) * inv_var_sqr * - d_y_mul_x_minus_mean_row_sum); + + if (d_scale && d_bias) { + for (int nhw = 0; nhw < N * sample_size; ++nhw) { + d_bias_arr += d_y_arr.col(nhw); + d_scale_arr += + (x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw); + } + } + + if (!use_global_stats) { + for (int nhw = 0; nhw < N * sample_size; ++nhw) { + d_x_arr.col(nhw) += + scale_inv_var_nhw * + (d_y_arr.col(nhw) * N * sample_size - d_y_row_sum - + x_minus_mean.col(nhw) * inv_var_sqr * + d_y_mul_x_minus_mean_row_sum); + } + } else { + for (int nhw = 0; nhw < N * sample_size; ++nhw) { + d_x_arr.col(nhw) += scale_inv_var_nhw * d_y_arr.col(nhw); + } } break; } @@ -522,6 +594,10 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("SavedMean", Output("SavedMean")); op->SetInput("SavedVariance", Output("SavedVariance")); + // used when setting use_global_stats True during training + op->SetInput("Mean", Output("MeanOut")); + op->SetInput("Variance", Output("VarianceOut")); + op->SetAttrMap(Attrs()); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu similarity index 57% rename from paddle/fluid/operators/batch_norm_op.cu.cc rename to paddle/fluid/operators/batch_norm_op.cu index aaed335c90..1c45746a92 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/batch_norm_op.h" +#include #include +#include +#include +#include "cub/cub.cuh" #include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" @@ -59,6 +63,7 @@ class BatchNormKernel double epsilon = static_cast(ctx.Attr("epsilon")); const float momentum = ctx.Attr("momentum"); const bool is_test = ctx.Attr("is_test"); + const bool use_global_stats = ctx.Attr("use_global_stats"); const std::string data_layout_str = ctx.Attr("data_layout"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); @@ -121,7 +126,7 @@ class BatchNormKernel auto handle = dev_ctx.cudnn_handle(); // Now, depending on whether we are running test or not, we have two paths. - if (is_test) { + if (is_test || use_global_stats) { // only when test we use input to do computation. const auto *est_mean = ctx.Input("Mean"); const auto *est_var = ctx.Input("Variance"); @@ -163,7 +168,7 @@ class BatchNormKernel if ((N * H * W * D) == 1) { LOG(WARNING) << "Only 1 element in normalization dimension, " << "we skip the batch norm calculation, let y = x."; - framework::TensorCopySync(*x, ctx.GetPlace(), y); + framework::TensorCopy(*x, ctx.GetPlace(), y); } else { double this_factor = 1. - momentum; @@ -191,6 +196,58 @@ class BatchNormKernel } }; +template +static __global__ void KeBNBackwardData(const T *dy, + const BatchNormParamType *scale, + const BatchNormParamType *variance, + const double epsilon, const int C, + const int HxW, const int num, T *dx) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; + BatchNormParamType inv_var = 1.0 / sqrt(variance[c] + epsilon); + dx[i] = static_cast(static_cast>(dy[i]) * + scale[c] * inv_var); + } +} + +template +static __global__ void KeBNBackwardScaleBias( + const T *dy, const T *x, const BatchNormParamType *mean, + const BatchNormParamType *variance, const double epsilon, const int N, + const int C, const int HxW, BatchNormParamType *dscale, + BatchNormParamType *dbias) { + const int outer_size = C; + const int inner_size = N * HxW; + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage ds_storage; + __shared__ typename BlockReduce::TempStorage db_storage; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + BatchNormParamType ds_sum = static_cast>(0); + BatchNormParamType db_sum = static_cast>(0); + + BatchNormParamType inv_var_i = 1.0 / sqrt(variance[i] + epsilon); + BatchNormParamType mean_i = mean[i]; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == framework::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + ds_sum += static_cast>(dy[index]) * + (static_cast>(x[index]) - mean_i); + db_sum += static_cast>(dy[index]); + } + ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum()); + db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum()); + if (threadIdx.x == 0) { + dscale[i] = ds_sum * inv_var_i; + dbias[i] = db_sum; + } + __syncthreads(); + } +} + template class BatchNormGradKernel : public framework::OpKernel { @@ -200,6 +257,8 @@ class BatchNormGradKernel "It must use CUDAPlace."); double epsilon = static_cast(ctx.Attr("epsilon")); const std::string data_layout_str = ctx.Attr("data_layout"); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); const auto *x = ctx.Input("X"); @@ -219,42 +278,13 @@ class BatchNormGradKernel auto *d_bias = ctx.Output(framework::GradVarName("Bias")); d_x->mutable_data(ctx.GetPlace()); - d_scale->mutable_data>(ctx.GetPlace()); - d_bias->mutable_data>(ctx.GetPlace()); - - auto &dev_ctx = ctx.template device_context(); - if ((N * H * W * D) == 1) { - framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x); - math::SetConstant> - functor; - functor(dev_ctx, d_scale, static_cast>(0)); - functor(dev_ctx, d_bias, static_cast>(0)); - return; + if (d_scale && d_bias) { + d_scale->mutable_data>(ctx.GetPlace()); + d_bias->mutable_data>(ctx.GetPlace()); } - PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); PADDLE_ENFORCE_EQ(scale->dims()[0], C); - // ------------------- cudnn descriptors --------------------- - cudnnTensorDescriptor_t data_desc_; - cudnnTensorDescriptor_t bn_param_desc_; - cudnnBatchNormMode_t mode_; - - CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); - CUDNN_ENFORCE( - platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); - if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { - LOG(ERROR) << "Provided epsilon is smaller than " - << "CUDNN_BN_MIN_EPSILON. Setting it to " - << "CUDNN_BN_MIN_EPSILON instead."; - } - epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); -#if CUDNN_VERSION_MIN(7, 0, 0) - mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; -#else - mode_ = CUDNN_BATCHNORM_SPATIAL; -#endif - std::vector dims; std::vector strides; if (data_layout == DataLayout::kNCHW) { @@ -264,34 +294,114 @@ class BatchNormGradKernel dims = {N, C, H, W, D}; strides = {H * W * C * D, 1, W * D * C, D * C, C}; } - CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( - data_desc_, CudnnDataType::type, - x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); - CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( - bn_param_desc_, data_desc_, mode_)); - - const auto *saved_mean = ctx.Input("SavedMean"); - const auto *saved_var = ctx.Input("SavedVariance"); - const void *saved_mean_data = - saved_mean->template data>(); - const void *saved_var_data = - saved_var->template data>(); - - CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( - dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), - CudnnDataType::kZero(), CudnnDataType::kOne(), - CudnnDataType::kZero(), data_desc_, x->template data(), - data_desc_, d_y->template data(), data_desc_, - d_x->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - d_scale->template mutable_data>(ctx.GetPlace()), - d_bias->template mutable_data>(ctx.GetPlace()), - epsilon, saved_mean_data, saved_var_data)); - // clean when exit. - CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); - CUDNN_ENFORCE( - platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); + auto &dev_ctx = ctx.template device_context(); + if (!use_global_stats) { + if ((N * H * W * D) == 1) { + framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); + math::SetConstant> + functor; + functor(dev_ctx, d_scale, static_cast>(0)); + functor(dev_ctx, d_bias, static_cast>(0)); + return; + } + + // ------------------- cudnn descriptors --------------------- + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_; + + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); +#if CUDNN_VERSION_MIN(7, 0, 0) + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + mode_ = CUDNN_BATCHNORM_SPATIAL; +#endif + + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); + CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + const auto *saved_mean = ctx.Input("SavedMean"); + const auto *saved_var = ctx.Input("SavedVariance"); + const void *saved_mean_data = + saved_mean->template data>(); + const void *saved_var_data = + saved_var->template data>(); + + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( + dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), + CudnnDataType::kZero(), CudnnDataType::kOne(), + CudnnDataType::kZero(), data_desc_, x->template data(), + data_desc_, d_y->template data(), data_desc_, + d_x->template mutable_data(ctx.GetPlace()), bn_param_desc_, + scale->template data>(), + d_scale->template mutable_data>(ctx.GetPlace()), + d_bias->template mutable_data>(ctx.GetPlace()), + epsilon, saved_mean_data, saved_var_data)); + + // clean when exit. + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); + } else { + const auto *running_mean = ctx.Input("Mean"); + const auto *running_var = ctx.Input("Variance"); + + const auto *running_mean_data = + running_mean->template data>(); + const auto *running_var_data = + running_var->template data>(); + + const int num = x->numel(); + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + int grid1 = (num + block - 1) / block; + int grid2 = std::min(C, max_blocks); + + if (data_layout == framework::DataLayout::kNCHW) { + if (d_x) { + KeBNBackwardData<<< + grid1, block, 0, dev_ctx.stream()>>>( + d_y->data(), scale->data>(), + running_var_data, epsilon, C, H * W, num, d_x->data()); + } + if (d_scale && d_bias) { + KeBNBackwardScaleBias<<< + grid2, block, 0, dev_ctx.stream()>>>( + d_y->data(), x->data(), running_mean_data, running_var_data, + epsilon, C, H * W, num, d_scale->data>(), + d_bias->data>()); + } + } else { + if (d_x) { + KeBNBackwardData<<< + grid1, block, 0, dev_ctx.stream()>>>( + d_y->data(), scale->data>(), + running_var_data, epsilon, C, H * W, num, d_x->data()); + } + if (d_scale && d_bias) { + KeBNBackwardScaleBias<<< + grid2, block, 0, dev_ctx.stream()>>>( + d_y->data(), x->data(), running_mean_data, running_var_data, + epsilon, C, H * W, num, d_scale->data>(), + d_bias->data>()); + } + } + } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4df74edfce..b1fc6b808c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2300,7 +2300,8 @@ def batch_norm(input, moving_mean_name=None, moving_variance_name=None, do_model_average_for_mean_and_var=False, - fuse_with_relu=False): + fuse_with_relu=False, + use_global_stats=False): """ **Batch Normalization Layer** @@ -2327,6 +2328,19 @@ def batch_norm(input, \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + When use_global_stats = True, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. + They are global (or running) statistics. (It usually got from the + pre-trained model.) + The training and testing (or inference) have the same behavior: + + .. math:: + + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta + Args: input(variable): The input variable which is a LoDTensor. act(string, Default None): Activation type, linear|relu|prelu|... @@ -2349,6 +2363,11 @@ def batch_norm(input, moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not. fuse_with_relu (bool): if True, this OP performs relu after batch norm. + use_global_stats(bool, Default False): Whether to use global mean and + variance. In inference or test mode, set use_global_stats to true + or is_test to true, and the behavior is equivalent. + In train mode, when setting use_global_stats True, the global mean + and variance are also used during train period. Returns: Variable: A tensor variable which is the result after applying batch normalization on the input. @@ -2381,9 +2400,15 @@ def batch_norm(input, shape=param_shape, dtype=dtype, default_initializer=Constant(1.0)) + # setting stop_gradient=True to reduce computation + if use_global_stats and helper.param_attr.learning_rate == 0.: + scale.stop_gradient = True bias = helper.create_parameter( attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) + # setting stop_gradient=True to reduce computation + if use_global_stats and helper.bias_attr.learning_rate == 0.: + scale.stop_gradient = True mean = helper.create_parameter( attr=ParamAttr( @@ -2439,7 +2464,8 @@ def batch_norm(input, "epsilon": epsilon, "is_test": is_test, "use_mkldnn": False, - "fuse_with_relu": fuse_with_relu + "fuse_with_relu": fuse_with_relu, + "use_global_stats": use_global_stats }) return helper.append_activation(batch_norm_out) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 80261eff4e..2869a6ba53 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -54,6 +54,19 @@ def _reference_testing(x, scale, offset, mean, var, epsilon, data_format): return y +def _cal_mean_variance(x, epsilon, data_format): + assert data_format in ['NCHW', 'NHWC'] + x_square = x * x + axis = (0, 2, 3) if data_format == 'NCHW' else (0, 1, 2) + C = x.shape[1] if data_format == 'NCHW' else x.shape[-1] + x_square_sum = np.sum(x_square, axis) + x_sum = np.sum(x, axis=axis) + element_count = np.size(x) / C + mean = x_sum / element_count + var = x_square_sum / element_count - mean * mean + return mean, var + + def _reference_training(x, scale, offset, epsilon, data_format): x_shape = x.shape @@ -294,7 +307,18 @@ class TestBatchNormOpTraining(unittest.TestCase): self.use_mkldnn = False self.fuse_with_relu = False self.data_formats = ["NCHW", "NHWC"] + self.momentum = 0.9 + self.epsilon = 0.00001 self.init_kernel_type() + self.init_test_case() + + def init_test_case(self): + self.use_global_stats = False + self.no_grad_set = set() + self.fetch_list = [ + 'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD', + 'scale@GRAD', 'bias@GRAD' + ] def __assert_close(self, tensor, np_array, msg, atol=1e-4): np.allclose(np.array(tensor), np_array, atol=atol) @@ -313,11 +337,22 @@ class TestBatchNormOpTraining(unittest.TestCase): return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad + def set_mean_variance(self, scale_shape, x, data_layout): + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + # computing global mean/variance for one step + if self.use_global_stats: + mom = self.momentum + x_mean, x_var = _cal_mean_variance(x, self.epsilon, data_layout) + mean = x_mean * (1. - mom) + mom * mean + variance = x_var * (1. - mom) + mom * variance + return mean, variance + def test_forward_backward(self): def test_with_place(place, data_layout, shape): # attr - epsilon = 0.00001 - momentum = 0.9 + epsilon = self.epsilon + momentum = self.momentum if data_layout == "NCHW": n, c, h, w = shape[0], shape[1], shape[2], shape[3] else: @@ -328,9 +363,7 @@ class TestBatchNormOpTraining(unittest.TestCase): x = np.random.random_sample(shape).astype(np.float32) scale = np.random.random_sample(scale_shape).astype(np.float32) bias = np.random.random_sample(scale_shape).astype(np.float32) - mean = np.zeros(scale_shape).astype(np.float32) - variance = np.ones(scale_shape).astype(np.float32) - + mean, variance = self.set_mean_variance(scale_shape, x, data_layout) y_grad = np.random.random_sample(shape).astype(np.float32) y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward( @@ -339,6 +372,9 @@ class TestBatchNormOpTraining(unittest.TestCase): var_dict = locals() var_dict['y@GRAD'] = y_grad + var_dict['x@GRAD'] = x_grad + var_dict['scale@GRAD'] = scale_grad + var_dict['bias@GRAD'] = bias_grad var_names = [ 'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean', @@ -365,9 +401,8 @@ class TestBatchNormOpTraining(unittest.TestCase): }, outputs={ "Y": block.var('y'), - "MeanOut": block.var('mean'), # share the same memory - "VarianceOut": - block.var('variance'), # share the same memory + "MeanOut": block.var('mean'), # share memory + "VarianceOut": block.var('variance'), # share memory "SavedMean": block.var('saved_mean'), "SavedVariance": block.var('saved_variance') }, @@ -377,13 +412,14 @@ class TestBatchNormOpTraining(unittest.TestCase): "is_test": False, "data_layout": data_layout, "use_mkldnn": self.use_mkldnn, - "fuse_with_relu": self.fuse_with_relu + "fuse_with_relu": self.fuse_with_relu, + "use_global_stats": self.use_global_stats }) block.create_var(name='y@GRAD', dtype='float32', shape=y.shape) # generate backward op_desc grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( - bn_op.desc, set(), []) + bn_op.desc, self.no_grad_set, []) grad_op_desc = grad_op_desc_list[0] new_op_desc = block.desc.append_op() new_op_desc.copy_from(grad_op_desc) @@ -403,20 +439,10 @@ class TestBatchNormOpTraining(unittest.TestCase): for name in ['x', 'scale', 'bias', 'mean', 'variance', 'y@GRAD'] }, - fetch_list=[ - 'y', 'mean', 'variance', 'saved_mean', 'saved_variance', - 'x@GRAD', 'scale@GRAD', 'bias@GRAD' - ]) - - self.__assert_close(y, out[0], "y") - self.__assert_close(mean_out, out[1], "mean") - self.__assert_close(variance_out, out[2], "variance", 1e-3) - self.__assert_close(saved_mean, out[3], "saved_mean") - self.__assert_close(saved_variance, out[4], "saved_variance", 1e-3) - self.__assert_close(x_grad, out[5], "x_grad") - self.__assert_close(scale_grad, out[6], "scale_grad") - self.__assert_close(bias_grad, out[7], "bias_grad") + fetch_list=self.fetch_list) + for id, name in enumerate(self.fetch_list): + self.__assert_close(var_dict[name], out[id], name) print("op test forward passed: ", str(place), data_layout) places = [core.CPUPlace()] @@ -432,5 +458,66 @@ class TestBatchNormOpTraining(unittest.TestCase): pass +class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining): + def init_test_case(self): + self.use_global_stats = True + self.no_grad_set = set() + self.fetch_list = [ + 'y', 'mean', 'variance', 'x@GRAD', 'scale@GRAD', 'bias@GRAD' + ] + + def reference_grad(self, x, y_grad, scale, mean, var, epsilon, data_format): + if data_format == "NCHW": + x = np.transpose(x, (0, 2, 3, 1)) + y_grad = np.transpose(y_grad, (0, 2, 3, 1)) + + x_grad = scale * y_grad / np.sqrt(var + epsilon) + grad_scale = np.sum(y_grad * (x - mean) / np.sqrt(var + epsilon), + axis=(0, 1, 2)) + grad_offset = np.sum(y_grad, axis=(0, 1, 2)) + + # transfer back to N, C, H, W + if data_format == "NCHW": + x_grad = np.transpose(x_grad, (0, 3, 1, 2)) + x = np.transpose(x, (0, 3, 1, 2)) + y_grad = np.transpose(y_grad, (0, 3, 1, 2)) + + return x_grad, grad_scale, grad_offset + + def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance, + epsilon, momentum, shape, data_layout): + if data_layout != "NCHW" and data_layout != "NHWC": + raise ValueError("Unknown data order.") + + if data_layout == "NCHW": + x = np.transpose(x, (0, 2, 3, 1)) + + # run normalizaton + normalized = (x - mean) / np.sqrt(variance + epsilon) + y = normalized * scale + bias + + # transfer back to N, C, H, W + if data_layout == "NCHW": + x = np.transpose(x, (0, 3, 1, 2)) + y = np.transpose(y, (0, 3, 1, 2)) + + mean_out = mean + variance_out = variance + saved_variance = 1. / np.sqrt(variance + epsilon) + # run backward + x_grad, scale_grad, bias_grad = self.reference_grad( + x, y_grad, scale, mean, variance, epsilon, data_layout) + + return y, mean_out, variance_out, mean, saved_variance, x_grad, scale_grad, bias_grad + + +class TestBatchNormOpFreezeStatsAndScaleBiasTraining( + TestBatchNormOpFreezeStatsTraining): + def init_test_case(self): + self.use_global_stats = True + self.no_grad_set = set(['scale@GRAD', 'bias@GRAD']) + self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD'] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 5411607711..2004c91793 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -955,6 +955,15 @@ class TestBook(unittest.TestCase): print(str(program)) + def test_batch_norm(self): + program = Program() + with program_guard(program): + data = layers.data( + name='data', shape=[32, 128, 128], dtype="float32") + out = layers.batch_norm(data) + + print(str(program)) + if __name__ == '__main__': unittest.main() -- GitLab