提交 4d35397b 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/fallback): fix conv1x1/im2col usable and fuse-conv-bias get fp32xfp32-->qint8 error

GitOrigin-RevId: 5a3bfedd8a433cd17e735e3ba822027589263b69
上级 12dc36a6
......@@ -201,24 +201,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1)
return false;
if (param.src_type.enumv() != param.filter_type.enumv() &&
param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
#if !MEGDNN_DISABLE_FLOAT16
param.src_type.enumv() != DTypeEnum::Float16 &&
#endif
param.src_type.enumv() != DTypeEnum::Float32) {
return false;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode
//! is identity otherwise return false mean that 8x8x32 and 8x8x16
//! not support PostProcess
if (param.src_type.enumv() == param.filter_type.enumv() &&
(param.src_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32)) &&
param.bias_mode != megdnn::BiasMode::NO_BIAS &&
param.nonlineMode != megdnn::NonlineMode::IDENTITY)
return false;
if (param.src_type.enumv() == param.filter_type.enumv() &&
((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32) &&
param.bias_mode != megdnn::BiasMode::NO_BIAS &&
param.nonlineMode != megdnn::NonlineMode::IDENTITY)
if (param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
}
if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
......
......@@ -647,20 +647,27 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return false;
}
if (param.src_type.enumv() != param.filter_type.enumv() &&
param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
#if !MEGDNN_DISABLE_FLOAT16
param.src_type.enumv() != DTypeEnum::Float16 &&
#endif
param.src_type.enumv() != DTypeEnum::Float32) {
return false;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not
//! support PostProcess
if (param.src_type.enumv() == param.filter_type.enumv() &&
((param.src_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32)) ||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) &&
param.bias_mode != megdnn::BiasMode::NO_BIAS &&
if (param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
}
if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
//! current NCHW44 im2col only support DEFAULT mode matmul
......
......@@ -188,6 +188,24 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
}
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) {
using namespace conv_bias;
param::ConvBias cur_param;
using NLMode = param::ConvBias::NonlineMode;
std::vector<conv_bias::TestArg> args = get_conv_bias_args(
{1, 3}, {0}, {NLMode::IDENTITY, NLMode::RELU}, {1}, false, true);
NormalRNG default_rng;
Checker<ConvBias> checker(handle());
checker.set_dtype(0, dtype::Int8{});
checker.set_dtype(1, dtype::Int8{});
checker.set_dtype(2, dtype::Int16{});
checker.set_dtype(4, dtype::Int16{});
for (auto&& arg : args) {
checker.set_param(arg.param).execs(
{arg.src, arg.filter, arg.bias, {}, {}});
}
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) {
using namespace conv_bias;
param::ConvBias cur_param;
......
......@@ -1671,7 +1671,9 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
rewriter.get_var(typecvt->input(0))->owner_opr());
if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 ||
typecvt->output(0)->dtype().enumv() !=
DTypeTrait<dtype::QuantizedS8>::enumv)
DTypeTrait<dtype::QuantizedS8>::enumv ||
typecvt->input(0)->dtype().enumv() !=
DTypeTrait<dtype::QuantizedS32>::enumv)
return nullptr;
auto config = conv_bias->config();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册