From b3e54eade1dc303b0f002c3b7df88a9bf4b5eb74 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 17 Aug 2021 17:45:50 +0800 Subject: [PATCH] feat(dnn/bn): use new cudnn BN kernel to support NHWC GitOrigin-RevId: 9d80f2009d0496f532b267fcf841785b74c0b50c --- dnn/include/megdnn/oprs/nn.h | 46 +++-- dnn/scripts/opr_param_defs.py | 3 +- dnn/src/common/batch_normalization.cpp | 48 +++-- dnn/src/common/opr_trait.h | 4 +- dnn/src/cuda/batch_normalization/opr_impl.cpp | 189 ++++++++++++++---- dnn/src/cuda/batch_normalization/opr_impl.h | 48 +++-- .../naive/batch_normalization/opr_impl.cpp | 14 +- dnn/src/naive/batch_normalization/opr_impl.h | 18 +- dnn/src/rocm/batch_normalization/opr_impl.cpp | 5 +- dnn/src/rocm/batch_normalization/opr_impl.h | 16 +- 10 files changed, 268 insertions(+), 123 deletions(-) diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 31ce6912e..ca6920a76 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -682,7 +682,8 @@ public: * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html * * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$, - * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. + * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), + * iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. */ virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; @@ -724,7 +725,8 @@ protected: }; class SlidingWindowTransposeForward : public SlidingWindowTransposeBase { - DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1); + DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, + 1); public: /** @@ -744,7 +746,8 @@ protected: using SlidingWindowTranspose = SlidingWindowTransposeForward; class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase { - DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1); + DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, + 1); public: /** @@ -975,7 +978,7 @@ protected: }; class BNForward : public BNBase { - DEF_OPR_IMPL(BNForward, BNBase, 6, 5); + DEF_OPR_IMPL(BNForward, BNBase, 6, 6); public: /** @@ -986,10 +989,11 @@ public: * \param[out] dst (n, c, h, w) * \param[out] mean (see m_param.ParamDim) Global mean. * \param[out] variance (see m_param.ParamDim) Global variance. - * \Param[out] batch_mean (see m_param.ParamDim) + * \param[out] batch_mean (see m_param.ParamDim) * Optionally cached intermediate mean from forward pass - * \Param[out] batch_inv_variance (see m_param.ParamDim) + * \param[out] batch_inv_variance (see m_param.ParamDim) * Optionally cached intermediate variance from forward pass + * \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx) * src and dst must have the same shape. * src and dst must be contiguous. */ @@ -998,17 +1002,20 @@ public: _megdnn_tensor_inout variance, _megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_inv_variance, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale, - TensorLayout& bn_bias, TensorLayout& mean, + _megdnn_tensor_out reserve, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, const TensorLayout& bn_scale, + const TensorLayout& bn_bias, TensorLayout& mean, TensorLayout& variance, TensorLayout& batch_mean, - TensorLayout& batch_inv_variance, TensorLayout& dst); + TensorLayout& batch_inv_variance, TensorLayout& reserve, + TensorLayout& dst); virtual size_t get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& bn_scale, const TensorLayout& bn_bias, const TensorLayout& mean, const TensorLayout& variance, const TensorLayout& batch_mean, - const TensorLayout& batch_inv_variance, + const TensorLayout& batch_inv_variance, const TensorLayout& reserve, const TensorLayout& dst) = 0; + virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0; protected: void check_exec(const TensorLayout& src, const TensorLayout& bn_scale, @@ -1016,12 +1023,13 @@ protected: const TensorLayout& variance, const TensorLayout& batch_mean, const TensorLayout& batch_inv_variance, - const TensorLayout& dst, size_t workspace_in_bytes); + const TensorLayout& dst, size_t workspace_in_bytes, + size_t reserve_in_bytes = 0); }; using BN = BNForward; class BNBackward : public BNBase { - DEF_OPR_IMPL(BNBackward, BNBase, 5, 3); + DEF_OPR_IMPL(BNBackward, BNBase, 6, 3); public: /** @@ -1035,19 +1043,23 @@ public: Calculated in the forwardpropagation. * \param[in] saved_batch_variance of the input batch. Calculated in the forwardpropagation. + * \param[in] reserve (see cudnnBatchNormalizationBackwardEx) */ virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_variance, - _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, + _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, + _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, _megdnn_workspace workspace) = 0; virtual size_t get_workspace_in_bytes( const TensorLayout& x, const TensorLayout& dy, const TensorLayout& saved_batch_mean, const TensorLayout& saved_batch_variance, - const TensorLayout& bn_scale, const TensorLayout& d_bn_scale, - const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0; + const TensorLayout& bn_scale, const TensorLayout& reserve, + const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias, + const TensorLayout& dx) = 0; + virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0; protected: void check_exec(const TensorLayout& x, const TensorLayout& dy, @@ -1056,7 +1068,7 @@ protected: const TensorLayout& bn_scale, const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias, const TensorLayout& dx, - size_t workspace_in_bytes); + size_t workspace_in_bytes, size_t reserve_in_bytes = 0); }; class LRNBase : public OperatorBase { diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index d8634b334..f825a598d 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -253,7 +253,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) add_enum_alias('Format', 'Convolution') ) -(pdef('AdaptivePooling', version=0,is_legacy=True). +(pdef('AdaptivePooling', version=0, is_legacy=True). add_enum_alias('Mode', 'PoolingV0'). add_enum_alias('Format', 'ConvolutionV0') ) @@ -276,6 +276,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'), Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'), Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'), + Doc('DIM_111C = 3', 'Dim of params (Sigma, Mu) is 1 x 1 x 1 x C'), name_field='param_dim' ). add_enum( diff --git a/dnn/src/common/batch_normalization.cpp b/dnn/src/common/batch_normalization.cpp index 55a25f11a..133670226 100644 --- a/dnn/src/common/batch_normalization.cpp +++ b/dnn/src/common/batch_normalization.cpp @@ -4,9 +4,9 @@ * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megdnn/oprs.h" @@ -14,28 +14,32 @@ namespace megdnn { -void BNForward::deduce_layout(const TensorLayout& src, TensorLayout&, - TensorLayout&, TensorLayout&, TensorLayout&, - TensorLayout&, TensorLayout&, TensorLayout& dst) { +void BNForward::deduce_layout(const TensorLayout& src, const TensorLayout&, + const TensorLayout&, TensorLayout&, TensorLayout&, + TensorLayout&, TensorLayout&, + TensorLayout& reserve, TensorLayout& dst) { + reserve = {{get_reserve_in_bytes(src)}, dtype::Byte()}; dst = src; } -void BNForward::check_exec(const TensorLayout& src, const TensorLayout& bn_scale, - const TensorLayout& bn_bias, const TensorLayout& mean, - const TensorLayout& variance, - const TensorLayout& batch_mean, - const TensorLayout& batch_inv_variance, - const TensorLayout& dst, size_t workspace_in_bytes) { +void BNForward::check_exec( + const TensorLayout& src, const TensorLayout& bn_scale, + const TensorLayout& bn_bias, const TensorLayout& mean, + const TensorLayout& variance, const TensorLayout& batch_mean, + const TensorLayout& batch_inv_variance, const TensorLayout& dst, + size_t workspace_in_bytes, size_t reserve_in_bytes) { megdnn_assert_contiguous(src); megdnn_assert_eq_layout(src, dst); megdnn_assert_eq_layout(bn_scale, bn_bias); megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); - auto required_workspace_in_bytes = - get_workspace_in_bytes(src, bn_scale, bn_bias, mean, variance, - batch_mean, batch_inv_variance, dst); + auto required_workspace_in_bytes = get_workspace_in_bytes( + src, bn_scale, bn_bias, mean, variance, batch_mean, + batch_inv_variance, {}, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + auto required_reserve_in_bytes = get_reserve_in_bytes(src); + megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes); } void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, @@ -44,7 +48,8 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, const TensorLayout& bn_scale, const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias, - const TensorLayout& dx, size_t workspace_in_bytes) { + const TensorLayout& dx, size_t workspace_in_bytes, + size_t reserve_in_bytes) { megdnn_assert_contiguous(x); megdnn_assert_eq_layout(x, dy); megdnn_assert_eq_layout(x, dx); @@ -54,11 +59,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, megdnn_assert_eq_layout(saved_batch_mean, bn_scale); megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT); megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); - auto required_workspace_in_bytes = - get_workspace_in_bytes(x, dy, saved_batch_mean, saved_batch_variance, - bn_scale, d_bn_scale, d_bn_bias, dx); + auto required_workspace_in_bytes = get_workspace_in_bytes( + x, dy, saved_batch_mean, saved_batch_variance, bn_scale, {}, + d_bn_scale, d_bn_bias, dx); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); - megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, "BNBackward only support TRAINING mode"); + auto required_reserve_in_bytes = get_reserve_in_bytes(x); + megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes); + megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, + "BNBackward only support TRAINING mode"); } } // namespace megdnn diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index fcf4fa10f..fbc02c75c 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -55,8 +55,8 @@ DEF(GroupLocalBackwardData, 3, true, false); DEF(GroupLocalBackwardFilter, 3, true, false); DEF(LRNForward, 2, true, true); DEF(LRNBackward, 4, true, false); -DEF(BNForward, 8, true, true); -DEF(BNBackward, 8, true, false); +DEF(BNForward, 9, true, true); +DEF(BNBackward, 9, true, false); DEF(ROIPoolingForward, 4, true, false); DEF(ROIPoolingBackward, 5, true, false); DEF(CorrelationForward, 3, true, true); diff --git a/dnn/src/cuda/batch_normalization/opr_impl.cpp b/dnn/src/cuda/batch_normalization/opr_impl.cpp index 4a6551a07..c0606a8db 100644 --- a/dnn/src/cuda/batch_normalization/opr_impl.cpp +++ b/dnn/src/cuda/batch_normalization/opr_impl.cpp @@ -4,9 +4,9 @@ * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./opr_impl.h" @@ -17,9 +17,11 @@ namespace cuda { namespace batch_normalization { -void BNTensorDescHolder::setup(const TensorLayout& x, - const ParamDim& param_dim) { +BNTensorDescHolder::BNTensorDescHolder(const TensorLayout& x, + const ParamDim& param_dim, + const FwdMode& fwd_mode) { TensorShape xy_shape(x); + Format xy_format = Format::NCHW; switch (param_dim) { case ParamDim::DIM_11HW: @@ -34,50 +36,116 @@ void BNTensorDescHolder::setup(const TensorLayout& x, case ParamDim::DIM_1C11: bn_mode = CUDNN_BATCHNORM_SPATIAL; break; + case ParamDim::DIM_111C: + bn_mode = CUDNN_BATCHNORM_SPATIAL; + xy_format = Format::NHWC; +#if CUDNN_VERSION >= 7410 + if (fwd_mode == FwdMode::TRAINING) { + bn_mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } +#endif // CUDNN_VERSION >= 7400 + break; default: megdnn_throw("Unknown param dim type of batch normalization."); } - xy_desc.set(TensorLayout(xy_shape, x.dtype)); + xy_desc.set(TensorLayout(xy_shape, x.dtype), xy_format); param_desc.set(xy_desc.desc, bn_mode); } +size_t get_reserve_size(const cudnnHandle_t& handle, + const BNTensorDescHolder& tensor_desc) { +#if CUDNN_VERSION >= 7410 + size_t reserve_size; + cudnn_check(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, + nullptr, // activationDesc + tensor_desc.xy_desc.desc, // xDesc + &reserve_size)); + return reserve_size; +#else + return 0; +#endif // CUDNN_VERSION >= 7410 +} } // namespace batch_normalization +using batch_normalization::BNTensorDescHolder; + +size_t BNForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) { +#if CUDNN_VERSION >= 7410 + auto handle = cudnn_handle(this->handle()); + BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); + + size_t workspace_size; + cudnn_check(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, + tensor_desc.xy_desc.desc, // xDesc + tensor_desc.xy_desc.desc, // yDesc + tensor_desc.xy_desc.desc, // zDesc + tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc + nullptr, // activationDesc + &workspace_size)); + return workspace_size; +#else + return 0; +#endif // CUDNN_VERSION >= 7410 +} + +size_t BNForwardImpl::get_reserve_in_bytes(const TensorLayout& src) { + BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); + return batch_normalization::get_reserve_size(cudnn_handle(this->handle()), + tensor_desc); +} + void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_inv_variance, - _megdnn_tensor_out dst, _megdnn_workspace workspace) { + _megdnn_tensor_out reserve, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, variance.layout, batch_mean.layout, batch_inv_variance.layout, - dst.layout, workspace.size); + dst.layout, workspace.size, reserve.layout.access_bytes()); auto handle = cudnn_handle(this->handle()); - m_tensor_desc.setup(src.layout, m_param.param_dim); + BNTensorDescHolder tensor_desc(src.layout, m_param.param_dim, + m_param.fwd_mode); float alpha = 1.0f, beta = 0.0f; switch (m_param.fwd_mode) { case param::BN::FwdMode::TRAINING: +#if CUDNN_VERSION >= 7410 + cudnn_check(cudnnBatchNormalizationForwardTrainingEx( + handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, + &alpha, &beta, // one & zero + tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x + nullptr, nullptr, // zDesc & z + tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y + tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc + bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, + mean.raw_ptr, variance.raw_ptr, m_param.epsilon, + batch_mean.raw_ptr, batch_inv_variance.raw_ptr, nullptr, + workspace.raw_ptr, workspace.size, reserve.raw_ptr, + reserve.layout.access_bytes())); +#else cudnn_check(cudnnBatchNormalizationForwardTraining( - handle, m_tensor_desc.bn_mode, - &alpha, &beta, - m_tensor_desc.xy_desc.desc, // xDesc - src.raw_ptr, // x - m_tensor_desc.xy_desc.desc, // yDesc - dst.raw_ptr, // y - m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc + handle, tensor_desc.bn_mode, &alpha, &beta, + tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x + tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y + tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, mean.raw_ptr, variance.raw_ptr, m_param.epsilon, batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); - +#endif // CUDNN_VERSION >= 7410 break; case param::BN::FwdMode::INFERENCE: cudnn_check(cudnnBatchNormalizationForwardInference( - handle, m_tensor_desc.bn_mode, - &alpha, &beta, - m_tensor_desc.xy_desc.desc, src.raw_ptr, - m_tensor_desc.xy_desc.desc, dst.raw_ptr, - m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, + handle, tensor_desc.bn_mode, &alpha, &beta, + tensor_desc.xy_desc.desc, src.raw_ptr, + tensor_desc.xy_desc.desc, dst.raw_ptr, + tensor_desc.param_desc.desc, bn_scale.raw_ptr, bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, m_param.epsilon)); break; @@ -86,30 +154,79 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, } } +size_t BNBackwardImpl::get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) { +#if CUDNN_VERSION >= 7410 + auto handle = cudnn_handle(this->handle()); + BNTensorDescHolder tensor_desc(x, m_param.param_dim, m_param.fwd_mode); + + size_t workspace_size; + cudnn_check(cudnnGetBatchNormalizationBackwardExWorkspaceSize( + handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, + tensor_desc.xy_desc.desc, // xDesc + tensor_desc.xy_desc.desc, // yDesc + tensor_desc.xy_desc.desc, // dyDesc + nullptr, // dzDesc + tensor_desc.xy_desc.desc, // dxDesc + tensor_desc.param_desc.desc, // dBnScaleBiasDesc + nullptr, // activationDesc + &workspace_size)); + return workspace_size; +#else + return 0; +#endif // CUDNN_VERSION >= 7410 +} + +size_t BNBackwardImpl::get_reserve_in_bytes(const TensorLayout& src) { + BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); + return batch_normalization::get_reserve_size(cudnn_handle(this->handle()), + tensor_desc); +} + void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_inv_variance, - _megdnn_tensor_in bn_scale, + _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, _megdnn_tensor_out d_bn_scale, - _megdnn_tensor_out d_bn_bias, - _megdnn_tensor_out dx, _megdnn_workspace workspace) { + _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, + _megdnn_workspace workspace) { check_exec(x.layout, dy.layout, saved_batch_mean.layout, saved_batch_inv_variance.layout, bn_scale.layout, - d_bn_scale.layout, d_bn_bias.layout, dx.layout, - workspace.size); + d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size, + reserve.layout.access_bytes()); auto handle = cudnn_handle(this->handle()); - m_tensor_desc.setup(x.layout, m_param.param_dim); + BNTensorDescHolder tensor_desc(x.layout, m_param.param_dim, + m_param.fwd_mode); float alpha = 1.0, beta = 0.0; +#if CUDNN_VERSION >= 7410 + cudnn_check(cudnnBatchNormalizationBackwardEx( + handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, + &alpha, &beta, tensor_desc.xy_desc.desc, + x.raw_ptr, // xDesc & x + nullptr, nullptr, // yDesc & y + tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy + nullptr, nullptr, // dzDesc & dz + tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx + tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale + nullptr, // bnBias + d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias + m_param.epsilon, saved_batch_mean.raw_ptr, + saved_batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr, + workspace.size, reserve.raw_ptr, reserve.layout.access_bytes())); +#else cudnn_check(cudnnBatchNormalizationBackward( - handle, m_tensor_desc.bn_mode, - &alpha, &beta, &alpha, &beta, - m_tensor_desc.xy_desc.desc, x.raw_ptr, - m_tensor_desc.xy_desc.desc, dy.raw_ptr, - m_tensor_desc.xy_desc.desc, dx.raw_ptr, - m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, - d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, m_param.epsilon, - saved_batch_mean.raw_ptr, saved_batch_inv_variance.raw_ptr)); + handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, + tensor_desc.xy_desc.desc, x.raw_ptr, // xDesc & x + tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy + tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx + tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale + d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias + m_param.epsilon, saved_batch_mean.raw_ptr, + saved_batch_inv_variance.raw_ptr)); +#endif } } // namespace cuda diff --git a/dnn/src/cuda/batch_normalization/opr_impl.h b/dnn/src/cuda/batch_normalization/opr_impl.h index c5055623e..acfb29a8d 100644 --- a/dnn/src/cuda/batch_normalization/opr_impl.h +++ b/dnn/src/cuda/batch_normalization/opr_impl.h @@ -4,9 +4,9 @@ * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megdnn/oprs.h" @@ -20,14 +20,20 @@ namespace batch_normalization { struct BNTensorDescHolder { using ParamDim = param::BN::ParamDim; + using FwdMode = param::BN::FwdMode; + using Format = param::Convolution::Format; TensorDesc xy_desc; BNParamDesc param_desc; cudnnBatchNormMode_t bn_mode; - void setup(const TensorLayout& x, const ParamDim& param_dim); + BNTensorDescHolder(const TensorLayout& x, const ParamDim& param_dim, + const FwdMode& fwd_mode); }; +size_t get_reserve_size(const cudnnHandle_t& handle, + const BNTensorDescHolder& tensor_desc); + } // namespace batch_normalization class BNForwardImpl final : public BNForward { @@ -36,19 +42,15 @@ public: void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, - _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, - const TensorLayout&, - const TensorLayout&) override { - return 0; - } - -private: - batch_normalization::BNTensorDescHolder m_tensor_desc; + const TensorLayout&, const TensorLayout&, + const TensorLayout&) override; + size_t get_reserve_in_bytes(const TensorLayout& src) override; }; class BNBackwardImpl final : public BNBackward { @@ -57,20 +59,16 @@ public: void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_inv_variance, - _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, - _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, - _megdnn_workspace workspace) override; + _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, + _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, + _megdnn_tensor_out dx, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, - const TensorLayout&, - const TensorLayout&) override { - return 0; - } - -private: - batch_normalization::BNTensorDescHolder m_tensor_desc; + const TensorLayout&, const TensorLayout&, + const TensorLayout&) override; + size_t get_reserve_in_bytes(const TensorLayout& src) override; }; } // namespace cuda diff --git a/dnn/src/naive/batch_normalization/opr_impl.cpp b/dnn/src/naive/batch_normalization/opr_impl.cpp index 61036a075..51bacf2b5 100644 --- a/dnn/src/naive/batch_normalization/opr_impl.cpp +++ b/dnn/src/naive/batch_normalization/opr_impl.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/naive/batch_normalization/opr_impl.h" @@ -219,13 +220,14 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_inout variance, _megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_inv_variance, - _megdnn_tensor_out dst, _megdnn_workspace workspace) { + _megdnn_tensor_out, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, variance.layout, batch_mean.layout, batch_inv_variance.layout, dst.layout, workspace.size); DNN_INC_FLOAT16(if (src.layout.dtype == dtype::Float16() && - bn_scale.layout.dtype == dtype::Float32()) { + bn_scale.layout.dtype == dtype::Float32()) { MEGDNN_DISPATCH_CPU_KERN_OPR(({ using T0 = typename DTypeTrait::ctype; using T1 = typename DTypeTrait::ctype; @@ -263,7 +265,7 @@ WorkspaceBundle BNBackwardImpl::get_workspace_bundle(size_t x_size, size_t BNBackwardImpl::get_workspace_in_bytes( const TensorLayout& x, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout& bn_scale, const TensorLayout&, - const TensorLayout&, const TensorLayout&) { + const TensorLayout&, const TensorLayout&, const TensorLayout&) { auto x_size = x.total_nr_elems(), param_size = bn_scale.total_nr_elems(); return get_workspace_bundle(x_size, param_size).total_size_in_bytes(); } @@ -271,7 +273,7 @@ size_t BNBackwardImpl::get_workspace_in_bytes( void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in, _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_inv_variance, - _megdnn_tensor_in bn_scale, + _megdnn_tensor_in bn_scale, _megdnn_tensor_in, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx_out, @@ -286,7 +288,7 @@ void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in, workspace.raw_ptr); DNN_INC_FLOAT16(if (x_in.layout.dtype == dtype::Float16() && - bn_scale.layout.dtype == dtype::Float32()) { + bn_scale.layout.dtype == dtype::Float32()) { MEGDNN_DISPATCH_CPU_KERN_OPR(({ using T0 = typename DTypeTrait::ctype; using T1 = typename DTypeTrait::ctype; diff --git a/dnn/src/naive/batch_normalization/opr_impl.h b/dnn/src/naive/batch_normalization/opr_impl.h index ecaee7ccb..a980dfbc0 100644 --- a/dnn/src/naive/batch_normalization/opr_impl.h +++ b/dnn/src/naive/batch_normalization/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -21,16 +22,17 @@ public: void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, - _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, - const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { return 0; } + size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } }; class BNBackwardImpl final : public BNBackward { @@ -39,15 +41,17 @@ public: void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_inv_variance, - _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, - _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, - _megdnn_workspace workspace) override; + _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, + _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, + _megdnn_tensor_out dx, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout& bn_scale, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&) override; + size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } private: WorkspaceBundle get_workspace_bundle(size_t x_size, size_t param_size, diff --git a/dnn/src/rocm/batch_normalization/opr_impl.cpp b/dnn/src/rocm/batch_normalization/opr_impl.cpp index 90f9213ba..d011c931c 100644 --- a/dnn/src/rocm/batch_normalization/opr_impl.cpp +++ b/dnn/src/rocm/batch_normalization/opr_impl.cpp @@ -49,7 +49,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_inv_variance, - _megdnn_tensor_out dst, _megdnn_workspace workspace) { + _megdnn_tensor_out, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, variance.layout, batch_mean.layout, batch_inv_variance.layout, dst.layout, workspace.size); @@ -88,7 +89,7 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_inv_variance, - _megdnn_tensor_in bn_scale, + _megdnn_tensor_in bn_scale, _megdnn_tensor_in, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, _megdnn_workspace workspace) { diff --git a/dnn/src/rocm/batch_normalization/opr_impl.h b/dnn/src/rocm/batch_normalization/opr_impl.h index 2c4267962..c222944d9 100644 --- a/dnn/src/rocm/batch_normalization/opr_impl.h +++ b/dnn/src/rocm/batch_normalization/opr_impl.h @@ -37,16 +37,17 @@ public: void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, - _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, - const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { return 0; } + size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } private: batch_normalization::BNTensorDescHolder m_tensor_desc; @@ -58,17 +59,18 @@ public: void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_inv_variance, - _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, - _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, - _megdnn_workspace workspace) override; + _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, + _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, + _megdnn_tensor_out dx, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, - const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { return 0; } + size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } private: batch_normalization::BNTensorDescHolder m_tensor_desc; -- GitLab