提交 48d1ac14 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): fix consistence between create_conv1x1_strategy and can_create_conv1x1_strategy

GitOrigin-RevId: 2d32998aca1ae77d35a1185c44ce84f5e04816ea
上级 6d0d5e5a
......@@ -67,7 +67,8 @@ struct StrategyHashParamEqual {
return flags;
};
};
//! NOTE: must keep consistence with can_make_conv1x1_strategy when you modify
//! this function
std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode,
......@@ -211,14 +212,64 @@ Conv1x1StrategyBase* Conv1x1Factory::make_conv1x1_strategy(
bool Conv1x1Factory::can_make_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format) {
bool ok_default_cb1 =
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv;
bool ok_default_cb2 =
param.filter_type.enumv() == param.src_type.enumv() &&
((param.src_type.enumv() == DTypeTrait<dt_int8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dt_int32>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dt_int8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dt_int16>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv &&
param.dst_type.enumv() ==
DTypeTrait<dtype::QuantizedS32>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv));
bool ok_default_cb1_fp16 = false;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC || !MEGDNN_DISABLE_FLOAT16
if ((pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK ||
pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) &&
param.src_type.enumv() == DTypeTrait<dt_float16>::enumv) {
return false;
}
ok_default_cb1_fp16 =
param.src_type.enumv() == DTypeTrait<dt_float16>::enumv;
#endif
bool ok_default_cb2_arm = false;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
ok_default_cb2_arm = param.filter_type.enumv() == param.src_type.enumv() &&
((param.src_type.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv &&
param.dst_type.enumv() ==
DTypeTrait<dtype::QuantizedS32>::enumv) ||
(param.src_type.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv &&
param.dst_type.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv));
#endif
return true;
bool ok_only_packa_cb1 =
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv;
bool ok_no_pack_cb1 =
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv;
bool ok_no_pack_cb2 =
param.filter_type.enumv() == param.src_type.enumv() &&
((param.src_type.enumv() == DTypeTrait<dt_int8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dt_int16>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dt_int8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dt_int32>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv &&
param.dst_type.enumv() ==
DTypeTrait<dtype::QuantizedS32>::enumv));
switch (pack_mode) {
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT:
return ok_default_cb1 || ok_default_cb2 || ok_default_cb1_fp16 ||
ok_default_cb2_arm;
break;
case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA:
return ok_only_packa_cb1;
break;
case MatrixMulImpl::AlgoBase::PackMode::NO_PACK:
return ok_no_pack_cb1 || ok_no_pack_cb2;
break;
default:
return false;
}
}
} // namespace conv1x1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册