From 56b723c40d06623c716124fc7a0b61bfcfb0f78a Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Wed, 25 Oct 2017 15:54:08 -0700 Subject: [PATCH] Cudnn batch norm op (#5067) * init cudnn batch norm op * rename batch_norm_cudnn_op.cc batch_norm_op.cu * correct name style * add ExtractNCWHD, simplify code * fix ExtractNCWHD * use CUDNN_ENFORCE instead of PADDLE_ENFORCE --- paddle/operators/batch_norm_op.cu | 262 ++++++++++++++++++++++++++++++ paddle/platform/cudnn_helper.h | 59 +++++++ paddle/platform/dynload/cudnn.h | 1 + 3 files changed, 322 insertions(+) create mode 100644 paddle/operators/batch_norm_op.cu diff --git a/paddle/operators/batch_norm_op.cu b/paddle/operators/batch_norm_op.cu new file mode 100644 index 0000000000..6ba6ee12ec --- /dev/null +++ b/paddle/operators/batch_norm_op.cu @@ -0,0 +1,262 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/batch_norm_op.h" + +#include +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using CudnnDataType = platform::CudnnDataType; + +void ExtractNCWHD(const framework::DDim &dims, + const TensorFormat &tensor_format, int *N, int *C, int *H, + int *W, int *D) { + *N = dims[0]; + *C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1]; + *H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1]; + *W = dims.size() > 3 + ? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2]) + : 1; + *D = dims.size() > 4 + ? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3]) + : 1; +} + +template +class BatchNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + double epsilon = static_cast(ctx.Attr("epsilon")); + const float momentum = ctx.Attr("momentum"); + const bool is_test = ctx.Attr("is_test"); + const std::string tensor_format_str = + ctx.Attr("tensor_format"); + const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str); + + // Get the size for each dimension. + // NCHW [batch_size, in_channels, in_height, in_width] + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5, + "The Input dim size should be between 3 and 5"); + int N, C, H, W, D; + ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D); + + // ------------------- cudnn descriptors --------------------- + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_; + + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + + if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); +#if CUDNN_VERSION_MIN(7, 0, 0) + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + mode_ = CUDNN_BATCHNORM_SPATIAL; +#endif + + VLOG(1) << "Setting descriptors."; + std::vector dims; + std::vector strides; + if (tensor_format == TensorFormat::NCHW) { + dims = {N, C, H, W, D}; + strides = {C * H * W * D, H * W * D, W * D, D, 1}; + } else { + dims = {N, C, H, W, D}; + strides = {H * W * D * C, 1, W * D * C, D * C, C}; + } + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); + CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + + auto *y = ctx.Output("Y"); + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_variance = ctx.Output("SavedVariance"); + + // alloc memory + y->mutable_data(ctx.GetPlace()); + mean_out->mutable_data(ctx.GetPlace()); + variance_out->mutable_data(ctx.GetPlace()); + saved_mean->mutable_data(ctx.GetPlace()); + saved_variance->mutable_data(ctx.GetPlace()); + + math::SetConstant functor; + functor(ctx.device_context(), saved_mean, 0); + functor(ctx.device_context(), saved_variance, 0); + // FIXME(qiao) should not set zero self + functor(ctx.device_context(), mean_out, 0); + functor(ctx.device_context(), variance_out, 0); + + auto handle = ctx.cuda_device_context().cudnn_handle(); + + // Now, depending on whether we are running test or not, we have two paths. + if (is_test) { + // only when test we use input to do computation. + const auto *est_mean = ctx.Input("Mean"); + const auto *est_var = ctx.Input("Variance"); + // Run inference mode. + PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL); + PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL); + PADDLE_ENFORCE_EQ(est_mean->dims()[0], C); + PADDLE_ENFORCE_EQ(est_var->dims()[0], C); + + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardInference( + handle, + // Note: PERSISTENT not implemented for inference + CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), + CudnnDataType::kZero(), data_desc_, x->template data(), + data_desc_, y->template mutable_data(ctx.GetPlace()), + bn_param_desc_, scale->template data(), bias->template data(), + est_mean->template data(), est_var->template data(), epsilon)); + } else { + // Run training mode. + // obtain running mean and running inv var, and see if we need to + // initialize them. + double this_factor = 1. - momentum; + + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining( + handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), + data_desc_, x->template data(), data_desc_, + y->template mutable_data(ctx.GetPlace()), bn_param_desc_, + scale->template data(), bias->template data(), this_factor, + mean_out->template mutable_data(ctx.GetPlace()), + variance_out->template mutable_data(ctx.GetPlace()), epsilon, + saved_mean->template mutable_data(ctx.GetPlace()), + saved_variance->template mutable_data(ctx.GetPlace()))); + } + + // clean when exit. + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); + } +}; + +template +class BatchNormGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + double epsilon = static_cast(ctx.Attr("epsilon")); + const std::string tensor_format_str = + ctx.Attr("tensor_format"); + const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str); + const auto *x = ctx.Input("X"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scale = ctx.Input("Scale"); + + const auto &x_dims = x->dims(); + + PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5, + "The Input dim size should be between 3 and 5"); + int N, C, H, W, D; + ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D); + + PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); + PADDLE_ENFORCE_EQ(scale->dims()[0], C); + + // ------------------- cudnn descriptors --------------------- + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_; + + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); +#if CUDNN_VERSION_MIN(7, 0, 0) + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + mode_ = CUDNN_BATCHNORM_SPATIAL; +#endif + + std::vector dims = {N, C, H, W, D}; + std::vector strides = {H * W * C * D, 1, W * D * C, D * C, C}; + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); + CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + d_x->mutable_data(ctx.GetPlace()); + d_scale->mutable_data(ctx.GetPlace()); + d_bias->mutable_data(ctx.GetPlace()); + + const auto *saved_mean = ctx.Input("SavedMean"); + const auto *saved_var = ctx.Input("SavedVariance"); + const void *saved_mean_data = saved_mean->template data(); + const void *saved_var_data = saved_var->template data(); + + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( + ctx.cuda_device_context().cudnn_handle(), mode_, + CudnnDataType::kOne(), CudnnDataType::kZero(), + CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, + x->template data(), data_desc_, d_y->template data(), data_desc_, + d_x->template mutable_data(ctx.GetPlace()), bn_param_desc_, + scale->template data(), + d_scale->template mutable_data(ctx.GetPlace()), + d_bias->template mutable_data(ctx.GetPlace()), epsilon, + saved_mean_data, saved_var_data)); + + // clean when exit. + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(batch_norm, + ops::BatchNormKernel); +REGISTER_OP_GPU_KERNEL( + batch_norm_grad, + ops::BatchNormGradKernel); diff --git a/paddle/platform/cudnn_helper.h b/paddle/platform/cudnn_helper.h index 0c5719ef51..ce3421a3cb 100644 --- a/paddle/platform/cudnn_helper.h +++ b/paddle/platform/cudnn_helper.h @@ -22,6 +22,47 @@ limitations under the License. */ namespace paddle { namespace platform { +inline const char* cudnnGetErrorString(cudnnStatus_t status) { + switch (status) { + case CUDNN_STATUS_SUCCESS: + return "CUDNN_STATUS_SUCCESS"; + case CUDNN_STATUS_NOT_INITIALIZED: + return "CUDNN_STATUS_NOT_INITIALIZED"; + case CUDNN_STATUS_ALLOC_FAILED: + return "CUDNN_STATUS_ALLOC_FAILED"; + case CUDNN_STATUS_BAD_PARAM: + return "CUDNN_STATUS_BAD_PARAM"; + case CUDNN_STATUS_INTERNAL_ERROR: + return "CUDNN_STATUS_INTERNAL_ERROR"; + case CUDNN_STATUS_INVALID_VALUE: + return "CUDNN_STATUS_INVALID_VALUE"; + case CUDNN_STATUS_ARCH_MISMATCH: + return "CUDNN_STATUS_ARCH_MISMATCH"; + case CUDNN_STATUS_MAPPING_ERROR: + return "CUDNN_STATUS_MAPPING_ERROR"; + case CUDNN_STATUS_EXECUTION_FAILED: + return "CUDNN_STATUS_EXECUTION_FAILED"; + case CUDNN_STATUS_NOT_SUPPORTED: + return "CUDNN_STATUS_NOT_SUPPORTED"; + case CUDNN_STATUS_LICENSE_ERROR: + return "CUDNN_STATUS_LICENSE_ERROR"; + default: + return "Unknown cudnn error number"; + } +} + +#define CUDNN_VERSION_MIN(major, minor, patch) \ + (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) + +#define CUDNN_ENFORCE(condition) \ + do { \ + cudnnStatus_t status = condition; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \ + PADDLE_THROW("cuDNN call failed"); \ + } \ + } while (false) + enum class DataLayout { kNHWC, kNCHW, @@ -40,12 +81,30 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; + typedef const float ScalingParamType; + static ScalingParamType* kOne() { + static ScalingParamType v = 1.0; + return &v; + } + static ScalingParamType* kZero() { + static ScalingParamType v = 0.0; + return &v; + } }; template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; + typedef const double ScalingParamType; + static ScalingParamType* kOne() { + static ScalingParamType v = 1.0; + return &v; + } + static ScalingParamType* kZero() { + static ScalingParamType v = 0.0; + return &v; + } }; inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) { diff --git a/paddle/platform/dynload/cudnn.h b/paddle/platform/dynload/cudnn.h index 0120625b7c..b2d69da93b 100644 --- a/paddle/platform/dynload/cudnn.h +++ b/paddle/platform/dynload/cudnn.h @@ -83,6 +83,7 @@ extern void* cudnn_dso_handle; __macro(cudnnDestroyConvolutionDescriptor); \ __macro(cudnnSetConvolutionNdDescriptor); \ __macro(cudnnGetConvolutionNdDescriptor); \ + __macro(cudnnDeriveBNTensorDescriptor); \ __macro(cudnnCreate); \ __macro(cudnnDestroy); \ __macro(cudnnSetStream); \ -- GitLab