diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index 28435abbdbe91494529bd50a7190c425b3916a4e..611d7d68c56f6b626669eebd3201a2ec4c4127a7 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -201,24 +201,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) return false; + if (param.src_type.enumv() != param.filter_type.enumv() && + param.src_type.enumv() != DTypeEnum::Int8 && + param.src_type.enumv() != DTypeEnum::QuantizedS8 && + param.src_type.enumv() != DTypeEnum::Quantized8Asymm && +#if !MEGDNN_DISABLE_FLOAT16 + param.src_type.enumv() != DTypeEnum::Float16 && +#endif + param.src_type.enumv() != DTypeEnum::Float32) { + return false; + } //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode //! is identity otherwise return false mean that 8x8x32 and 8x8x16 //! not support PostProcess - if (param.src_type.enumv() == param.filter_type.enumv() && - (param.src_type.enumv() == DTypeEnum::Int8 && - (param.dst_type.enumv() == DTypeEnum::Int16 || - param.dst_type.enumv() == DTypeEnum::Int32)) && - param.bias_mode != megdnn::BiasMode::NO_BIAS && - param.nonlineMode != megdnn::NonlineMode::IDENTITY) - return false; - - if (param.src_type.enumv() == param.filter_type.enumv() && - ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || - param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && - param.dst_type.enumv() == DTypeEnum::QuantizedS32) && - param.bias_mode != megdnn::BiasMode::NO_BIAS && - param.nonlineMode != megdnn::NonlineMode::IDENTITY) - return false; + if (param.dst_type.enumv() == DTypeEnum::Int16 || + param.dst_type.enumv() == DTypeEnum::Int32 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32) { + if (param.bias_mode != megdnn::BiasMode::NO_BIAS || + param.nonlineMode != megdnn::NonlineMode::IDENTITY) { + return false; + } + } if (opr->param().format == param::ConvBias::Format::NCHW44 || opr->param().format == param::ConvBias::Format::NCHW44_DOT) { diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index 5507f51a0e2ea6c35d2ce208ef4c58da432125e5..9d4e89e7011a81eff94cada12150d07ebe43b469 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -647,19 +647,26 @@ bool ConvBiasImpl::AlgoIm2col::usable( return false; } + if (param.src_type.enumv() != param.filter_type.enumv() && + param.src_type.enumv() != DTypeEnum::Int8 && + param.src_type.enumv() != DTypeEnum::QuantizedS8 && + param.src_type.enumv() != DTypeEnum::Quantized8Asymm && +#if !MEGDNN_DISABLE_FLOAT16 + param.src_type.enumv() != DTypeEnum::Float16 && +#endif + param.src_type.enumv() != DTypeEnum::Float32) { + return false; + } //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is //! identity otherwise return false mean that 8x8x32 and 8x8x16 not //! support PostProcess - if (param.src_type.enumv() == param.filter_type.enumv() && - ((param.src_type.enumv() == DTypeEnum::Int8 && - (param.dst_type.enumv() == DTypeEnum::Int16 || - param.dst_type.enumv() == DTypeEnum::Int32)) || - ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || - param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && - param.dst_type.enumv() == DTypeEnum::QuantizedS32)) && - param.bias_mode != megdnn::BiasMode::NO_BIAS && - param.nonlineMode != megdnn::NonlineMode::IDENTITY) { - return false; + if (param.dst_type.enumv() == DTypeEnum::Int16 || + param.dst_type.enumv() == DTypeEnum::Int32 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32) { + if (param.bias_mode != megdnn::BiasMode::NO_BIAS || + param.nonlineMode != megdnn::NonlineMode::IDENTITY) { + return false; + } } if (opr->param().format == param::ConvBias::Format::NCHW44 || opr->param().format == param::ConvBias::Format::NCHW44_DOT) { diff --git a/dnn/test/fallback/conv_bias.cpp b/dnn/test/fallback/conv_bias.cpp index 3fc82d1241c4f43b2ebd97c6ae91876ffc724d8c..d00719f074fd4e4c273e68ac40558d5a84593799 100644 --- a/dnn/test/fallback/conv_bias.cpp +++ b/dnn/test/fallback/conv_bias.cpp @@ -188,6 +188,24 @@ void checker_conv_bias(std::vector args, Handle* handle, } } +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { + using namespace conv_bias; + param::ConvBias cur_param; + using NLMode = param::ConvBias::NonlineMode; + std::vector args = get_conv_bias_args( + {1, 3}, {0}, {NLMode::IDENTITY, NLMode::RELU}, {1}, false, true); + NormalRNG default_rng; + Checker checker(handle()); + checker.set_dtype(0, dtype::Int8{}); + checker.set_dtype(1, dtype::Int8{}); + checker.set_dtype(2, dtype::Int16{}); + checker.set_dtype(4, dtype::Int16{}); + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} + TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { using namespace conv_bias; param::ConvBias cur_param; diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index c43dd0d655f1b319747f50886a2ef1db6b5cc5ce..4d641942e915dcc509c88a37bd5b4f982aa17970 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1671,7 +1671,9 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { rewriter.get_var(typecvt->input(0))->owner_opr()); if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 || typecvt->output(0)->dtype().enumv() != - DTypeTrait::enumv) + DTypeTrait::enumv || + typecvt->input(0)->dtype().enumv() != + DTypeTrait::enumv) return nullptr; auto config = conv_bias->config();