提交 b3e54ead 编写于 作者: M Megvii Engine Team

feat(dnn/bn): use new cudnn BN kernel to support NHWC

GitOrigin-RevId: 9d80f2009d0496f532b267fcf841785b74c0b50c
上级 6b863cc5
......@@ -682,7 +682,8 @@ public:
* http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
*
* \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1),
* iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
*/
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
......@@ -724,7 +725,8 @@ protected:
};
class SlidingWindowTransposeForward : public SlidingWindowTransposeBase {
DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1);
DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1,
1);
public:
/**
......@@ -744,7 +746,8 @@ protected:
using SlidingWindowTranspose = SlidingWindowTransposeForward;
class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase {
DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1);
DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1,
1);
public:
/**
......@@ -975,7 +978,7 @@ protected:
};
class BNForward : public BNBase {
DEF_OPR_IMPL(BNForward, BNBase, 6, 5);
DEF_OPR_IMPL(BNForward, BNBase, 6, 6);
public:
/**
......@@ -986,10 +989,11 @@ public:
* \param[out] dst (n, c, h, w)
* \param[out] mean (see m_param.ParamDim) Global mean.
* \param[out] variance (see m_param.ParamDim) Global variance.
* \Param[out] batch_mean (see m_param.ParamDim)
* \param[out] batch_mean (see m_param.ParamDim)
* Optionally cached intermediate mean from forward pass
* \Param[out] batch_inv_variance (see m_param.ParamDim)
* \param[out] batch_inv_variance (see m_param.ParamDim)
* Optionally cached intermediate variance from forward pass
* \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx)
* src and dst must have the same shape.
* src and dst must be contiguous.
*/
......@@ -998,17 +1002,20 @@ public:
_megdnn_tensor_inout variance,
_megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale,
TensorLayout& bn_bias, TensorLayout& mean,
_megdnn_tensor_out reserve, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, const TensorLayout& bn_scale,
const TensorLayout& bn_bias, TensorLayout& mean,
TensorLayout& variance, TensorLayout& batch_mean,
TensorLayout& batch_inv_variance, TensorLayout& dst);
TensorLayout& batch_inv_variance, TensorLayout& reserve,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& bn_scale,
const TensorLayout& bn_bias, const TensorLayout& mean,
const TensorLayout& variance, const TensorLayout& batch_mean,
const TensorLayout& batch_inv_variance,
const TensorLayout& batch_inv_variance, const TensorLayout& reserve,
const TensorLayout& dst) = 0;
virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
protected:
void check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
......@@ -1016,12 +1023,13 @@ protected:
const TensorLayout& variance,
const TensorLayout& batch_mean,
const TensorLayout& batch_inv_variance,
const TensorLayout& dst, size_t workspace_in_bytes);
const TensorLayout& dst, size_t workspace_in_bytes,
size_t reserve_in_bytes = 0);
};
using BN = BNForward;
class BNBackward : public BNBase {
DEF_OPR_IMPL(BNBackward, BNBase, 5, 3);
DEF_OPR_IMPL(BNBackward, BNBase, 6, 3);
public:
/**
......@@ -1035,19 +1043,23 @@ public:
Calculated in the forwardpropagation.
* \param[in] saved_batch_variance of the input batch.
Calculated in the forwardpropagation.
* \param[in] reserve (see cudnnBatchNormalizationBackwardEx)
*/
virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& dy,
const TensorLayout& saved_batch_mean,
const TensorLayout& saved_batch_variance,
const TensorLayout& bn_scale, const TensorLayout& d_bn_scale,
const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
const TensorLayout& bn_scale, const TensorLayout& reserve,
const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias,
const TensorLayout& dx) = 0;
virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
protected:
void check_exec(const TensorLayout& x, const TensorLayout& dy,
......@@ -1056,7 +1068,7 @@ protected:
const TensorLayout& bn_scale,
const TensorLayout& d_bn_scale,
const TensorLayout& d_bn_bias, const TensorLayout& dx,
size_t workspace_in_bytes);
size_t workspace_in_bytes, size_t reserve_in_bytes = 0);
};
class LRNBase : public OperatorBase {
......
......@@ -253,7 +253,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum_alias('Format', 'Convolution')
)
(pdef('AdaptivePooling', version=0,is_legacy=True).
(pdef('AdaptivePooling', version=0, is_legacy=True).
add_enum_alias('Mode', 'PoolingV0').
add_enum_alias('Format', 'ConvolutionV0')
)
......@@ -276,6 +276,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'),
Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'),
Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'),
Doc('DIM_111C = 3', 'Dim of params (Sigma, Mu) is 1 x 1 x 1 x C'),
name_field='param_dim'
).
add_enum(
......
......@@ -4,9 +4,9 @@
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs.h"
......@@ -14,28 +14,32 @@
namespace megdnn {
void BNForward::deduce_layout(const TensorLayout& src, TensorLayout&,
TensorLayout&, TensorLayout&, TensorLayout&,
TensorLayout&, TensorLayout&, TensorLayout& dst) {
void BNForward::deduce_layout(const TensorLayout& src, const TensorLayout&,
const TensorLayout&, TensorLayout&, TensorLayout&,
TensorLayout&, TensorLayout&,
TensorLayout& reserve, TensorLayout& dst) {
reserve = {{get_reserve_in_bytes(src)}, dtype::Byte()};
dst = src;
}
void BNForward::check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
const TensorLayout& bn_bias, const TensorLayout& mean,
const TensorLayout& variance,
const TensorLayout& batch_mean,
const TensorLayout& batch_inv_variance,
const TensorLayout& dst, size_t workspace_in_bytes) {
void BNForward::check_exec(
const TensorLayout& src, const TensorLayout& bn_scale,
const TensorLayout& bn_bias, const TensorLayout& mean,
const TensorLayout& variance, const TensorLayout& batch_mean,
const TensorLayout& batch_inv_variance, const TensorLayout& dst,
size_t workspace_in_bytes, size_t reserve_in_bytes) {
megdnn_assert_contiguous(src);
megdnn_assert_eq_layout(src, dst);
megdnn_assert_eq_layout(bn_scale, bn_bias);
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT);
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT);
auto required_workspace_in_bytes =
get_workspace_in_bytes(src, bn_scale, bn_bias, mean, variance,
batch_mean, batch_inv_variance, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(
src, bn_scale, bn_bias, mean, variance, batch_mean,
batch_inv_variance, {}, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
auto required_reserve_in_bytes = get_reserve_in_bytes(src);
megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes);
}
void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
......@@ -44,7 +48,8 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
const TensorLayout& bn_scale,
const TensorLayout& d_bn_scale,
const TensorLayout& d_bn_bias,
const TensorLayout& dx, size_t workspace_in_bytes) {
const TensorLayout& dx, size_t workspace_in_bytes,
size_t reserve_in_bytes) {
megdnn_assert_contiguous(x);
megdnn_assert_eq_layout(x, dy);
megdnn_assert_eq_layout(x, dx);
......@@ -54,11 +59,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
megdnn_assert_eq_layout(saved_batch_mean, bn_scale);
megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT);
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT);
auto required_workspace_in_bytes =
get_workspace_in_bytes(x, dy, saved_batch_mean, saved_batch_variance,
bn_scale, d_bn_scale, d_bn_bias, dx);
auto required_workspace_in_bytes = get_workspace_in_bytes(
x, dy, saved_batch_mean, saved_batch_variance, bn_scale, {},
d_bn_scale, d_bn_bias, dx);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, "BNBackward only support TRAINING mode");
auto required_reserve_in_bytes = get_reserve_in_bytes(x);
megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes);
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING,
"BNBackward only support TRAINING mode");
}
} // namespace megdnn
......
......@@ -55,8 +55,8 @@ DEF(GroupLocalBackwardData, 3, true, false);
DEF(GroupLocalBackwardFilter, 3, true, false);
DEF(LRNForward, 2, true, true);
DEF(LRNBackward, 4, true, false);
DEF(BNForward, 8, true, true);
DEF(BNBackward, 8, true, false);
DEF(BNForward, 9, true, true);
DEF(BNBackward, 9, true, false);
DEF(ROIPoolingForward, 4, true, false);
DEF(ROIPoolingBackward, 5, true, false);
DEF(CorrelationForward, 3, true, true);
......
......@@ -4,9 +4,9 @@
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"
......@@ -17,9 +17,11 @@ namespace cuda {
namespace batch_normalization {
void BNTensorDescHolder::setup(const TensorLayout& x,
const ParamDim& param_dim) {
BNTensorDescHolder::BNTensorDescHolder(const TensorLayout& x,
const ParamDim& param_dim,
const FwdMode& fwd_mode) {
TensorShape xy_shape(x);
Format xy_format = Format::NCHW;
switch (param_dim) {
case ParamDim::DIM_11HW:
......@@ -34,50 +36,116 @@ void BNTensorDescHolder::setup(const TensorLayout& x,
case ParamDim::DIM_1C11:
bn_mode = CUDNN_BATCHNORM_SPATIAL;
break;
case ParamDim::DIM_111C:
bn_mode = CUDNN_BATCHNORM_SPATIAL;
xy_format = Format::NHWC;
#if CUDNN_VERSION >= 7410
if (fwd_mode == FwdMode::TRAINING) {
bn_mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
}
#endif // CUDNN_VERSION >= 7400
break;
default:
megdnn_throw("Unknown param dim type of batch normalization.");
}
xy_desc.set(TensorLayout(xy_shape, x.dtype));
xy_desc.set(TensorLayout(xy_shape, x.dtype), xy_format);
param_desc.set(xy_desc.desc, bn_mode);
}
size_t get_reserve_size(const cudnnHandle_t& handle,
const BNTensorDescHolder& tensor_desc) {
#if CUDNN_VERSION >= 7410
size_t reserve_size;
cudnn_check(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
nullptr, // activationDesc
tensor_desc.xy_desc.desc, // xDesc
&reserve_size));
return reserve_size;
#else
return 0;
#endif // CUDNN_VERSION >= 7410
}
} // namespace batch_normalization
using batch_normalization::BNTensorDescHolder;
size_t BNForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) {
#if CUDNN_VERSION >= 7410
auto handle = cudnn_handle(this->handle());
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
size_t workspace_size;
cudnn_check(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
tensor_desc.xy_desc.desc, // xDesc
tensor_desc.xy_desc.desc, // yDesc
tensor_desc.xy_desc.desc, // zDesc
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
nullptr, // activationDesc
&workspace_size));
return workspace_size;
#else
return 0;
#endif // CUDNN_VERSION >= 7410
}
size_t BNForwardImpl::get_reserve_in_bytes(const TensorLayout& src) {
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
return batch_normalization::get_reserve_size(cudnn_handle(this->handle()),
tensor_desc);
}
void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance,
_megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
_megdnn_tensor_out reserve, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout,
variance.layout, batch_mean.layout, batch_inv_variance.layout,
dst.layout, workspace.size);
dst.layout, workspace.size, reserve.layout.access_bytes());
auto handle = cudnn_handle(this->handle());
m_tensor_desc.setup(src.layout, m_param.param_dim);
BNTensorDescHolder tensor_desc(src.layout, m_param.param_dim,
m_param.fwd_mode);
float alpha = 1.0f, beta = 0.0f;
switch (m_param.fwd_mode) {
case param::BN::FwdMode::TRAINING:
#if CUDNN_VERSION >= 7410
cudnn_check(cudnnBatchNormalizationForwardTrainingEx(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
&alpha, &beta, // one & zero
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x
nullptr, nullptr, // zDesc & z
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor,
mean.raw_ptr, variance.raw_ptr, m_param.epsilon,
batch_mean.raw_ptr, batch_inv_variance.raw_ptr, nullptr,
workspace.raw_ptr, workspace.size, reserve.raw_ptr,
reserve.layout.access_bytes()));
#else
cudnn_check(cudnnBatchNormalizationForwardTraining(
handle, m_tensor_desc.bn_mode,
&alpha, &beta,
m_tensor_desc.xy_desc.desc, // xDesc
src.raw_ptr, // x
m_tensor_desc.xy_desc.desc, // yDesc
dst.raw_ptr, // y
m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
handle, tensor_desc.bn_mode, &alpha, &beta,
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor,
mean.raw_ptr, variance.raw_ptr, m_param.epsilon,
batch_mean.raw_ptr, batch_inv_variance.raw_ptr));
#endif // CUDNN_VERSION >= 7410
break;
case param::BN::FwdMode::INFERENCE:
cudnn_check(cudnnBatchNormalizationForwardInference(
handle, m_tensor_desc.bn_mode,
&alpha, &beta,
m_tensor_desc.xy_desc.desc, src.raw_ptr,
m_tensor_desc.xy_desc.desc, dst.raw_ptr,
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr,
handle, tensor_desc.bn_mode, &alpha, &beta,
tensor_desc.xy_desc.desc, src.raw_ptr,
tensor_desc.xy_desc.desc, dst.raw_ptr,
tensor_desc.param_desc.desc, bn_scale.raw_ptr,
bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr,
m_param.epsilon));
break;
......@@ -86,30 +154,79 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
}
}
size_t BNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) {
#if CUDNN_VERSION >= 7410
auto handle = cudnn_handle(this->handle());
BNTensorDescHolder tensor_desc(x, m_param.param_dim, m_param.fwd_mode);
size_t workspace_size;
cudnn_check(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN,
tensor_desc.xy_desc.desc, // xDesc
tensor_desc.xy_desc.desc, // yDesc
tensor_desc.xy_desc.desc, // dyDesc
nullptr, // dzDesc
tensor_desc.xy_desc.desc, // dxDesc
tensor_desc.param_desc.desc, // dBnScaleBiasDesc
nullptr, // activationDesc
&workspace_size));
return workspace_size;
#else
return 0;
#endif // CUDNN_VERSION >= 7410
}
size_t BNBackwardImpl::get_reserve_in_bytes(const TensorLayout& src) {
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode);
return batch_normalization::get_reserve_size(cudnn_handle(this->handle()),
tensor_desc);
}
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias,
_megdnn_tensor_out dx, _megdnn_workspace workspace) {
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) {
check_exec(x.layout, dy.layout, saved_batch_mean.layout,
saved_batch_inv_variance.layout, bn_scale.layout,
d_bn_scale.layout, d_bn_bias.layout, dx.layout,
workspace.size);
d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size,
reserve.layout.access_bytes());
auto handle = cudnn_handle(this->handle());
m_tensor_desc.setup(x.layout, m_param.param_dim);
BNTensorDescHolder tensor_desc(x.layout, m_param.param_dim,
m_param.fwd_mode);
float alpha = 1.0, beta = 0.0;
#if CUDNN_VERSION >= 7410
cudnn_check(cudnnBatchNormalizationBackwardEx(
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta,
&alpha, &beta, tensor_desc.xy_desc.desc,
x.raw_ptr, // xDesc & x
nullptr, nullptr, // yDesc & y
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy
nullptr, nullptr, // dzDesc & dz
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale
nullptr, // bnBias
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias
m_param.epsilon, saved_batch_mean.raw_ptr,
saved_batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr,
workspace.size, reserve.raw_ptr, reserve.layout.access_bytes()));
#else
cudnn_check(cudnnBatchNormalizationBackward(
handle, m_tensor_desc.bn_mode,
&alpha, &beta, &alpha, &beta,
m_tensor_desc.xy_desc.desc, x.raw_ptr,
m_tensor_desc.xy_desc.desc, dy.raw_ptr,
m_tensor_desc.xy_desc.desc, dx.raw_ptr,
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr,
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, m_param.epsilon,
saved_batch_mean.raw_ptr, saved_batch_inv_variance.raw_ptr));
handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta,
tensor_desc.xy_desc.desc, x.raw_ptr, // xDesc & x
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias
m_param.epsilon, saved_batch_mean.raw_ptr,
saved_batch_inv_variance.raw_ptr));
#endif
}
} // namespace cuda
......
......@@ -4,9 +4,9 @@
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
......@@ -20,14 +20,20 @@ namespace batch_normalization {
struct BNTensorDescHolder {
using ParamDim = param::BN::ParamDim;
using FwdMode = param::BN::FwdMode;
using Format = param::Convolution::Format;
TensorDesc xy_desc;
BNParamDesc param_desc;
cudnnBatchNormMode_t bn_mode;
void setup(const TensorLayout& x, const ParamDim& param_dim);
BNTensorDescHolder(const TensorLayout& x, const ParamDim& param_dim,
const FwdMode& fwd_mode);
};
size_t get_reserve_size(const cudnnHandle_t& handle,
const BNTensorDescHolder& tensor_desc);
} // namespace batch_normalization
class BNForwardImpl final : public BNForward {
......@@ -36,19 +42,15 @@ public:
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
private:
batch_normalization::BNTensorDescHolder m_tensor_desc;
const TensorLayout&, const TensorLayout&,
const TensorLayout&) override;
size_t get_reserve_in_bytes(const TensorLayout& src) override;
};
class BNBackwardImpl final : public BNBackward {
......@@ -57,20 +59,16 @@ public:
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) override;
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
_megdnn_tensor_out dx, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
private:
batch_normalization::BNTensorDescHolder m_tensor_desc;
const TensorLayout&, const TensorLayout&,
const TensorLayout&) override;
size_t get_reserve_in_bytes(const TensorLayout& src) override;
};
} // namespace cuda
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/batch_normalization/opr_impl.h"
......@@ -219,13 +220,14 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_inout variance,
_megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
_megdnn_tensor_out, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout,
variance.layout, batch_mean.layout, batch_inv_variance.layout,
dst.layout, workspace.size);
DNN_INC_FLOAT16(if (src.layout.dtype == dtype::Float16() &&
bn_scale.layout.dtype == dtype::Float32()) {
bn_scale.layout.dtype == dtype::Float32()) {
MEGDNN_DISPATCH_CPU_KERN_OPR(({
using T0 = typename DTypeTrait<dtype::Float16>::ctype;
using T1 = typename DTypeTrait<dtype::Float32>::ctype;
......@@ -263,7 +265,7 @@ WorkspaceBundle BNBackwardImpl::get_workspace_bundle(size_t x_size,
size_t BNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout& bn_scale, const TensorLayout&,
const TensorLayout&, const TensorLayout&) {
const TensorLayout&, const TensorLayout&, const TensorLayout&) {
auto x_size = x.total_nr_elems(), param_size = bn_scale.total_nr_elems();
return get_workspace_bundle(x_size, param_size).total_size_in_bytes();
}
......@@ -271,7 +273,7 @@ size_t BNBackwardImpl::get_workspace_in_bytes(
void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_scale, _megdnn_tensor_in,
_megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias,
_megdnn_tensor_out dx_out,
......@@ -286,7 +288,7 @@ void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in,
workspace.raw_ptr);
DNN_INC_FLOAT16(if (x_in.layout.dtype == dtype::Float16() &&
bn_scale.layout.dtype == dtype::Float32()) {
bn_scale.layout.dtype == dtype::Float32()) {
MEGDNN_DISPATCH_CPU_KERN_OPR(({
using T0 = typename DTypeTrait<dtype::Float16>::ctype;
using T1 = typename DTypeTrait<dtype::Float32>::ctype;
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
......@@ -21,16 +22,17 @@ public:
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; }
};
class BNBackwardImpl final : public BNBackward {
......@@ -39,15 +41,17 @@ public:
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) override;
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
_megdnn_tensor_out dx, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout& bn_scale,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override;
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; }
private:
WorkspaceBundle get_workspace_bundle(size_t x_size, size_t param_size,
......
......@@ -49,7 +49,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_out variance,
_megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
_megdnn_tensor_out, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout,
variance.layout, batch_mean.layout, batch_inv_variance.layout,
dst.layout, workspace.size);
......@@ -88,7 +89,7 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_scale, _megdnn_tensor_in,
_megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) {
......
......@@ -37,16 +37,17 @@ public:
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; }
private:
batch_normalization::BNTensorDescHolder m_tensor_desc;
......@@ -58,17 +59,18 @@ public:
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) override;
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
_megdnn_tensor_out dx, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; }
private:
batch_normalization::BNTensorDescHolder m_tensor_desc;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册