From ca2828ddcb2960e11bff10e80787e479e5927991 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 28 Jun 2021 20:58:57 +0800 Subject: [PATCH] fix(dnn/x86): fix x86 int8 matmul ldc bug GitOrigin-RevId: 2502f99000d5e90fdc410b1d0bf731668cd1077c --- .../matrix_mul/int8/avx2_strategy_2x4x16.cpp | 8 ++--- .../matrix_mul/int8/avx2_strategy_4x16x2.cpp | 8 ++--- .../matrix_mul/int8/sse_strategy_4x8x2.cpp | 8 ++--- dnn/test/common/checker.h | 12 ++++++- dnn/test/common/matrix_mul.cpp | 34 +++++++++++-------- dnn/test/common/matrix_mul.h | 8 +++-- dnn/test/x86/matrix_mul.cpp | 22 ++++++++---- 7 files changed, 63 insertions(+), 37 deletions(-) diff --git a/dnn/src/x86/matrix_mul/int8/avx2_strategy_2x4x16.cpp b/dnn/src/x86/matrix_mul/int8/avx2_strategy_2x4x16.cpp index 78af84482..6326771a4 100644 --- a/dnn/src/x86/matrix_mul/int8/avx2_strategy_2x4x16.cpp +++ b/dnn/src/x86/matrix_mul/int8/avx2_strategy_2x4x16.cpp @@ -71,13 +71,13 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr, auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; - auto iter_c_ptr = c_ptr + m_offset * n + n_offset; + auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset; matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16(iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); } if (n_end < n) { auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; - auto iter_c_ptr = c_ptr + m_offset * n + n_end; + auto iter_c_ptr = c_ptr + m_offset * ldc + n_end; matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_tile, n_remain); @@ -87,14 +87,14 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr, auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; - auto iter_c_ptr = c_ptr + m_end * n + n_offset; + auto iter_c_ptr = c_ptr + m_end * ldc + n_offset; matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, n_tile); } if (n_end < n) { auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; - auto iter_c_ptr = c_ptr + m_end * n + n_end; + auto iter_c_ptr = c_ptr + m_end * ldc + n_end; matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, n_remain); diff --git a/dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp b/dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp index 681727c1e..2c606314a 100644 --- a/dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp +++ b/dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp @@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; - auto iter_c_ptr = c_ptr + m_offset * n + n_offset; + auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset; matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); } if (n_remain > 0) { auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; - auto iter_c_ptr = c_ptr + m_offset * n + n_end; + auto iter_c_ptr = c_ptr + m_offset * ldc + n_end; if (n_remain <= 8) { matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain); @@ -79,13 +79,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; - auto iter_c_ptr = c_ptr + m_end * n + n_offset; + auto iter_c_ptr = c_ptr + m_end * ldc + n_offset; matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_remain_m( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain); } if (n_remain > 0) { auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; - auto iter_c_ptr = c_ptr + m_end * n + n_end; + auto iter_c_ptr = c_ptr + m_end * ldc + n_end; if (n_remain <= 8) { matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, diff --git a/dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp b/dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp index d47e7f048..afaa4fdbf 100644 --- a/dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp +++ b/dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp @@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) { auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; - auto iter_c_ptr = c_ptr + m_offset * n + n_offset; + auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset; matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); } if (n_remain > 0) { auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; - auto iter_c_ptr = c_ptr + m_offset * n + n_end; + auto iter_c_ptr = c_ptr + m_offset * ldc + n_end; matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_n( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain); } @@ -74,13 +74,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) { auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; - auto iter_c_ptr = c_ptr + m_end * n + n_offset; + auto iter_c_ptr = c_ptr + m_end * ldc + n_offset; matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain); } if (n_remain > 0) { auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; - auto iter_c_ptr = c_ptr + m_end * n + n_end; + auto iter_c_ptr = c_ptr + m_end * ldc + n_end; matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, n_remain); diff --git a/dnn/test/common/checker.h b/dnn/test/common/checker.h index 4241250de..c00632f29 100644 --- a/dnn/test/common/checker.h +++ b/dnn/test/common/checker.h @@ -78,6 +78,7 @@ protected: TensorsConstriant m_tensor_constraint; bool m_no_naive_and_check = false; bool m_stable_check = false; + bool m_force_deduce_dst = true; /** * the offset from the start of malloc memory * @@ -236,6 +237,12 @@ public: return *this; } + //! froce deduce dst + Checker& set_force_deduce_dst(bool force_deduce_dst) { + m_force_deduce_dst = force_deduce_dst; + return *this; + } + Checker& set_no_naive_check(bool no_naive_and_check) { m_no_naive_and_check = no_naive_and_check; return *this; @@ -343,7 +350,10 @@ void Checker::exec(TensorLayoutArray layouts) { auto opr_cur = this->opr(); opr_naive->param() = m_param; opr_cur->param() = m_param; - m_naive_proxy.deduce_layout(opr_naive.get(), layouts); + bool deduce_layout = layouts.back().ndim == 0; + if (deduce_layout || m_force_deduce_dst) { + m_naive_proxy.deduce_layout(opr_naive.get(), layouts); + } auto exec_naive = [this, &opr_naive, &layouts, &opr_relayout](const TensorValueArray& values) { TensorValueArray contig_values = values; diff --git a/dnn/test/common/matrix_mul.cpp b/dnn/test/common/matrix_mul.cpp index 00b63cdd0..e703fce15 100644 --- a/dnn/test/common/matrix_mul.cpp +++ b/dnn/test/common/matrix_mul.cpp @@ -101,7 +101,7 @@ std::vector matrix_mul::get_matmul_args_mask( size_t Astride = mask & 1 ? m + 2 : k + 2; // B: (k, n) size_t Bstride = mask & 2 ? k + 2 : n + 2; - size_t Cstride = n + 2; + size_t Cstride = n * 2 + 2; args.emplace_back(m, n, k, mask, Astride, Bstride, Cstride); } return args; @@ -183,9 +183,11 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format, size_t nbase, - float eps, std::vector&& user_args) { + float eps, std::vector&& user_args, + bool force_deduce_dst) { megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); Checker checker(handle); + checker.set_force_deduce_dst(force_deduce_dst); if (!algo.name.empty()) { checker.set_before_exec_callback(AlgoChecker(algo)); } @@ -245,16 +247,16 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, for (auto& arg : args) { size_t m = arg.m, n = arg.n, k = arg.k; -#if MEGDNN_WITH_CUDA - //[NOTE]: cublas can only process 4B aligned 8-bit input matrix; - bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 || - A_dtype.enumv() == DTypeEnum::QuantizedS8 || - A_dtype.enumv() == DTypeEnum::Uint8 || - A_dtype.enumv() == DTypeEnum::Quantized8Asymm; - if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) { - continue; + if (handle->type() == Handle::HandleType::CUDA) { + //! NOTE: cublas can only process 4B aligned 8-bit input matrix; + bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 || + A_dtype.enumv() == DTypeEnum::QuantizedS8 || + A_dtype.enumv() == DTypeEnum::Uint8 || + A_dtype.enumv() == DTypeEnum::Quantized8Asymm; + if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) { + continue; + } } -#endif Param param; param.transposeA = arg.mask & 0x1; @@ -312,20 +314,22 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo, float eps, - std::vector&& args) { + std::vector&& args, + bool force_deduce_dst) { check_matrix_mul( A_dtype, B_dtype, C_dtype, handle, algo, param::MatrixMul::Format::DEFAULT, 8, eps, - std::forward(args)); + std::forward(args), force_deduce_dst); } void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format, size_t nbase, - float eps) { + float eps, bool force_deduce_dst) { check_matrix_mul(A_dtype, B_dtype, C_dtype, handle, algo, - format, nbase, eps); + format, nbase, eps, {}, + force_deduce_dst); } #if MEGDNN_WITH_BENCHMARK diff --git a/dnn/test/common/matrix_mul.h b/dnn/test/common/matrix_mul.h index ab3057e09..7c6da5292 100644 --- a/dnn/test/common/matrix_mul.h +++ b/dnn/test/common/matrix_mul.h @@ -68,19 +68,21 @@ void check_matrix_mul( DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo = {"", {}}, param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, - size_t nbase = 8, float eps = 1e-3, std::vector&& args = {}); + size_t nbase = 8, float eps = 1e-3, std::vector&& args = {}, + bool force_deduce_dst = true); void check_matrix_mul( DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo = {"", {}}, param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, - size_t nbase = 8, float eps = 1e-3); + size_t nbase = 8, float eps = 1e-3, bool force_deduce_dst = true); void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo = {"", {}}, float eps = 1e-3, - std::vector&& args = {}); + std::vector&& args = {}, + bool force_deduce_dst = true); #if MEGDNN_WITH_BENCHMARK std::vector get_benchmark_matmul_args(); diff --git a/dnn/test/x86/matrix_mul.cpp b/dnn/test/x86/matrix_mul.cpp index 356d69c46..5f8903a11 100644 --- a/dnn/test/x86/matrix_mul.cpp +++ b/dnn/test/x86/matrix_mul.cpp @@ -44,21 +44,31 @@ TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) { //! FIXME: need to add tests of GEMV and QUINT8 TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, - handle(), "X86_INT8X8X32_AVX2_2X4X16"); + handle(), "X86_INT8X8X32_AVX2_2X4X16", + param::MatrixMul::Format::DEFAULT, 8, 1e-3, + false); matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, - handle(), "X86_INT8X8X32_AVX2_4X16X2"); + handle(), "X86_INT8X8X32_AVX2_4X16X2", + param::MatrixMul::Format::DEFAULT, 8, 1e-3, + false); } TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, - handle(), "X86_INT8X8X16_AVX2"); + handle(), "X86_INT8X8X16_AVX2", + param::MatrixMul::Format::DEFAULT, 8, 1e-3, + false); } TEST_F(X86, MATRIX_MUL_SSE_8X8X16) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, - handle(), "X86_INT8X8X16_SSE"); + handle(), "X86_INT8X8X16_SSE", + param::MatrixMul::Format::DEFAULT, 8, 1e-3, + false); } TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, - handle(), "X86_INT8X8X32_SSE_4X8X2"); + handle(), "X86_INT8X8X32_SSE_4X8X2", + param::MatrixMul::Format::DEFAULT, 8, 1e-3, + false); } #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM @@ -72,7 +82,7 @@ TEST_F(X86, MATRIX_MUL_MKL_PACKA) { TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) { matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "X86_F32MK8_8X8", - param::MatrixMul::Format::MK8, 1); + param::MatrixMul::Format::MK8, 1, 1e-3, false); } #if MEGDNN_WITH_BENCHMARK -- GitLab