提交 40e79e9d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn/x86): fix x86 matrix usable ignore format

GitOrigin-RevId: 40fe508aca366efaa8b9ef1a799740a760e11531
上级 eab7ab05
...@@ -10,13 +10,13 @@ ...@@ -10,13 +10,13 @@
* implied. * implied.
*/ */
#include "src/x86/matrix_mul/algos.h"
#include "midout.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h" #include "src/fallback/matrix_mul/gemm_impl.h"
#include "src/x86/matrix_mul/algos.h"
#include "src/x86/matrix_mul/f32/strategy.h"
#include "src/x86/matrix_mul/int8/strategy.h" #include "src/x86/matrix_mul/int8/strategy.h"
#include "src/x86/matrix_mul/f32/strategy.h" #include "midout.h"
MIDOUT_DECL(megdnn_x86_matmul_kern) MIDOUT_DECL(megdnn_x86_matmul_kern)
MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8) MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8)
...@@ -170,6 +170,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable( ...@@ -170,6 +170,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
preferred(kern_size_param) && is_supported(SIMDType::VNNI); preferred(kern_size_param) && is_supported(SIMDType::VNNI);
} }
...@@ -230,6 +231,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Mkldnn::usable( ...@@ -230,6 +231,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Mkldnn::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::VNNI) && preferred(kern_size_param); is_supported(SIMDType::VNNI) && preferred(kern_size_param);
} }
...@@ -365,8 +367,10 @@ bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable( ...@@ -365,8 +367,10 @@ bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable(
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
bool is_mode_ok = bool is_mode_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2); is_supported(SIMDType::AVX2);
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
return is_param_ok; return is_param_ok;
} }
bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const { bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const {
...@@ -440,6 +444,7 @@ bool MatrixMulImpl::AlgoInt8x8x16SSE::usable( ...@@ -440,6 +444,7 @@ bool MatrixMulImpl::AlgoInt8x8x16SSE::usable(
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
bool is_mode_ok = bool is_mode_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::SSE4_1); is_supported(SIMDType::SSE4_1);
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
return is_param_ok; return is_param_ok;
...@@ -478,13 +483,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( ...@@ -478,13 +483,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern(
} }
bool MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::usable( bool MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && bool is_param_ok =
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
kern_size_param.C_type.enumv() == DTypeEnum::Int32) || ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
is_supported(SIMDType::AVX2); kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2);
return is_param_ok;
} }
size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
const KernSizeParam& kern_param) const { const KernSizeParam& kern_param) const {
...@@ -522,6 +530,7 @@ bool MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::usable( ...@@ -522,6 +530,7 @@ bool MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2); is_supported(SIMDType::AVX2);
} }
size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace(
...@@ -562,6 +571,7 @@ bool MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::usable( ...@@ -562,6 +571,7 @@ bool MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::SSE4_1); is_supported(SIMDType::SSE4_1);
} }
size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace( size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册