From 48d1ac1433d72f6f2db25f396e03a4ebae50997c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 29 May 2020 21:51:16 +0800 Subject: [PATCH] fix(dnn/arm): fix consistence between create_conv1x1_strategy and can_create_conv1x1_strategy GitOrigin-RevId: 2d32998aca1ae77d35a1185c44ce84f5e04816ea --- .../conv_bias/conv1x1/conv1x1_strategy.cpp | 65 +++++++++++++++++-- 1 file changed, 58 insertions(+), 7 deletions(-) diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp index e32f6d3a..c3a982fe 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp @@ -67,7 +67,8 @@ struct StrategyHashParamEqual { return flags; }; }; - +//! NOTE: must keep consistence with can_make_conv1x1_strategy when you modify +//! this function std::unique_ptr create_conv1x1_strategy( const ConvBiasImpl::NCBKernSizeParam& param, MatrixMulImpl::AlgoBase::PackMode pack_mode, @@ -211,14 +212,64 @@ Conv1x1StrategyBase* Conv1x1Factory::make_conv1x1_strategy( bool Conv1x1Factory::can_make_conv1x1_strategy( const ConvBiasImpl::NCBKernSizeParam& param, MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format) { + bool ok_default_cb1 = + param.src_type.enumv() == DTypeTrait::enumv; + bool ok_default_cb2 = + param.filter_type.enumv() == param.src_type.enumv() && + ((param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == + DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv)); + bool ok_default_cb1_fp16 = false; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC || !MEGDNN_DISABLE_FLOAT16 - if ((pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK || - pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) && - param.src_type.enumv() == DTypeTrait::enumv) { - return false; - } + ok_default_cb1_fp16 = + param.src_type.enumv() == DTypeTrait::enumv; +#endif + bool ok_default_cb2_arm = false; +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + ok_default_cb2_arm = param.filter_type.enumv() == param.src_type.enumv() && + ((param.src_type.enumv() == + DTypeTrait::enumv && + param.dst_type.enumv() == + DTypeTrait::enumv) || + (param.src_type.enumv() == + DTypeTrait::enumv && + param.dst_type.enumv() == + DTypeTrait::enumv)); #endif - return true; + + bool ok_only_packa_cb1 = + param.src_type.enumv() == DTypeTrait::enumv; + bool ok_no_pack_cb1 = + param.src_type.enumv() == DTypeTrait::enumv; + bool ok_no_pack_cb2 = + param.filter_type.enumv() == param.src_type.enumv() && + ((param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == + DTypeTrait::enumv)); + switch (pack_mode) { + case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: + return ok_default_cb1 || ok_default_cb2 || ok_default_cb1_fp16 || + ok_default_cb2_arm; + break; + case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA: + return ok_only_packa_cb1; + break; + case MatrixMulImpl::AlgoBase::PackMode::NO_PACK: + return ok_no_pack_cb1 || ok_no_pack_cb2; + break; + default: + return false; + } } } // namespace conv1x1 -- GitLab