From 4d35397bdfa48324055934c42b2c7d6bedd2f29e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Jun 2020 15:43:44 +0800 Subject: [PATCH] fix(dnn/fallback): fix conv1x1/im2col usable and fuse-conv-bias get fp32xfp32-->qint8 error GitOrigin-RevId: 5a3bfedd8a433cd17e735e3ba822027589263b69 --- dnn/src/fallback/conv_bias/conv1x1/algos.cpp | 33 +++++++++++--------- dnn/src/fallback/conv_bias/im2col/algos.cpp | 27 ++++++++++------ dnn/test/fallback/conv_bias.cpp | 18 +++++++++++ src/gopt/impl/inference.cpp | 4 ++- 4 files changed, 56 insertions(+), 26 deletions(-) diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index 28435abbd..611d7d68c 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -201,24 +201,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) return false; + if (param.src_type.enumv() != param.filter_type.enumv() && + param.src_type.enumv() != DTypeEnum::Int8 && + param.src_type.enumv() != DTypeEnum::QuantizedS8 && + param.src_type.enumv() != DTypeEnum::Quantized8Asymm && +#if !MEGDNN_DISABLE_FLOAT16 + param.src_type.enumv() != DTypeEnum::Float16 && +#endif + param.src_type.enumv() != DTypeEnum::Float32) { + return false; + } //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode //! is identity otherwise return false mean that 8x8x32 and 8x8x16 //! not support PostProcess - if (param.src_type.enumv() == param.filter_type.enumv() && - (param.src_type.enumv() == DTypeEnum::Int8 && - (param.dst_type.enumv() == DTypeEnum::Int16 || - param.dst_type.enumv() == DTypeEnum::Int32)) && - param.bias_mode != megdnn::BiasMode::NO_BIAS && - param.nonlineMode != megdnn::NonlineMode::IDENTITY) - return false; - - if (param.src_type.enumv() == param.filter_type.enumv() && - ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || - param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && - param.dst_type.enumv() == DTypeEnum::QuantizedS32) && - param.bias_mode != megdnn::BiasMode::NO_BIAS && - param.nonlineMode != megdnn::NonlineMode::IDENTITY) - return false; + if (param.dst_type.enumv() == DTypeEnum::Int16 || + param.dst_type.enumv() == DTypeEnum::Int32 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32) { + if (param.bias_mode != megdnn::BiasMode::NO_BIAS || + param.nonlineMode != megdnn::NonlineMode::IDENTITY) { + return false; + } + } if (opr->param().format == param::ConvBias::Format::NCHW44 || opr->param().format == param::ConvBias::Format::NCHW44_DOT) { diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index 5507f51a0..9d4e89e70 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -647,19 +647,26 @@ bool ConvBiasImpl::AlgoIm2col::usable( return false; } + if (param.src_type.enumv() != param.filter_type.enumv() && + param.src_type.enumv() != DTypeEnum::Int8 && + param.src_type.enumv() != DTypeEnum::QuantizedS8 && + param.src_type.enumv() != DTypeEnum::Quantized8Asymm && +#if !MEGDNN_DISABLE_FLOAT16 + param.src_type.enumv() != DTypeEnum::Float16 && +#endif + param.src_type.enumv() != DTypeEnum::Float32) { + return false; + } //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is //! identity otherwise return false mean that 8x8x32 and 8x8x16 not //! support PostProcess - if (param.src_type.enumv() == param.filter_type.enumv() && - ((param.src_type.enumv() == DTypeEnum::Int8 && - (param.dst_type.enumv() == DTypeEnum::Int16 || - param.dst_type.enumv() == DTypeEnum::Int32)) || - ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || - param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && - param.dst_type.enumv() == DTypeEnum::QuantizedS32)) && - param.bias_mode != megdnn::BiasMode::NO_BIAS && - param.nonlineMode != megdnn::NonlineMode::IDENTITY) { - return false; + if (param.dst_type.enumv() == DTypeEnum::Int16 || + param.dst_type.enumv() == DTypeEnum::Int32 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32) { + if (param.bias_mode != megdnn::BiasMode::NO_BIAS || + param.nonlineMode != megdnn::NonlineMode::IDENTITY) { + return false; + } } if (opr->param().format == param::ConvBias::Format::NCHW44 || opr->param().format == param::ConvBias::Format::NCHW44_DOT) { diff --git a/dnn/test/fallback/conv_bias.cpp b/dnn/test/fallback/conv_bias.cpp index 3fc82d124..d00719f07 100644 --- a/dnn/test/fallback/conv_bias.cpp +++ b/dnn/test/fallback/conv_bias.cpp @@ -188,6 +188,24 @@ void checker_conv_bias(std::vector args, Handle* handle, } } +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { + using namespace conv_bias; + param::ConvBias cur_param; + using NLMode = param::ConvBias::NonlineMode; + std::vector args = get_conv_bias_args( + {1, 3}, {0}, {NLMode::IDENTITY, NLMode::RELU}, {1}, false, true); + NormalRNG default_rng; + Checker checker(handle()); + checker.set_dtype(0, dtype::Int8{}); + checker.set_dtype(1, dtype::Int8{}); + checker.set_dtype(2, dtype::Int16{}); + checker.set_dtype(4, dtype::Int16{}); + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} + TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { using namespace conv_bias; param::ConvBias cur_param; diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index c43dd0d65..4d641942e 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1671,7 +1671,9 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { rewriter.get_var(typecvt->input(0))->owner_opr()); if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 || typecvt->output(0)->dtype().enumv() != - DTypeTrait::enumv) + DTypeTrait::enumv || + typecvt->input(0)->dtype().enumv() != + DTypeTrait::enumv) return nullptr; auto config = conv_bias->config(); -- GitLab