/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "cub/cub.cuh" #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The // reason we set it to false by default is that this mode may use scaled // atomic integer reduction that may cause a numerical overflow for certain // input data range. DEFINE_bool(cudnn_batchnorm_spatial_persistent, false, "Whether enable CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode for cudnn " "batch_norm, default is False."); namespace paddle { namespace operators { using Tensor = framework::Tensor; using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template using BatchNormParamType = typename CudnnDataType::BatchNormParamType; template class BatchNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); double epsilon = static_cast(ctx.Attr("epsilon")); const float momentum = ctx.Attr("momentum"); const bool is_test = ctx.Attr("is_test"); const bool use_global_stats = ctx.Attr("use_global_stats"); const std::string data_layout_str = ctx.Attr("data_layout"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); // Get the size for each dimension. // NCHW [batch_size, in_channels, in_height, in_width] const auto *x = ctx.Input("X"); const auto &x_dims = x->dims(); PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, "The Input dim size should be between 2 and 5"); int N, C, H, W, D; ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); auto *y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); // ------------------- cudnn descriptors --------------------- cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t bn_param_desc_; cudnnBatchNormMode_t mode_; CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); CUDNN_ENFORCE( platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " << "CUDNN_BN_MIN_EPSILON instead."; } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); #if CUDNN_VERSION_MIN(7, 0, 0) if (FLAGS_cudnn_batchnorm_spatial_persistent) { mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; } else { mode_ = CUDNN_BATCHNORM_SPATIAL; } #else mode_ = CUDNN_BATCHNORM_SPATIAL; #endif VLOG(3) << "Setting descriptors."; std::vector dims; std::vector strides; if (data_layout == DataLayout::kNCHW) { dims = {N, C, H, W, D}; strides = {C * H * W * D, H * W * D, W * D, D, 1}; } else { dims = {N, C, H, W, D}; strides = {H * W * D * C, 1, W * D * C, D * C, C}; } CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); // Note: PERSISTENT not implemented for inference CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_)); const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); // Now, depending on whether we are running test or not, we have two paths. if (is_test || use_global_stats) { // only when test we use input to do computation. const auto *est_mean = ctx.Input("Mean"); const auto *est_var = ctx.Input("Variance"); // Run inference mode. PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL); PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL); PADDLE_ENFORCE_EQ(est_mean->dims()[0], C); PADDLE_ENFORCE_EQ(est_var->dims()[0], C); CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardInference( handle, // Note: PERSISTENT not implemented for inference CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, scale->template data>(), bias->template data>(), est_mean->template data>(), est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to // initialize them. auto *mean_out = ctx.Output("MeanOut"); auto *variance_out = ctx.Output("VarianceOut"); mean_out->mutable_data>(ctx.GetPlace()); variance_out->mutable_data>(ctx.GetPlace()); auto *saved_mean = ctx.Output("SavedMean"); auto *saved_variance = ctx.Output("SavedVariance"); saved_mean->mutable_data>(ctx.GetPlace()); saved_variance->mutable_data>(ctx.GetPlace()); math::SetConstant> functor; functor(dev_ctx, saved_mean, static_cast>(0)); functor(dev_ctx, saved_variance, static_cast>(0)); if ((N * H * W * D) == 1) { LOG(WARNING) << "Only 1 element in normalization dimension, " << "we skip the batch norm calculation, let y = x."; framework::TensorCopy(*x, ctx.GetPlace(), y); } else { double this_factor = 1. - momentum; CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining( handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, scale->template data>(), bias->template data>(), this_factor, mean_out->template mutable_data>( ctx.GetPlace()), variance_out->template mutable_data>( ctx.GetPlace()), epsilon, saved_mean->template mutable_data>( ctx.GetPlace()), saved_variance->template mutable_data>( ctx.GetPlace()))); } } // clean when exit. CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); CUDNN_ENFORCE( platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); } }; template static __global__ void KeBNBackwardScaleBias( const T *dy, const T *x, const BatchNormParamType *mean, const BatchNormParamType *variance, const double epsilon, const int N, const int C, const int HxW, BatchNormParamType *dscale, BatchNormParamType *dbias) { const int outer_size = C; const int inner_size = N * HxW; typedef cub::BlockReduce, BlockDim> BlockReduce; __shared__ typename BlockReduce::TempStorage ds_storage; __shared__ typename BlockReduce::TempStorage db_storage; for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { BatchNormParamType ds_sum = static_cast>(0); BatchNormParamType db_sum = static_cast>(0); BatchNormParamType inv_var_i = 1.0 / sqrt(variance[i] + epsilon); BatchNormParamType mean_i = mean[i]; for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { const int index = layout == framework::DataLayout::kNCHW ? (j / HxW * C + i) * HxW + j % HxW : j * outer_size + i; ds_sum += static_cast>(dy[index]) * (static_cast>(x[index]) - mean_i); db_sum += static_cast>(dy[index]); } ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum()); db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum()); if (threadIdx.x == 0) { dscale[i] = ds_sum * inv_var_i; dbias[i] = db_sum; } __syncthreads(); } } template static __global__ void KeBNBackwardData(const T *dy, const BatchNormParamType *scale, const BatchNormParamType *variance, const double epsilon, const int C, const int HxW, const int num, T *dx) { int gid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (int i = gid; i < num; i += stride) { const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; BatchNormParamType inv_var = 1.0 / sqrt(variance[c] + epsilon); dx[i] = static_cast(static_cast>(dy[i]) * scale[c] * inv_var); } } template class BatchNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); double epsilon = static_cast(ctx.Attr("epsilon")); const std::string data_layout_str = ctx.Attr("data_layout"); const bool use_global_stats = ctx.Attr("use_global_stats"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); const auto *x = ctx.Input("X"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto *scale = ctx.Input("Scale"); const auto &x_dims = x->dims(); PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, "The Input dim size should be between 2 and 5"); int N, C, H, W, D; ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); // init output auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); d_x->mutable_data(ctx.GetPlace()); if (d_scale && d_bias) { d_scale->mutable_data>(ctx.GetPlace()); d_bias->mutable_data>(ctx.GetPlace()); } PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); PADDLE_ENFORCE_EQ(scale->dims()[0], C); std::vector dims; std::vector strides; if (data_layout == DataLayout::kNCHW) { dims = {N, C, H, W, D}; strides = {C * H * W * D, H * W * D, W * D, D, 1}; } else { dims = {N, C, H, W, D}; strides = {H * W * C * D, 1, W * D * C, D * C, C}; } auto &dev_ctx = ctx.template device_context(); if (!use_global_stats) { if ((N * H * W * D) == 1) { framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); math::SetConstant> functor; functor(dev_ctx, d_scale, static_cast>(0)); functor(dev_ctx, d_bias, static_cast>(0)); return; } // ------------------- cudnn descriptors --------------------- cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t bn_param_desc_; cudnnBatchNormMode_t mode_; CUDNN_ENFORCE( platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); CUDNN_ENFORCE( platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " << "CUDNN_BN_MIN_EPSILON instead."; } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); #if CUDNN_VERSION_MIN(7, 0, 0) if (FLAGS_cudnn_batchnorm_spatial_persistent) { mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; } else { mode_ = CUDNN_BATCHNORM_SPATIAL; } #else mode_ = CUDNN_BATCHNORM_SPATIAL; #endif CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( bn_param_desc_, data_desc_, mode_)); const auto *saved_mean = ctx.Input("SavedMean"); const auto *saved_var = ctx.Input("SavedVariance"); const void *saved_mean_data = saved_mean->template data>(); const void *saved_var_data = saved_var->template data>(); CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, d_y->template data(), data_desc_, d_x->template mutable_data(ctx.GetPlace()), bn_param_desc_, scale->template data>(), d_scale->template mutable_data>(ctx.GetPlace()), d_bias->template mutable_data>(ctx.GetPlace()), epsilon, saved_mean_data, saved_var_data)); // clean when exit. CUDNN_ENFORCE( platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); CUDNN_ENFORCE( platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); } else { const auto *running_mean = ctx.Input("Mean"); const auto *running_var = ctx.Input("Variance"); const auto *running_mean_data = running_mean->template data>(); const auto *running_var_data = running_var->template data>(); const int num = x->numel(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); int grid1 = (num + block - 1) / block; int grid2 = std::min(C, max_blocks); if (data_layout == framework::DataLayout::kNCHW) { if (d_x) { KeBNBackwardData<<< grid1, block, 0, dev_ctx.stream()>>>( d_y->data(), scale->data>(), running_var_data, epsilon, C, H * W, num, d_x->data()); } if (d_scale && d_bias) { KeBNBackwardScaleBias<<< grid2, block, 0, dev_ctx.stream()>>>( d_y->data(), x->data(), running_mean_data, running_var_data, epsilon, N, C, H * W * D, d_scale->data>(), d_bias->data>()); } } else { if (d_x) { KeBNBackwardData<<< grid1, block, 0, dev_ctx.stream()>>>( d_y->data(), scale->data>(), running_var_data, epsilon, C, H * W, num, d_x->data()); } if (d_scale && d_bias) { KeBNBackwardScaleBias<<< grid2, block, 0, dev_ctx.stream()>>>( d_y->data(), x->data(), running_mean_data, running_var_data, epsilon, N, C, H * W * D, d_scale->data>(), d_bias->data>()); } } } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( batch_norm, ops::BatchNormKernel, ops::BatchNormKernel, ops::BatchNormKernel); REGISTER_OP_CUDA_KERNEL( batch_norm_grad, ops::BatchNormGradKernel, ops::BatchNormGradKernel, ops::BatchNormGradKernel);