提交 b3e54ead 编写于 作者: M Megvii Engine Team

feat(dnn/bn): use new cudnn BN kernel to support NHWC

GitOrigin-RevId: 9d80f2009d0496f532b267fcf841785b74c0b50c
上级 6b863cc5
...@@ -682,7 +682,8 @@ public: ...@@ -682,7 +682,8 @@ public:
* http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
* *
* \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$, * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1),
* iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
*/ */
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
...@@ -724,7 +725,8 @@ protected: ...@@ -724,7 +725,8 @@ protected:
}; };
class SlidingWindowTransposeForward : public SlidingWindowTransposeBase { class SlidingWindowTransposeForward : public SlidingWindowTransposeBase {
DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1); DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1,
1);
public: public:
/** /**
...@@ -744,7 +746,8 @@ protected: ...@@ -744,7 +746,8 @@ protected:
using SlidingWindowTranspose = SlidingWindowTransposeForward; using SlidingWindowTranspose = SlidingWindowTransposeForward;
class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase { class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase {
DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1); DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1,
1);
public: public:
/** /**
...@@ -975,7 +978,7 @@ protected: ...@@ -975,7 +978,7 @@ protected:
}; };
class BNForward : public BNBase { class BNForward : public BNBase {
DEF_OPR_IMPL(BNForward, BNBase, 6, 5); DEF_OPR_IMPL(BNForward, BNBase, 6, 6);
public: public:
/** /**
...@@ -986,10 +989,11 @@ public: ...@@ -986,10 +989,11 @@ public:
* \param[out] dst (n, c, h, w) * \param[out] dst (n, c, h, w)
* \param[out] mean (see m_param.ParamDim) Global mean. * \param[out] mean (see m_param.ParamDim) Global mean.
* \param[out] variance (see m_param.ParamDim) Global variance. * \param[out] variance (see m_param.ParamDim) Global variance.
* \Param[out] batch_mean (see m_param.ParamDim) * \param[out] batch_mean (see m_param.ParamDim)
* Optionally cached intermediate mean from forward pass * Optionally cached intermediate mean from forward pass
* \Param[out] batch_inv_variance (see m_param.ParamDim) * \param[out] batch_inv_variance (see m_param.ParamDim)
* Optionally cached intermediate variance from forward pass * Optionally cached intermediate variance from forward pass
* \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx)
* src and dst must have the same shape. * src and dst must have the same shape.
* src and dst must be contiguous. * src and dst must be contiguous.
*/ */
...@@ -998,17 +1002,20 @@ public: ...@@ -998,17 +1002,20 @@ public:
_megdnn_tensor_inout variance, _megdnn_tensor_inout variance,
_megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; _megdnn_tensor_out reserve, _megdnn_tensor_out dst,
void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale, _megdnn_workspace workspace) = 0;
TensorLayout& bn_bias, TensorLayout& mean, void deduce_layout(const TensorLayout& src, const TensorLayout& bn_scale,
const TensorLayout& bn_bias, TensorLayout& mean,
TensorLayout& variance, TensorLayout& batch_mean, TensorLayout& variance, TensorLayout& batch_mean,
TensorLayout& batch_inv_variance, TensorLayout& dst); TensorLayout& batch_inv_variance, TensorLayout& reserve,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes( virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& bn_scale, const TensorLayout& src, const TensorLayout& bn_scale,
const TensorLayout& bn_bias, const TensorLayout& mean, const TensorLayout& bn_bias, const TensorLayout& mean,
const TensorLayout& variance, const TensorLayout& batch_mean, const TensorLayout& variance, const TensorLayout& batch_mean,
const TensorLayout& batch_inv_variance, const TensorLayout& batch_inv_variance, const TensorLayout& reserve,
const TensorLayout& dst) = 0; const TensorLayout& dst) = 0;
virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& bn_scale, void check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
...@@ -1016,12 +1023,13 @@ protected: ...@@ -1016,12 +1023,13 @@ protected:
const TensorLayout& variance, const TensorLayout& variance,
const TensorLayout& batch_mean, const TensorLayout& batch_mean,
const TensorLayout& batch_inv_variance, const TensorLayout& batch_inv_variance,
const TensorLayout& dst, size_t workspace_in_bytes); const TensorLayout& dst, size_t workspace_in_bytes,
size_t reserve_in_bytes = 0);
}; };
using BN = BNForward; using BN = BNForward;
class BNBackward : public BNBase { class BNBackward : public BNBase {
DEF_OPR_IMPL(BNBackward, BNBase, 5, 3); DEF_OPR_IMPL(BNBackward, BNBase, 6, 3);
public: public:
/** /**
...@@ -1035,19 +1043,23 @@ public: ...@@ -1035,19 +1043,23 @@ public:
Calculated in the forwardpropagation. Calculated in the forwardpropagation.
* \param[in] saved_batch_variance of the input batch. * \param[in] saved_batch_variance of the input batch.
Calculated in the forwardpropagation. Calculated in the forwardpropagation.
* \param[in] reserve (see cudnnBatchNormalizationBackwardEx)
*/ */
virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_variance, _megdnn_tensor_in saved_batch_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes( virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& dy, const TensorLayout& x, const TensorLayout& dy,
const TensorLayout& saved_batch_mean, const TensorLayout& saved_batch_mean,
const TensorLayout& saved_batch_variance, const TensorLayout& saved_batch_variance,
const TensorLayout& bn_scale, const TensorLayout& d_bn_scale, const TensorLayout& bn_scale, const TensorLayout& reserve,
const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0; const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias,
const TensorLayout& dx) = 0;
virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
protected: protected:
void check_exec(const TensorLayout& x, const TensorLayout& dy, void check_exec(const TensorLayout& x, const TensorLayout& dy,
...@@ -1056,7 +1068,7 @@ protected: ...@@ -1056,7 +1068,7 @@ protected:
const TensorLayout& bn_scale, const TensorLayout& bn_scale,
const TensorLayout& d_bn_scale, const TensorLayout& d_bn_scale,
const TensorLayout& d_bn_bias, const TensorLayout& dx, const TensorLayout& d_bn_bias, const TensorLayout& dx,
size_t workspace_in_bytes); size_t workspace_in_bytes, size_t reserve_in_bytes = 0);
}; };
class LRNBase : public OperatorBase { class LRNBase : public OperatorBase {
......
...@@ -253,7 +253,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) ...@@ -253,7 +253,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum_alias('Format', 'Convolution') add_enum_alias('Format', 'Convolution')
) )
(pdef('AdaptivePooling', version=0,is_legacy=True). (pdef('AdaptivePooling', version=0, is_legacy=True).
add_enum_alias('Mode', 'PoolingV0'). add_enum_alias('Mode', 'PoolingV0').
add_enum_alias('Format', 'ConvolutionV0') add_enum_alias('Format', 'ConvolutionV0')
) )
...@@ -276,6 +276,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) ...@@ -276,6 +276,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'), Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'),
Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'), Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'),
Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'), Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'),
Doc('DIM_111C = 3', 'Dim of params (Sigma, Mu) is 1 x 1 x 1 x C'),
name_field='param_dim' name_field='param_dim'
). ).
add_enum( add_enum(
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing, software
* software distributed under the License is distributed on an * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
...@@ -14,28 +14,32 @@ ...@@ -14,28 +14,32 @@
namespace megdnn { namespace megdnn {
void BNForward::deduce_layout(const TensorLayout& src, TensorLayout&, void BNForward::deduce_layout(const TensorLayout& src, const TensorLayout&,
TensorLayout&, TensorLayout&, TensorLayout&, const TensorLayout&, TensorLayout&, TensorLayout&,
TensorLayout&, TensorLayout&, TensorLayout& dst) { TensorLayout&, TensorLayout&,
TensorLayout& reserve, TensorLayout& dst) {
reserve = {{get_reserve_in_bytes(src)}, dtype::Byte()};
dst = src; dst = src;
} }
void BNForward::check_exec(const TensorLayout& src, const TensorLayout& bn_scale, void BNForward::check_exec(
const TensorLayout& src, const TensorLayout& bn_scale,
const TensorLayout& bn_bias, const TensorLayout& mean, const TensorLayout& bn_bias, const TensorLayout& mean,
const TensorLayout& variance, const TensorLayout& variance, const TensorLayout& batch_mean,
const TensorLayout& batch_mean, const TensorLayout& batch_inv_variance, const TensorLayout& dst,
const TensorLayout& batch_inv_variance, size_t workspace_in_bytes, size_t reserve_in_bytes) {
const TensorLayout& dst, size_t workspace_in_bytes) {
megdnn_assert_contiguous(src); megdnn_assert_contiguous(src);
megdnn_assert_eq_layout(src, dst); megdnn_assert_eq_layout(src, dst);
megdnn_assert_eq_layout(bn_scale, bn_bias); megdnn_assert_eq_layout(bn_scale, bn_bias);
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT);
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT);
auto required_workspace_in_bytes = auto required_workspace_in_bytes = get_workspace_in_bytes(
get_workspace_in_bytes(src, bn_scale, bn_bias, mean, variance, src, bn_scale, bn_bias, mean, variance, batch_mean,
batch_mean, batch_inv_variance, dst); batch_inv_variance, {}, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
auto required_reserve_in_bytes = get_reserve_in_bytes(src);
megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes);
} }
void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
...@@ -44,7 +48,8 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, ...@@ -44,7 +48,8 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
const TensorLayout& bn_scale, const TensorLayout& bn_scale,
const TensorLayout& d_bn_scale, const TensorLayout& d_bn_scale,
const TensorLayout& d_bn_bias, const TensorLayout& d_bn_bias,
const TensorLayout& dx, size_t workspace_in_bytes) { const TensorLayout& dx, size_t workspace_in_bytes,
size_t reserve_in_bytes) {
megdnn_assert_contiguous(x); megdnn_assert_contiguous(x);
megdnn_assert_eq_layout(x, dy); megdnn_assert_eq_layout(x, dy);
megdnn_assert_eq_layout(x, dx); megdnn_assert_eq_layout(x, dx);
...@@ -54,11 +59,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, ...@@ -54,11 +59,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
megdnn_assert_eq_layout(saved_batch_mean, bn_scale); megdnn_assert_eq_layout(saved_batch_mean, bn_scale);
megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT); megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT);
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT);
auto required_workspace_in_bytes = auto required_workspace_in_bytes = get_workspace_in_bytes(
get_workspace_in_bytes(x, dy, saved_batch_mean, saved_batch_variance, x, dy, saved_batch_mean, saved_batch_variance, bn_scale, {},
bn_scale, d_bn_scale, d_bn_bias, dx); d_bn_scale, d_bn_bias, dx);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, "BNBackward only support TRAINING mode"); auto required_reserve_in_bytes = get_reserve_in_bytes(x);
megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes);
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING,
"BNBackward only support TRAINING mode");
} }
} // namespace megdnn } // namespace megdnn
......
...@@ -55,8 +55,8 @@ DEF(GroupLocalBackwardData, 3, true, false); ...@@ -55,8 +55,8 @@ DEF(GroupLocalBackwardData, 3, true, false);
DEF(GroupLocalBackwardFilter, 3, true, false); DEF(GroupLocalBackwardFilter, 3, true, false);
DEF(LRNForward, 2, true, true); DEF(LRNForward, 2, true, true);
DEF(LRNBackward, 4, true, false); DEF(LRNBackward, 4, true, false);
DEF(BNForward, 8, true, true); DEF(BNForward, 9, true, true);
DEF(BNBackward, 8, true, false); DEF(BNBackward, 9, true, false);
DEF(ROIPoolingForward, 4, true, false); DEF(ROIPoolingForward, 4, true, false);
DEF(ROIPoolingBackward, 5, true, false); DEF(ROIPoolingBackward, 5, true, false);
DEF(CorrelationForward, 3, true, true); DEF(CorrelationForward, 3, true, true);
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing, software
* software distributed under the License is distributed on an * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "./opr_impl.h" #include "./opr_impl.h"
...@@ -17,9 +17,11 @@ namespace cuda { ...@@ -17,9 +17,11 @@ namespace cuda {
namespace batch_normalization { namespace batch_normalization {
void BNTensorDescHolder::setup(const TensorLayout& x, BNTensorDescHolder::BNTensorDescHolder(const TensorLayout& x,
const ParamDim& param_dim) { const ParamDim& param_dim,
const FwdMode& fwd_mode) {
TensorShape xy_shape(x); TensorShape xy_shape(x);
Format xy_format = Format::NCHW;
switch (param_dim) { switch (param_dim) {
case ParamDim::DIM_11HW: case ParamDim::DIM_11HW:
...@@ -34,50 +36,116 @@ void BNTensorDescHolder::setup(const TensorLayout& x, ...@@ -34,50 +36,116 @@ void BNTensorDescHolder::setup(const TensorLayout& x,
case ParamDim::DIM_1C11: case ParamDim::DIM_1C11:
bn_mode = CUDNN_BATCHNORM_SPATIAL; bn_mode = CUDNN_BATCHNORM_SPATIAL;
break; break;
case ParamDim::DIM_111C:
bn_mode = CUDNN_BATCHNORM_SPATIAL;
xy_format = Format::NHWC;
#if CUDNN_VERSION >= 7410
if (fwd_mode == FwdMode::TRAINING) {
bn_mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
}
#endif // CUDNN_VERSION >= 7400
break;
default: default:
megdnn_throw("Unknown param dim type of batch normalization."); megdnn_throw("Unknown param dim type of batch normalization.");
} }
xy_desc.set(TensorLayout(xy_shape, x.dtype)); xy_desc.set(TensorLayout(xy_shape, x.dtype), xy_format);
param_desc.set(xy_desc.desc, bn_mode); param_desc.set(xy_desc.desc, bn_mode);
} }
size_t get_reserve_size(const cudnnHandle_t& handle,
const BNTensorDescHolder& tensor_desc) {
#if CUDNN_VERSION >= 7410
size_t reserve_size;
cudnn_check(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
nullptr, // activationDesc
tensor_desc.xy_desc.desc, // xDesc
&reserve_size));
return reserve_size;
#else
return 0;
#endif // CUDNN_VERSION >= 7410
}
} // namespace batch_normalization } // namespace batch_normalization
using batch_normalization::BNTensorDescHolder;
size_t BNForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) {
#if CUDNN_VERSION >= 7410
auto handle = cudnn_handle(this->handle());
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
size_t workspace_size;
cudnn_check(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
tensor_desc.xy_desc.desc, // xDesc
tensor_desc.xy_desc.desc, // yDesc
tensor_desc.xy_desc.desc, // zDesc
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
nullptr, // activationDesc
&workspace_size));
return workspace_size;
#else
return 0;
#endif // CUDNN_VERSION >= 7410
}
size_t BNForwardImpl::get_reserve_in_bytes(const TensorLayout& src) {
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
return batch_normalization::get_reserve_size(cudnn_handle(this->handle()),
tensor_desc);
}
void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance, _megdnn_tensor_out variance,
_megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) { _megdnn_tensor_out reserve, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout,
variance.layout, batch_mean.layout, batch_inv_variance.layout, variance.layout, batch_mean.layout, batch_inv_variance.layout,
dst.layout, workspace.size); dst.layout, workspace.size, reserve.layout.access_bytes());
auto handle = cudnn_handle(this->handle()); auto handle = cudnn_handle(this->handle());
m_tensor_desc.setup(src.layout, m_param.param_dim); BNTensorDescHolder tensor_desc(src.layout, m_param.param_dim,
m_param.fwd_mode);
float alpha = 1.0f, beta = 0.0f; float alpha = 1.0f, beta = 0.0f;
switch (m_param.fwd_mode) { switch (m_param.fwd_mode) {
case param::BN::FwdMode::TRAINING: case param::BN::FwdMode::TRAINING:
#if CUDNN_VERSION >= 7410
cudnn_check(cudnnBatchNormalizationForwardTrainingEx(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
&alpha, &beta, // one & zero
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x
nullptr, nullptr, // zDesc & z
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor,
mean.raw_ptr, variance.raw_ptr, m_param.epsilon,
batch_mean.raw_ptr, batch_inv_variance.raw_ptr, nullptr,
workspace.raw_ptr, workspace.size, reserve.raw_ptr,
reserve.layout.access_bytes()));
#else
cudnn_check(cudnnBatchNormalizationForwardTraining( cudnn_check(cudnnBatchNormalizationForwardTraining(
handle, m_tensor_desc.bn_mode, handle, tensor_desc.bn_mode, &alpha, &beta,
&alpha, &beta, tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x
m_tensor_desc.xy_desc.desc, // xDesc tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y
src.raw_ptr, // x tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
m_tensor_desc.xy_desc.desc, // yDesc
dst.raw_ptr, // y
m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor,
mean.raw_ptr, variance.raw_ptr, m_param.epsilon, mean.raw_ptr, variance.raw_ptr, m_param.epsilon,
batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); batch_mean.raw_ptr, batch_inv_variance.raw_ptr));
#endif // CUDNN_VERSION >= 7410
break; break;
case param::BN::FwdMode::INFERENCE: case param::BN::FwdMode::INFERENCE:
cudnn_check(cudnnBatchNormalizationForwardInference( cudnn_check(cudnnBatchNormalizationForwardInference(
handle, m_tensor_desc.bn_mode, handle, tensor_desc.bn_mode, &alpha, &beta,
&alpha, &beta, tensor_desc.xy_desc.desc, src.raw_ptr,
m_tensor_desc.xy_desc.desc, src.raw_ptr, tensor_desc.xy_desc.desc, dst.raw_ptr,
m_tensor_desc.xy_desc.desc, dst.raw_ptr, tensor_desc.param_desc.desc, bn_scale.raw_ptr,
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr,
bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr,
m_param.epsilon)); m_param.epsilon));
break; break;
...@@ -86,30 +154,79 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, ...@@ -86,30 +154,79 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
} }
} }
size_t BNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) {
#if CUDNN_VERSION >= 7410
auto handle = cudnn_handle(this->handle());
BNTensorDescHolder tensor_desc(x, m_param.param_dim, m_param.fwd_mode);
size_t workspace_size;
cudnn_check(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
tensor_desc.xy_desc.desc, // xDesc
tensor_desc.xy_desc.desc, // yDesc
tensor_desc.xy_desc.desc, // dyDesc
nullptr, // dzDesc
tensor_desc.xy_desc.desc, // dxDesc
tensor_desc.param_desc.desc, // dBnScaleBiasDesc
nullptr, // activationDesc
&workspace_size));
return workspace_size;
#else
return 0;
#endif // CUDNN_VERSION >= 7410
}
size_t BNBackwardImpl::get_reserve_in_bytes(const TensorLayout& src) {
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
return batch_normalization::get_reserve_size(cudnn_handle(this->handle()),
tensor_desc);
}
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_tensor_out dx, _megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(x.layout, dy.layout, saved_batch_mean.layout, check_exec(x.layout, dy.layout, saved_batch_mean.layout,
saved_batch_inv_variance.layout, bn_scale.layout, saved_batch_inv_variance.layout, bn_scale.layout,
d_bn_scale.layout, d_bn_bias.layout, dx.layout, d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size,
workspace.size); reserve.layout.access_bytes());
auto handle = cudnn_handle(this->handle()); auto handle = cudnn_handle(this->handle());
m_tensor_desc.setup(x.layout, m_param.param_dim); BNTensorDescHolder tensor_desc(x.layout, m_param.param_dim,
m_param.fwd_mode);
float alpha = 1.0, beta = 0.0; float alpha = 1.0, beta = 0.0;
#if CUDNN_VERSION >= 7410
cudnn_check(cudnnBatchNormalizationBackwardEx(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta,
&alpha, &beta, tensor_desc.xy_desc.desc,
x.raw_ptr, // xDesc & x
nullptr, nullptr, // yDesc & y
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy
nullptr, nullptr, // dzDesc & dz
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale
nullptr, // bnBias
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias
m_param.epsilon, saved_batch_mean.raw_ptr,
saved_batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr,
workspace.size, reserve.raw_ptr, reserve.layout.access_bytes()));
#else
cudnn_check(cudnnBatchNormalizationBackward( cudnn_check(cudnnBatchNormalizationBackward(
handle, m_tensor_desc.bn_mode, handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta,
&alpha, &beta, &alpha, &beta, tensor_desc.xy_desc.desc, x.raw_ptr, // xDesc & x
m_tensor_desc.xy_desc.desc, x.raw_ptr, tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy
m_tensor_desc.xy_desc.desc, dy.raw_ptr, tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx
m_tensor_desc.xy_desc.desc, dx.raw_ptr, tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, m_param.epsilon, m_param.epsilon, saved_batch_mean.raw_ptr,
saved_batch_mean.raw_ptr, saved_batch_inv_variance.raw_ptr)); saved_batch_inv_variance.raw_ptr));
#endif
} }
} // namespace cuda } // namespace cuda
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing, software
* software distributed under the License is distributed on an * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#pragma once #pragma once
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
...@@ -20,14 +20,20 @@ namespace batch_normalization { ...@@ -20,14 +20,20 @@ namespace batch_normalization {
struct BNTensorDescHolder { struct BNTensorDescHolder {
using ParamDim = param::BN::ParamDim; using ParamDim = param::BN::ParamDim;
using FwdMode = param::BN::FwdMode;
using Format = param::Convolution::Format;
TensorDesc xy_desc; TensorDesc xy_desc;
BNParamDesc param_desc; BNParamDesc param_desc;
cudnnBatchNormMode_t bn_mode; cudnnBatchNormMode_t bn_mode;
void setup(const TensorLayout& x, const ParamDim& param_dim); BNTensorDescHolder(const TensorLayout& x, const ParamDim& param_dim,
const FwdMode& fwd_mode);
}; };
size_t get_reserve_size(const cudnnHandle_t& handle,
const BNTensorDescHolder& tensor_desc);
} // namespace batch_normalization } // namespace batch_normalization
class BNForwardImpl final : public BNForward { class BNForwardImpl final : public BNForward {
...@@ -36,19 +42,15 @@ public: ...@@ -36,19 +42,15 @@ public:
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
_megdnn_workspace workspace) override; _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override;
return 0; size_t get_reserve_in_bytes(const TensorLayout& src) override;
}
private:
batch_normalization::BNTensorDescHolder m_tensor_desc;
}; };
class BNBackwardImpl final : public BNBackward { class BNBackwardImpl final : public BNBackward {
...@@ -57,20 +59,16 @@ public: ...@@ -57,20 +59,16 @@ public:
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
_megdnn_workspace workspace) override; _megdnn_tensor_out dx, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override;
return 0; size_t get_reserve_in_bytes(const TensorLayout& src) override;
}
private:
batch_normalization::BNTensorDescHolder m_tensor_desc;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/naive/batch_normalization/opr_impl.h" #include "src/naive/batch_normalization/opr_impl.h"
...@@ -219,7 +220,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, ...@@ -219,7 +220,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_inout variance, _megdnn_tensor_inout variance,
_megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) { _megdnn_tensor_out, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout,
variance.layout, batch_mean.layout, batch_inv_variance.layout, variance.layout, batch_mean.layout, batch_inv_variance.layout,
dst.layout, workspace.size); dst.layout, workspace.size);
...@@ -263,7 +265,7 @@ WorkspaceBundle BNBackwardImpl::get_workspace_bundle(size_t x_size, ...@@ -263,7 +265,7 @@ WorkspaceBundle BNBackwardImpl::get_workspace_bundle(size_t x_size,
size_t BNBackwardImpl::get_workspace_in_bytes( size_t BNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout&, const TensorLayout&, const TensorLayout& x, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout& bn_scale, const TensorLayout&, const TensorLayout&, const TensorLayout& bn_scale, const TensorLayout&,
const TensorLayout&, const TensorLayout&) { const TensorLayout&, const TensorLayout&, const TensorLayout&) {
auto x_size = x.total_nr_elems(), param_size = bn_scale.total_nr_elems(); auto x_size = x.total_nr_elems(), param_size = bn_scale.total_nr_elems();
return get_workspace_bundle(x_size, param_size).total_size_in_bytes(); return get_workspace_bundle(x_size, param_size).total_size_in_bytes();
} }
...@@ -271,7 +273,7 @@ size_t BNBackwardImpl::get_workspace_in_bytes( ...@@ -271,7 +273,7 @@ size_t BNBackwardImpl::get_workspace_in_bytes(
void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in, void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in,
_megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_scale, _megdnn_tensor_in,
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out d_bn_bias,
_megdnn_tensor_out dx_out, _megdnn_tensor_out dx_out,
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
...@@ -21,16 +22,17 @@ public: ...@@ -21,16 +22,17 @@ public:
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
_megdnn_workspace workspace) override; _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override {
return 0; return 0;
} }
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; }
}; };
class BNBackwardImpl final : public BNBackward { class BNBackwardImpl final : public BNBackward {
...@@ -39,15 +41,17 @@ public: ...@@ -39,15 +41,17 @@ public:
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
_megdnn_workspace workspace) override; _megdnn_tensor_out dx, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout& bn_scale, const TensorLayout& bn_scale,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override; const TensorLayout&) override;
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; }
private: private:
WorkspaceBundle get_workspace_bundle(size_t x_size, size_t param_size, WorkspaceBundle get_workspace_bundle(size_t x_size, size_t param_size,
......
...@@ -49,7 +49,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, ...@@ -49,7 +49,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_out variance, _megdnn_tensor_out variance,
_megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) { _megdnn_tensor_out, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout,
variance.layout, batch_mean.layout, batch_inv_variance.layout, variance.layout, batch_mean.layout, batch_inv_variance.layout,
dst.layout, workspace.size); dst.layout, workspace.size);
...@@ -88,7 +89,7 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, ...@@ -88,7 +89,7 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_scale, _megdnn_tensor_in,
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
......
...@@ -37,16 +37,17 @@ public: ...@@ -37,16 +37,17 @@ public:
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
_megdnn_workspace workspace) override; _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override {
return 0; return 0;
} }
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; }
private: private:
batch_normalization::BNTensorDescHolder m_tensor_desc; batch_normalization::BNTensorDescHolder m_tensor_desc;
...@@ -58,17 +59,18 @@ public: ...@@ -58,17 +59,18 @@ public:
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
_megdnn_workspace workspace) override; _megdnn_tensor_out dx, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override {
return 0; return 0;
} }
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; }
private: private:
batch_normalization::BNTensorDescHolder m_tensor_desc; batch_normalization::BNTensorDescHolder m_tensor_desc;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册