From c58f1a5cd9672ab64c729d7943d86262ba16822c Mon Sep 17 00:00:00 2001 From: Juncheng Date: Wed, 22 Jul 2020 20:13:30 +0800 Subject: [PATCH] Fix cudnn bn param desc (#3253) * use cudnnDeriveBNTensorDescriptor * CudnnTensorDescHelper * fix bn min eps Co-authored-by: guo ran <360112263@qq.com> --- .../kernels/normalization_kernel.cpp | 206 ++++++++++-------- .../test/ops/test_batch_normalization.py | 2 +- 2 files changed, 111 insertions(+), 97 deletions(-) diff --git a/oneflow/customized/kernels/normalization_kernel.cpp b/oneflow/customized/kernels/normalization_kernel.cpp index 1a196c0dea..633f9f08af 100644 --- a/oneflow/customized/kernels/normalization_kernel.cpp +++ b/oneflow/customized/kernels/normalization_kernel.cpp @@ -11,7 +11,7 @@ namespace { void InferDimSizeAndDataFormat(const ShapeView& x_shape, const int32_t axis, int32_t* n, int32_t* c, int32_t* h, int32_t* w, cudnnTensorFormat_t* format) { - if (axis != 0 && x_shape.Count(axis + 1) == 1 && CUDNN_VERSION >= 7605) { + if (axis != 0 && x_shape.Count(axis + 1) == 1) { *n = x_shape.At(0); *h = x_shape.Count(1, axis); *w = 1; @@ -26,35 +26,65 @@ void InferDimSizeAndDataFormat(const ShapeView& x_shape, const int32_t axis, int } } -DataType GetParamDataType(const DataType x_data_type) { - return x_data_type == DataType::kFloat16 ? DataType::kFloat : x_data_type; +void InferXYCudnnTensorDesc(const ShapeView& xy_shape, const DataType& data_type, + const int32_t axis, cudnnTensorDescriptor_t xy_desc) { + int32_t n, c, h, w; + cudnnTensorFormat_t format; + InferDimSizeAndDataFormat(xy_shape, axis, &n, &c, &h, &w, &format); + CudaCheck(cudnnSetTensor4dDescriptor(xy_desc, format, GetCudnnDataType(data_type), n, c, h, w)); } -std::function MakeCheckParamTensorFn( - const int32_t param_dim_size, const DataType param_data_type) { - return [=](const user_op::Tensor* tensor) { - CHECK_EQ(tensor->shape().NumAxes(), 1); - CHECK_EQ(tensor->shape().At(0), param_dim_size); - CHECK_EQ(tensor->data_type(), param_data_type); - }; +void InferParamCudnnTensorDesc(const cudnnTensorDescriptor_t xy_desc, cudnnBatchNormMode_t mode, + cudnnTensorDescriptor_t param_desc) { + CudaCheck(cudnnDeriveBNTensorDescriptor(param_desc, xy_desc, mode)); } +class CudnnTensorDescHelper final { + public: + OF_DISALLOW_COPY_AND_MOVE(CudnnTensorDescHelper); + CudnnTensorDescHelper(const ShapeView& xy_shape, const DataType& data_type, const int32_t axis, + cudnnBatchNormMode_t mode) { + CudaCheck(cudnnCreateTensorDescriptor(&xy_desc_)); + InferXYCudnnTensorDesc(xy_shape, data_type, axis, xy_desc_); + CudaCheck(cudnnCreateTensorDescriptor(¶m_desc_)); + InferParamCudnnTensorDesc(xy_desc_, mode, param_desc_); + int n, c, h, w, n_stride, c_stride, h_stride, w_stride; + CudaCheck(cudnnGetTensor4dDescriptor(param_desc_, ¶m_data_type_, &n, &c, &h, &w, &n_stride, + &c_stride, &h_stride, &w_stride)); + param_size_ = c; + } + ~CudnnTensorDescHelper() { + CudaCheck(cudnnDestroyTensorDescriptor(param_desc_)); + CudaCheck(cudnnDestroyTensorDescriptor(xy_desc_)); + } + + cudnnTensorDescriptor_t xy_desc() const { return xy_desc_; } + + cudnnTensorDescriptor_t param_desc() const { return param_desc_; } + + void CheckParamTensor(const user_op::Tensor* tensor) const { + CHECK_EQ(tensor->shape().NumAxes(), 1); + CHECK_EQ(tensor->shape().At(0), param_size_); + CHECK_EQ(GetCudnnDataType(tensor->data_type()), param_data_type_); + } + + private: + cudnnTensorDescriptor_t xy_desc_ = nullptr; + cudnnTensorDescriptor_t param_desc_ = nullptr; + cudnnDataType_t param_data_type_; + int32_t param_size_ = 0; +}; size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_type, const int32_t axis) { #if defined(BN_ENABLE_EX_API) - int32_t n, c, h, w; - cudnnTensorFormat_t format; - InferDimSizeAndDataFormat(x_shape, axis, &n, &c, &h, &w, &format); - CudnnTensorDesc xy_desc(format, data_type, n, c, h, w); - CudnnTensorDesc param_desc(format, GetParamDataType(data_type), 1, c, 1, 1); + const CudnnTensorDescHelper desc_helper(x_shape, data_type, axis, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); size_t size_in_bytes; cudnnHandle_t handle; CudaCheck(cudnnCreate(&handle)); - CudaCheck(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( - handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, xy_desc.Get(), nullptr, - xy_desc.Get(), param_desc.Get(), nullptr, &size_in_bytes)); - + handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), + nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &size_in_bytes)); CudaCheck(cudnnDestroy(handle)); return std::max(size_in_bytes, static_cast(1)); #else @@ -71,17 +101,15 @@ size_t InferTrainTmpSize(user_op::InferContext* ctx) { size_t InferGradWorkspaceSize(const ShapeView& x_shape, const DataType data_type, const int32_t axis) { #if defined(BN_ENABLE_EX_API) - int32_t n, c, h, w; - cudnnTensorFormat_t format; - InferDimSizeAndDataFormat(x_shape, axis, &n, &c, &h, &w, &format); - CudnnTensorDesc xy_desc(format, data_type, n, c, h, w); - CudnnTensorDesc param_desc(format, GetParamDataType(data_type), 1, c, 1, 1); + const CudnnTensorDescHelper desc_helper(x_shape, data_type, axis, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); size_t size_in_bytes; cudnnHandle_t handle; CudaCheck(cudnnCreate(&handle)); CudaCheck(cudnnGetBatchNormalizationBackwardExWorkspaceSize( - handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, xy_desc.Get(), nullptr, - xy_desc.Get(), nullptr, xy_desc.Get(), param_desc.Get(), nullptr, &size_in_bytes)); + handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), + nullptr, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), + nullptr, &size_in_bytes)); CudaCheck(cudnnDestroy(handle)); return std::max(size_in_bytes, static_cast(1)); #else @@ -120,24 +148,17 @@ class NormalizationInferenceKernel final : public user_op::OpKernel { CHECK_GE(axis, 0); CHECK_LT(axis, x->shape().NumAxes()); - int32_t n, c, h, w; - cudnnTensorFormat_t format; - InferDimSizeAndDataFormat(x->shape(), axis, &n, &c, &h, &w, &format); - - CudnnTensorDesc xy_desc(format, data_type, n, c, h, w); - const DataType param_data_type = GetParamDataType(data_type); - const auto CheckParamTensor = MakeCheckParamTensorFn(c, param_data_type); - CheckParamTensor(gamma); - CheckParamTensor(beta); - CheckParamTensor(moving_mean); - CheckParamTensor(moving_variance); - CudnnTensorDesc param_desc(format, param_data_type, 1, c, 1, 1); + const CudnnTensorDescHelper desc_helper(x->shape(), data_type, axis, CUDNN_BATCHNORM_SPATIAL); + desc_helper.CheckParamTensor(gamma); + desc_helper.CheckParamTensor(beta); + desc_helper.CheckParamTensor(moving_mean); + desc_helper.CheckParamTensor(moving_variance); CudaCheck(cudnnBatchNormalizationForwardInference( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL, CudnnSPOnePtr(), - CudnnSPZeroPtr(), xy_desc.Get(), x->dptr(), xy_desc.Get(), y->mut_dptr(), - param_desc.Get(), gamma->dptr(), beta->dptr(), moving_mean->dptr(), moving_variance->dptr(), - epsilon)); + CudnnSPZeroPtr(), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(), + desc_helper.param_desc(), gamma->dptr(), beta->dptr(), moving_mean->dptr(), + moving_variance->dptr(), epsilon)); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -184,53 +205,50 @@ class NormalizationTrainKernel final : public user_op::OpKernel { CHECK_GE(axis, 0); CHECK_LT(axis, x->shape().NumAxes()); - int32_t n, c, h, w; - cudnnTensorFormat_t format; - InferDimSizeAndDataFormat(x->shape(), axis, &n, &c, &h, &w, &format); - - CudnnTensorDesc xy_desc(format, data_type, n, c, h, w); - const DataType param_data_type = GetParamDataType(data_type); - const auto CheckParamTensor = MakeCheckParamTensorFn(c, param_data_type); - CheckParamTensor(gamma); - CheckParamTensor(beta); - CheckParamTensor(moving_mean); - CheckParamTensor(moving_variance); - CheckParamTensor(mean); - CheckParamTensor(inv_variance); - CudnnTensorDesc param_desc(format, param_data_type, 1, c, 1, 1); + const CudnnTensorDescHelper desc_helper(x->shape(), data_type, axis, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + desc_helper.CheckParamTensor(gamma); + desc_helper.CheckParamTensor(beta); + desc_helper.CheckParamTensor(moving_mean); + desc_helper.CheckParamTensor(moving_variance); + desc_helper.CheckParamTensor(mean); + desc_helper.CheckParamTensor(inv_variance); #if defined(BN_ENABLE_EX_API) size_t workspace_size; CudaCheck(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, - CUDNN_BATCHNORM_OPS_BN, xy_desc.Get(), nullptr, xy_desc.Get(), param_desc.Get(), nullptr, - &workspace_size)); + CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), + desc_helper.param_desc(), nullptr, &workspace_size)); size_t reserve_space_size; CudaCheck(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, - CUDNN_BATCHNORM_OPS_BN, nullptr, xy_desc.Get(), &reserve_space_size)); + CUDNN_BATCHNORM_OPS_BN, nullptr, desc_helper.xy_desc(), &reserve_space_size)); auto* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); if (reserve_space_size == 0 && workspace_size <= workspace->shape().elem_cnt()) { CudaCheck(cudnnBatchNormalizationForwardTrainingEx( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, - CUDNN_BATCHNORM_OPS_BN, CudnnSPOnePtr(), CudnnSPZeroPtr(), xy_desc.Get(), x->dptr(), - nullptr, nullptr, xy_desc.Get(), y->mut_dptr(), param_desc.Get(), gamma->dptr(), - beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(), moving_variance->mut_dptr(), - epsilon, mean->mut_dptr(), inv_variance->mut_dptr(), nullptr, workspace->mut_dptr(), - workspace->shape().elem_cnt(), nullptr, 0)); + CUDNN_BATCHNORM_OPS_BN, CudnnSPOnePtr(), CudnnSPZeroPtr(), desc_helper.xy_desc(), + x->dptr(), nullptr, nullptr, desc_helper.xy_desc(), y->mut_dptr(), + desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum, + moving_mean->mut_dptr(), moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), + inv_variance->mut_dptr(), nullptr, workspace->mut_dptr(), workspace->shape().elem_cnt(), + nullptr, 0)); } else { CudaCheck(cudnnBatchNormalizationForwardTraining( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CudnnSPOnePtr(), - CudnnSPZeroPtr(), xy_desc.Get(), x->dptr(), xy_desc.Get(), y->mut_dptr(), - param_desc.Get(), gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(), - moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), inv_variance->mut_dptr())); + CudnnSPZeroPtr(), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), + y->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum, + moving_mean->mut_dptr(), moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), + inv_variance->mut_dptr())); } #else CudaCheck(cudnnBatchNormalizationForwardTraining( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CudnnSPOnePtr(), - CudnnSPZeroPtr(), xy_desc.Get(), x->dptr(), xy_desc.Get(), y->mut_dptr(), - param_desc.Get(), gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(), - moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), inv_variance->mut_dptr())); + CudnnSPZeroPtr(), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(), + desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum, + moving_mean->mut_dptr(), moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), + inv_variance->mut_dptr())); #endif } @@ -264,53 +282,49 @@ class NormalizationGradUserKernel final : public user_op::OpKernel { CHECK_GE(axis, 0); CHECK_LT(axis, x->shape().NumAxes()); - int32_t n, c, h, w; - cudnnTensorFormat_t format; - InferDimSizeAndDataFormat(x->shape(), axis, &n, &c, &h, &w, &format); - - CudnnTensorDesc xy_desc(format, data_type, n, c, h, w); - const DataType param_data_type = GetParamDataType(data_type); - const auto CheckParamTensor = MakeCheckParamTensorFn(c, param_data_type); - CheckParamTensor(gamma); - CheckParamTensor(gamma_diff); - CheckParamTensor(beta_diff); - CudnnTensorDesc param_desc(format, param_data_type, 1, c, 1, 1); + const CudnnTensorDescHelper desc_helper(x->shape(), data_type, axis, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + desc_helper.CheckParamTensor(gamma); + desc_helper.CheckParamTensor(gamma_diff); + desc_helper.CheckParamTensor(beta_diff); + desc_helper.CheckParamTensor(mean); + desc_helper.CheckParamTensor(inv_variance); #if defined(BN_ENABLE_EX_API) size_t workspace_size; CudaCheck(cudnnGetBatchNormalizationBackwardExWorkspaceSize( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, - CUDNN_BATCHNORM_OPS_BN, xy_desc.Get(), nullptr, xy_desc.Get(), nullptr, xy_desc.Get(), - param_desc.Get(), nullptr, &workspace_size)); + CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), nullptr, + desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &workspace_size)); size_t reserve_space_size; CudaCheck(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, - CUDNN_BATCHNORM_OPS_BN, nullptr, xy_desc.Get(), &reserve_space_size)); + CUDNN_BATCHNORM_OPS_BN, nullptr, desc_helper.xy_desc(), &reserve_space_size)); auto* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); if (reserve_space_size == 0 && workspace_size <= workspace->shape().elem_cnt()) { CudaCheck(cudnnBatchNormalizationBackwardEx( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, CudnnSPOnePtr(), CudnnSPZeroPtr(), CudnnSPOnePtr(), - CudnnSPZeroPtr(), xy_desc.Get(), x->dptr(), nullptr, nullptr, xy_desc.Get(), - dy->dptr(), nullptr, nullptr, xy_desc.Get(), dx->mut_dptr(), param_desc.Get(), - gamma->dptr(), nullptr, gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, - mean->dptr(), inv_variance->dptr(), nullptr, workspace->mut_dptr(), - workspace->shape().elem_cnt(), nullptr, 0)); + CudnnSPZeroPtr(), desc_helper.xy_desc(), x->dptr(), nullptr, nullptr, + desc_helper.xy_desc(), dy->dptr(), nullptr, nullptr, desc_helper.xy_desc(), + dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), nullptr, gamma_diff->mut_dptr(), + beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr(), nullptr, + workspace->mut_dptr(), workspace->shape().elem_cnt(), nullptr, 0)); } else { CudaCheck(cudnnBatchNormalizationBackward( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CudnnSPOnePtr(), - CudnnSPZeroPtr(), CudnnSPOnePtr(), CudnnSPZeroPtr(), xy_desc.Get(), x->dptr(), - xy_desc.Get(), dy->dptr(), xy_desc.Get(), dx->mut_dptr(), param_desc.Get(), gamma->dptr(), - gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), - inv_variance->dptr())); + CudnnSPZeroPtr(), CudnnSPOnePtr(), CudnnSPZeroPtr(), desc_helper.xy_desc(), + x->dptr(), desc_helper.xy_desc(), dy->dptr(), desc_helper.xy_desc(), dx->mut_dptr(), + desc_helper.param_desc(), gamma->dptr(), gamma_diff->mut_dptr(), beta_diff->mut_dptr(), + epsilon, mean->dptr(), inv_variance->dptr())); } #else CudaCheck(cudnnBatchNormalizationBackward( ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CudnnSPOnePtr(), - CudnnSPZeroPtr(), CudnnSPOnePtr(), CudnnSPZeroPtr(), xy_desc.Get(), x->dptr(), - xy_desc.Get(), dy->dptr(), xy_desc.Get(), dx->mut_dptr(), param_desc.Get(), gamma->dptr(), - gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), - inv_variance->dptr())); + CudnnSPZeroPtr(), CudnnSPOnePtr(), CudnnSPZeroPtr(), desc_helper.xy_desc(), + x->dptr(), desc_helper.xy_desc(), dy->dptr(), desc_helper.xy_desc(), dx->mut_dptr(), + desc_helper.param_desc(), gamma->dptr(), gamma_diff->mut_dptr(), beta_diff->mut_dptr(), + epsilon, mean->dptr(), inv_variance->dptr())); #endif } diff --git a/oneflow/python/test/ops/test_batch_normalization.py b/oneflow/python/test/ops/test_batch_normalization.py index 548dd94dca..21de626986 100644 --- a/oneflow/python/test/ops/test_batch_normalization.py +++ b/oneflow/python/test/ops/test_batch_normalization.py @@ -346,7 +346,7 @@ def test_nn_batchnorm(test_case): arg_dict["input_shape"] = [(2, 4, 3, 5)] arg_dict["data_type"] = ["float32"] arg_dict["axis"] = [1, -1] - arg_dict["epsilon"] = [1e-5, 1e-4] + arg_dict["epsilon"] = [1.001e-5, 1e-4] for arg in GenArgDict(arg_dict): CompareNnBnWithTensorFlow(**arg) -- GitLab