提交 48d1ac14 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): fix consistence between create_conv1x1_strategy and can_create_conv1x1_strategy

GitOrigin-RevId: 2d32998aca1ae77d35a1185c44ce84f5e04816ea
上级 6d0d5e5a
...@@ -67,7 +67,8 @@ struct StrategyHashParamEqual { ...@@ -67,7 +67,8 @@ struct StrategyHashParamEqual {
return flags; return flags;
}; };
}; };
//! NOTE: must keep consistence with can_make_conv1x1_strategy when you modify
//! this function
std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode, MatrixMulImpl::AlgoBase::PackMode pack_mode,
...@@ -211,14 +212,64 @@ Conv1x1StrategyBase* Conv1x1Factory::make_conv1x1_strategy( ...@@ -211,14 +212,64 @@ Conv1x1StrategyBase* Conv1x1Factory::make_conv1x1_strategy(
bool Conv1x1Factory::can_make_conv1x1_strategy( bool Conv1x1Factory::can_make_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format) { MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format) {
bool ok_default_cb1 =
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv;
bool ok_default_cb2 =
param.filter_type.enumv() == param.src_type.enumv() &&
((param.src_type.enumv() == DTypeTrait<dt_int8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dt_int32>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dt_int8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dt_int16>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv &&
param.dst_type.enumv() ==
DTypeTrait<dtype::QuantizedS32>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv));
bool ok_default_cb1_fp16 = false;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC || !MEGDNN_DISABLE_FLOAT16 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC || !MEGDNN_DISABLE_FLOAT16
if ((pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK || ok_default_cb1_fp16 =
pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) && param.src_type.enumv() == DTypeTrait<dt_float16>::enumv;
param.src_type.enumv() == DTypeTrait<dt_float16>::enumv) { #endif
return false; bool ok_default_cb2_arm = false;
} #if MEGDNN_AARCH64 || MEGDNN_ARMV7
ok_default_cb2_arm = param.filter_type.enumv() == param.src_type.enumv() &&
((param.src_type.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv &&
param.dst_type.enumv() ==
DTypeTrait<dtype::QuantizedS32>::enumv) ||
(param.src_type.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv &&
param.dst_type.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv));
#endif #endif
return true;
bool ok_only_packa_cb1 =
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv;
bool ok_no_pack_cb1 =
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv;
bool ok_no_pack_cb2 =
param.filter_type.enumv() == param.src_type.enumv() &&
((param.src_type.enumv() == DTypeTrait<dt_int8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dt_int16>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dt_int8>::enumv &&
param.dst_type.enumv() == DTypeTrait<dt_int32>::enumv) ||
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv &&
param.dst_type.enumv() ==
DTypeTrait<dtype::QuantizedS32>::enumv));
switch (pack_mode) {
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT:
return ok_default_cb1 || ok_default_cb2 || ok_default_cb1_fp16 ||
ok_default_cb2_arm;
break;
case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA:
return ok_only_packa_cb1;
break;
case MatrixMulImpl::AlgoBase::PackMode::NO_PACK:
return ok_no_pack_cb1 || ok_no_pack_cb2;
break;
default:
return false;
}
} }
} // namespace conv1x1 } // namespace conv1x1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册