diff --git a/dnn/src/x86/matrix_mul/algos.cpp b/dnn/src/x86/matrix_mul/algos.cpp index 5ec7d957c7b4e38fb0ed02047b6030fed94543fc..e07c1ee22add65e46d6a05906e0203548ca5f69a 100644 --- a/dnn/src/x86/matrix_mul/algos.cpp +++ b/dnn/src/x86/matrix_mul/algos.cpp @@ -10,13 +10,13 @@ * implied. */ -#include "src/x86/matrix_mul/algos.h" -#include "midout.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_impl.h" +#include "src/x86/matrix_mul/algos.h" +#include "src/x86/matrix_mul/f32/strategy.h" #include "src/x86/matrix_mul/int8/strategy.h" -#include "src/x86/matrix_mul/f32/strategy.h" +#include "midout.h" MIDOUT_DECL(megdnn_x86_matmul_kern) MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8) @@ -170,6 +170,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable( (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == Param::Format::DEFAULT && preferred(kern_size_param) && is_supported(SIMDType::VNNI); } @@ -230,6 +231,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Mkldnn::usable( (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == Param::Format::DEFAULT && is_supported(SIMDType::VNNI) && preferred(kern_size_param); } @@ -365,8 +367,10 @@ bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable( kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); bool is_mode_ok = kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == Param::Format::DEFAULT && is_supported(SIMDType::AVX2); bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; + return is_param_ok; } bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const { @@ -440,6 +444,7 @@ bool MatrixMulImpl::AlgoInt8x8x16SSE::usable( kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); bool is_mode_ok = kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == Param::Format::DEFAULT && is_supported(SIMDType::SSE4_1); bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; return is_param_ok; @@ -478,13 +483,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( } bool MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::usable( const KernSizeParam& kern_size_param) const { - return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && - ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && - kern_size_param.C_type.enumv() == DTypeEnum::Int32) || - (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && - kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && - kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - is_supported(SIMDType::AVX2); + bool is_param_ok = + kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && + ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && + kern_size_param.C_type.enumv() == DTypeEnum::Int32) || + (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == Param::Format::DEFAULT && + is_supported(SIMDType::AVX2); + return is_param_ok; } size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( const KernSizeParam& kern_param) const { @@ -522,6 +530,7 @@ bool MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::usable( (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == Param::Format::DEFAULT && is_supported(SIMDType::AVX2); } size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( @@ -562,6 +571,7 @@ bool MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::usable( (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == Param::Format::DEFAULT && is_supported(SIMDType::SSE4_1); } size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(