提交 40e79e9d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn/x86): fix x86 matrix usable ignore format

GitOrigin-RevId: 40fe508aca366efaa8b9ef1a799740a760e11531
上级 eab7ab05
......@@ -10,13 +10,13 @@
* implied.
*/
#include "src/x86/matrix_mul/algos.h"
#include "midout.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#include "src/x86/matrix_mul/algos.h"
#include "src/x86/matrix_mul/f32/strategy.h"
#include "src/x86/matrix_mul/int8/strategy.h"
#include "src/x86/matrix_mul/f32/strategy.h"
#include "midout.h"
MIDOUT_DECL(megdnn_x86_matmul_kern)
MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8)
......@@ -170,6 +170,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
preferred(kern_size_param) && is_supported(SIMDType::VNNI);
}
......@@ -230,6 +231,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Mkldnn::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::VNNI) && preferred(kern_size_param);
}
......@@ -365,8 +367,10 @@ bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable(
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
bool is_mode_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2);
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
return is_param_ok;
}
bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const {
......@@ -440,6 +444,7 @@ bool MatrixMulImpl::AlgoInt8x8x16SSE::usable(
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
bool is_mode_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::SSE4_1);
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
return is_param_ok;
......@@ -478,13 +483,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern(
}
bool MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
bool is_param_ok =
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2);
return is_param_ok;
}
size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
const KernSizeParam& kern_param) const {
......@@ -522,6 +530,7 @@ bool MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2);
}
size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace(
......@@ -562,6 +571,7 @@ bool MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::SSE4_1);
}
size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册