提交 8ffed043 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/x86): fix matrix_mul quantized performance on vnni

GitOrigin-RevId: 4af6b8be60dd654003576bbd6b817411252fd306
上级 1d860f4d
...@@ -95,6 +95,14 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { ...@@ -95,6 +95,14 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public: public:
AlgoPack() { AlgoPack() {
//! FIXME: preference to use mkldnn algo on VNNI devices
//! But now mkldnn algo preference issue with NCHW->NHWC->NCHW
#if MEGDNN_X86_WITH_MKL_DNN
//! Create the mkldnn algo
all_algos.emplace_back(&mkldnn_conv_fp32);
all_algos.emplace_back(&mkldnn_matmul_qint8);
all_algos.emplace_back(&mkldnn_qint8);
#endif
all_algos.emplace_back(&stride1_direct_large_group); all_algos.emplace_back(&stride1_direct_large_group);
all_algos.emplace_back(&stride1_direct_small_group); all_algos.emplace_back(&stride1_direct_small_group);
all_algos.emplace_back(&stride2_direct_large_group); all_algos.emplace_back(&stride2_direct_large_group);
...@@ -105,14 +113,6 @@ public: ...@@ -105,14 +113,6 @@ public:
all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); all_algos.emplace_back(&avx2_stride2_chanwsie_qint8);
all_algos.emplace_back(&matmul); all_algos.emplace_back(&matmul);
//! preference to use mkldnn algo on VNNI devices
#if MEGDNN_X86_WITH_MKL_DNN
//! Create the mkldnn algo
all_algos.emplace_back(&mkldnn_conv_fp32);
all_algos.emplace_back(&mkldnn_matmul_qint8);
all_algos.emplace_back(&mkldnn_qint8);
#endif
static CpuOprDelegationStorage<> storage; static CpuOprDelegationStorage<> storage;
auto matmul_opr = storage.get<MatrixMul>(); auto matmul_opr = storage.get<MatrixMul>();
auto&& matmul_algos = auto&& matmul_algos =
...@@ -172,15 +172,18 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( ...@@ -172,15 +172,18 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
chanwise_avx2_stride2_qint8_usable_preferred(param) || chanwise_avx2_stride2_qint8_usable_preferred(param) ||
direct_avx2_stride1_int8_usable_preferred(param) || direct_avx2_stride1_int8_usable_preferred(param) ||
direct_avx2_stride2_int8_usable_preferred(param); direct_avx2_stride2_int8_usable_preferred(param);
}
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
conv_direct_chanwise_mkldnn_usable = conv_direct_chanwise_mkldnn_usable =
conv_direct_chanwise_mkldnn_usable || conv_direct_chanwise_mkldnn_usable ||
mkldnn_qint8_usable_preferred(param) || mkldnn_qint8_usable_preferred(param) ||
mkldnn_matmul_qint8_usable_preferred(param); mkldnn_matmul_qint8_usable_preferred(param);
#endif #endif
}
return !conv_direct_chanwise_mkldnn_usable; return !conv_direct_chanwise_mkldnn_usable ||
(is_supported(SIMDType::VNNI) &&
!chanwise_avx2_stride1_qint8_usable_preferred(param) &&
!chanwise_avx2_stride2_qint8_usable_preferred(param));
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/x86/matrix_mul/algos.h" #include "src/x86/matrix_mul/algos.h"
...@@ -17,7 +18,6 @@ ...@@ -17,7 +18,6 @@
#include "src/x86/matrix_mul/f32/strategy.h" #include "src/x86/matrix_mul/f32/strategy.h"
MIDOUT_DECL(megdnn_x86_matmul_kern) MIDOUT_DECL(megdnn_x86_matmul_kern)
MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8) MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8)
using namespace megdnn; using namespace megdnn;
...@@ -45,17 +45,16 @@ void f32_blas_kern(const MatrixMulImpl::KernParam& kern_param) { ...@@ -45,17 +45,16 @@ void f32_blas_kern(const MatrixMulImpl::KernParam& kern_param) {
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
void f32_blas_kern_only_packA(const MatrixMulImpl::KernParam& kern_param, void f32_blas_kern_only_packA(const MatrixMulImpl::KernParam& kern_param,
const void* a_panel, const void* b_panel) { const void* a_panel, const void* b_panel) {
MEGDNN_MARK_USED_VAR(b_panel); MEGDNN_MARK_USED_VAR(b_panel);
auto m = kern_param.M, n = kern_param.N, k = kern_param.K; auto m = kern_param.M, n = kern_param.N, k = kern_param.K;
const auto Bptr = kern_param.B<dt_float32>(); const auto Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>(); auto Cptr = kern_param.C<dt_float32>();
auto Atrd = kern_param.LDA, Btrd = kern_param.LDB, Ctrd = kern_param.LDC; auto Atrd = kern_param.LDA, Btrd = kern_param.LDB, Ctrd = kern_param.LDC;
disable_denorm(); disable_denorm();
cblas_sgemm_compute(CblasRowMajor, CblasPacked, CblasNoTrans, m, n, k, cblas_sgemm_compute(CblasRowMajor, CblasPacked, CblasNoTrans, m, n, k,
static_cast<const float*>(a_panel), Atrd, static_cast<const float*>(a_panel), Atrd, Bptr, Btrd,
Bptr, Btrd, 0.0f, Cptr, 0.0f, Cptr, Ctrd);
Ctrd);
} }
#endif #endif
...@@ -111,8 +110,9 @@ WorkspaceBundle MatrixMulImpl::AlgoF32MKLPackA::get_bundle( ...@@ -111,8 +110,9 @@ WorkspaceBundle MatrixMulImpl::AlgoF32MKLPackA::get_bundle(
return {nullptr, {a_size, 0, 0}}; return {nullptr, {a_size, 0, 0}};
} }
void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param, void* out, void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param,
size_t index, size_t stride) const { void* out, size_t index,
size_t stride) const {
MEGDNN_MARK_USED_VAR(stride); MEGDNN_MARK_USED_VAR(stride);
MEGDNN_MARK_USED_VAR(index); MEGDNN_MARK_USED_VAR(index);
auto m = kern_param.M, n = kern_param.N, k = kern_param.K; auto m = kern_param.M, n = kern_param.N, k = kern_param.K;
...@@ -164,7 +164,7 @@ size_t get_kern_workspace(MatrixMulImpl::KernSizeParam kern_size_param) { ...@@ -164,7 +164,7 @@ size_t get_kern_workspace(MatrixMulImpl::KernSizeParam kern_size_param) {
bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable( bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type == kern_size_param.B_type && return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
kern_size_param.C_type.enumv() == DTypeEnum::Int32) || kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
...@@ -389,9 +389,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( ...@@ -389,9 +389,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace(
m, n, k, trans_a, trans_b, strategy, cacheline) m, n, k, trans_a, trans_b, strategy, cacheline)
.get_workspace_size(); .get_workspace_size();
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, megdnn_x86_matmul_kern, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16,
8, x86::matmul::gemm_avx2_s8s8s32_2x4x16, megdnn_x86_matmul_kern, 8,
dt_int8, dt_int32); x86::matmul::gemm_avx2_s8s8s32_2x4x16,
dt_int8, dt_int32);
/*************************AlgoInt8x8x32SSEM4N8K2********************/ /*************************AlgoInt8x8x32SSEM4N8K2********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern(
...@@ -426,9 +427,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace( ...@@ -426,9 +427,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(
m, n, k, trans_a, trans_b, strategy, cacheline) m, n, k, trans_a, trans_b, strategy, cacheline)
.get_workspace_size(); .get_workspace_size();
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2,
AlgoInt8x8x32SSEM4N8K2, megdnn_x86_matmul_kern, 9, megdnn_x86_matmul_kern, 9,
x86::matmul::gemm_sse_s8s8s32_4x8x2, dt_int8, dt_int32, dt_int16); x86::matmul::gemm_sse_s8s8s32_4x8x2,
dt_int8, dt_int32, dt_int16);
/*************************AlgoF32MK8_8x8********************/ /*************************AlgoF32MK8_8x8********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern(
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/x86/matrix_mul/opr_impl.h" #include "src/x86/matrix_mul/opr_impl.h"
...@@ -41,9 +42,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -41,9 +42,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
public: public:
AlgoPack() { AlgoPack() {
if (is_supported(SIMDType::VNNI)) { if (is_supported(SIMDType::VNNI)) {
#if MEGDNN_X86_WITH_MKL_DNN
all_algos.emplace_back(&algoint8x8x32mkldnn);
#endif
#if MEGDNN_X86_WITH_VNNI #if MEGDNN_X86_WITH_VNNI
all_algos.emplace_back(&algoint8x8x32vnni); all_algos.emplace_back(&algoint8x8x32vnni);
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册