提交 9b4cd92b 编写于 作者: M Megvii Engine Team

fix(mgb/dnn): fix cudnnConvBiasActivation crash on nchw32 int8 with oc > 256

GitOrigin-RevId: 20c0b90575ece88da0d1aafd2ec1751b575137ea
上级 34773ba3
......@@ -50,7 +50,11 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
#if (CUDNN_MAJOR == 8 && CUDNN_MINOR < 2)
if (m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM &&
param.format == param::ConvBias::Format::NCHW4 &&
(param.format == param::ConvBias::Format::NCHW4
#if (CUDNN_VERSION == 8004)
|| param.format == param::ConvBias::Format::NCHW32
#endif
) &&
args.filter_meta.group * args.filter_meta.ocpg > 256 &&
args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS8) {
......
......@@ -498,7 +498,7 @@ const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> CudnnAl
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr>
algos = {
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false),
#if CUDNN_VERSION == 8004
#if (CUDNN_VERSION >= 8000 && CUDNN_VERSION <= 8201)
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, true),
#else
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, false),
......
......@@ -1223,8 +1223,8 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_TENSORCORE_INT8) {
for (size_t fh : {3, 5, 7}) {
for (int ph : {static_cast<int>(fh / 2), 0}) {
for (int sh : {1, 2}) {
for (size_t ih : {9, 11, 12, 13, 16}) {
for (size_t iw : {8, 27, 32, 40}) {
for (size_t ih : {9, 11, 12}) {
for (size_t iw : {8, 27, 32}) {
param.nonlineMode = mode;
param.stride_h = param.stride_w = sh;
param.pad_h = param.pad_w = ph;
......@@ -1268,6 +1268,29 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_TENSORCORE_INT8) {
}
}
}
{ //! convbiasactivation algo crash when oc > 256 && cudnn v8.0.4
param.nonlineMode = NonlineMode::RELU;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 0;
checker.set_dtype(0, dtype::QuantizedS8(1.3f))
.set_dtype(1, dtype::QuantizedS8(1.3f))
.set_dtype(2, dtype::QuantizedS32(1.3f * 1.3f))
.set_dtype(3, dtype::QuantizedS8(1.7f))
.set_dtype(4, dtype::QuantizedS8(1.2f * 1.2f))
.set_rng(0, &int_rng)
.set_rng(1, &int_rng)
.set_rng(2, &int_rng)
.set_rng(3, &int_rng)
.set_epsilon(1 + 1e-3)
.set_param(param)
.execs({{2, 8, 12, 12, 32},
{512, 8, 1, 1, 32},
{1, 16, 1, 1, 32},
{},
{}});
}
}
#if MEGDNN_WITH_BENCHMARK
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册