diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp index e32f6d3adfe6677b0533aff32bb7a1e325c56fa8..c3a982fe003870ed1b25b436b9b421d6dbde108a 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp @@ -67,7 +67,8 @@ struct StrategyHashParamEqual { return flags; }; }; - +//! NOTE: must keep consistence with can_make_conv1x1_strategy when you modify +//! this function std::unique_ptr create_conv1x1_strategy( const ConvBiasImpl::NCBKernSizeParam& param, MatrixMulImpl::AlgoBase::PackMode pack_mode, @@ -211,14 +212,64 @@ Conv1x1StrategyBase* Conv1x1Factory::make_conv1x1_strategy( bool Conv1x1Factory::can_make_conv1x1_strategy( const ConvBiasImpl::NCBKernSizeParam& param, MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format) { + bool ok_default_cb1 = + param.src_type.enumv() == DTypeTrait::enumv; + bool ok_default_cb2 = + param.filter_type.enumv() == param.src_type.enumv() && + ((param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == + DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv)); + bool ok_default_cb1_fp16 = false; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC || !MEGDNN_DISABLE_FLOAT16 - if ((pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK || - pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) && - param.src_type.enumv() == DTypeTrait::enumv) { - return false; - } + ok_default_cb1_fp16 = + param.src_type.enumv() == DTypeTrait::enumv; +#endif + bool ok_default_cb2_arm = false; +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + ok_default_cb2_arm = param.filter_type.enumv() == param.src_type.enumv() && + ((param.src_type.enumv() == + DTypeTrait::enumv && + param.dst_type.enumv() == + DTypeTrait::enumv) || + (param.src_type.enumv() == + DTypeTrait::enumv && + param.dst_type.enumv() == + DTypeTrait::enumv)); #endif - return true; + + bool ok_only_packa_cb1 = + param.src_type.enumv() == DTypeTrait::enumv; + bool ok_no_pack_cb1 = + param.src_type.enumv() == DTypeTrait::enumv; + bool ok_no_pack_cb2 = + param.filter_type.enumv() == param.src_type.enumv() && + ((param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == DTypeTrait::enumv) || + (param.src_type.enumv() == DTypeTrait::enumv && + param.dst_type.enumv() == + DTypeTrait::enumv)); + switch (pack_mode) { + case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: + return ok_default_cb1 || ok_default_cb2 || ok_default_cb1_fp16 || + ok_default_cb2_arm; + break; + case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA: + return ok_only_packa_cb1; + break; + case MatrixMulImpl::AlgoBase::PackMode::NO_PACK: + return ok_no_pack_cb1 || ok_no_pack_cb2; + break; + default: + return false; + } } } // namespace conv1x1