diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index 3a4a12fd0fa9f1ebd6817eb761307d718cea543e..de2e0efed408a38a2a097a85d5471082a75185ad 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -707,11 +707,11 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4, armv7::matmul::gemm_dot_quint8_4x8, uint8_t, int32_t); -/* ======================== Int8 MK4 8x6x4 dot algo ======================== */ +/* ======================== Int8 MK4 8x4x4 dot algo ======================== */ namespace { -void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { +void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("int8_mk4_8x6x4_dotprod_kern"_hash)) { + midout_iv("int8_mk4_8x4x4_dotprod_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; @@ -720,9 +720,9 @@ void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type, + armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type, C_type); - megdnn::matmul::GemmInterleaved( + megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); @@ -731,7 +731,7 @@ void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { } } // namespace -bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable( +bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable( const KernSizeParam& kern_size_param) const { return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || @@ -743,35 +743,35 @@ bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable( !kern_size_param.trA && !kern_size_param.trB; } -size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace( +size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace( const KernSizeParam& kern_size_param) const { MIDOUT_BEGIN( megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace"_hash)) { + midout_iv("AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace"_hash)) { auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type, + armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved< - armv7::matmul::gemm_mk4_dots8_8x6>(M, N, K, trA, trB, + armv7::matmul::gemm_mk4_dots8_8x4>(M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); } -MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_kern( +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern( const KernSizeParam&) const { - return int8_mk4_8x6x4_dotprod_kern; + return int8_mk4_8x4x4_dotprod_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x6x4DotProd, +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd, megdnn_armv7_matmul_kern, - "AlgoInt8x8x32MK4_8x6x4DotProd"_hash, - armv7::matmul::gemm_mk4_dots8_8x6, int8_t, + "AlgoInt8x8x32MK4_8x4x4DotProd"_hash, + armv7::matmul::gemm_mk4_dots8_8x4, int8_t, int32_t); #endif diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 923be14a16c40d710156f5ff9d45ff5308a2a923..03d4c36ea4a4df4f3ac38e071ccdcc0f60c4228b 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -94,11 +94,11 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; -class MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd final : public AlgoBase { +class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { public: bool is_reproducible() const override { return true; } const char* name() const override { - return "AARCH32_INT8_MK4_8X6X4_DOTPROD"; + return "AARCH32_INT8_MK4_8X4X4_DOTPROD"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h similarity index 51% rename from dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h rename to dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h index e0cc5f4afb822e742a04b551f57d615d3561d4af..1b76ef8a7b15f1ec160b68d8c1a3cbff1563432a 100644 --- a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h +++ b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h + * \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -17,205 +17,7 @@ namespace megdnn { namespace armv7 { -namespace matmul_mk4_dot_8x6x4 { - -// Overview of register layout: -// -// A 1x6x4 cell of Rhs is stored in 8bit in q0, q1. -// A 2x1x4x4 cell of Lhs is stored in 8bit in q2, q3 -// A 2x6x4 block of accumulators is stored in 8bit in q4-q15 -// -// +--------+ -// Rhs |q0[0-16]| -// |q1[0-16]| -// +--------+ -// Lhs | | -// +-------+-------+ - - - - +--------+ -// | q2[0-16]| | q4[0-4]| -// | q3[0-16]| | q5[0-4]| -// +---------+ | q6[0-4]| -// | q7[0-4]| -// | q8[0-4]| -// | q9[0-4]| -// |q10[0-4]| -// |q11[0-4]| -// |q12[0-4]| -// |q13[0-4]| -// |q14[0-4]| -// |q15[0-4]| -// +--------+ -// Accumulator - -static void kern_8x6(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { - K /= 4; - const int8_t* a_ptr = packA; - const int8_t* b_ptr = packB; - // Fix up for odd lengths - set a flag if K is odd, but make - // sure we round up the iteration count. - int oddk = (K & 1); - int k = (K + 1) / 2 - 1; - - LDC = LDC * sizeof(int32_t); - int32_t* outptr0 = output; - int32_t* outptr1; - - asm volatile( - // load accumulator C - "add %[outptr1], %[outptr0], %[LDC]\n" - "cmp %[is_first_k], #1\n" - "beq 1f\n" - "vld1.32 {d8, d9}, [%[outptr0]]!\n" - "vld1.32 {d10, d11}, [%[outptr0]]!\n" - "vld1.32 {d12, d13}, [%[outptr0]]!\n" - "vld1.32 {d14, d15}, [%[outptr0]]!\n" - "vld1.32 {d16, d17}, [%[outptr0]]!\n" - "vld1.32 {d18, d19}, [%[outptr0]]!\n" - "vld1.32 {d20, d21}, [%[outptr1]]!\n" - "vld1.32 {d22, d23}, [%[outptr1]]!\n" - "vld1.32 {d24, d25}, [%[outptr1]]!\n" - "vld1.32 {d26, d27}, [%[outptr1]]!\n" - "vld1.32 {d28, d29}, [%[outptr1]]!\n" - "vld1.32 {d30, d31}, [%[outptr1]]!\n" - - "b 2f\n" - - "1:\n" - "veor.s32 q4, q4, q4\n" - "veor.s32 q5, q5, q5\n" - "veor.s32 q6, q6, q6\n" - "veor.s32 q7, q7, q7\n" - "veor.s32 q8, q8, q8\n" - "veor.s32 q9, q9, q9\n" - "veor.s32 q10, q10, q10\n" - "veor.s32 q11, q11, q11\n" - "veor.s32 q12, q12, q12\n" - "veor.s32 q13, q13, q13\n" - "veor.s32 q14, q14, q14\n" - "veor.s32 q15, q15, q15\n" - - "2: \n" - "vld1.s8 {q0}, [%[b_ptr]]!\n" - "vld1.s8 {d2}, [%[b_ptr]]!\n" - "vld1.s8 {q2}, [%[a_ptr]]!\n" - "vld1.s8 {q3}, [%[a_ptr]]!\n" - - "cmp %[k], #0 \n" - "beq 4f \n" - - "3:\n" - "vsdot.s8 q4 , q2, d0[0]\n" - "vsdot.s8 q5 , q2, d0[1]\n" - "vsdot.s8 q6 , q2, d1[0]\n" - "vsdot.s8 q7 , q2, d1[1]\n" - "vsdot.s8 q8 , q2, d2[0]\n" - "vsdot.s8 q9 , q2, d2[1]\n" - "vsdot.s8 q10 , q3, d0[0]\n" - "vsdot.s8 q11 , q3, d0[1]\n" - "vsdot.s8 q12 , q3, d1[0]\n" - "vsdot.s8 q13 , q3, d1[1]\n" - "vsdot.s8 q14 , q3, d2[0]\n" - "vsdot.s8 q15 , q3, d2[1]\n" - - "vld1.s8 {q0}, [%[b_ptr]]!\n" - "vld1.s8 {d2}, [%[b_ptr]]!\n" - "vld1.s8 {q2}, [%[a_ptr]]!\n" - "vld1.s8 {q3}, [%[a_ptr]]!\n" - "vsdot.s8 q4 , q2, d0[0]\n" - "vsdot.s8 q5 , q2, d0[1]\n" - "vsdot.s8 q6 , q2, d1[0]\n" - "vsdot.s8 q7 , q2, d1[1]\n" - "vsdot.s8 q8 , q2, d2[0]\n" - "vsdot.s8 q9 , q2, d2[1]\n" - "vsdot.s8 q10 , q3, d0[0]\n" - "vsdot.s8 q11 , q3, d0[1]\n" - "vsdot.s8 q12 , q3, d1[0]\n" - "vsdot.s8 q13 , q3, d1[1]\n" - "vsdot.s8 q14 , q3, d2[0]\n" - "vsdot.s8 q15 , q3, d2[1]\n" - - "vld1.s8 {q0}, [%[b_ptr]]!\n" - "vld1.s8 {d2}, [%[b_ptr]]!\n" - "vld1.s8 {q2}, [%[a_ptr]]!\n" - "vld1.s8 {q3}, [%[a_ptr]]!\n" - - "subs %[k], %[k], #1\n" - "bne 3b\n" - - // Target to use when K is 1 or 2 (i.e. zero iterations of main - // loop) - "4:\n" - - "cmp %[oddk], #0 \n" - "bne 5f \n" - "vsdot.s8 q4 , q2, d0[0]\n" - "vsdot.s8 q5 , q2, d0[1]\n" - "vsdot.s8 q6 , q2, d1[0]\n" - "vsdot.s8 q7 , q2, d1[1]\n" - "vsdot.s8 q8 , q2, d2[0]\n" - "vsdot.s8 q9 , q2, d2[1]\n" - "vsdot.s8 q10 , q3, d0[0]\n" - "vsdot.s8 q11 , q3, d0[1]\n" - "vsdot.s8 q12 , q3, d1[0]\n" - "vsdot.s8 q13 , q3, d1[1]\n" - "vsdot.s8 q14 , q3, d2[0]\n" - "vsdot.s8 q15 , q3, d2[1]\n" - - "vld1.s8 {q0}, [%[b_ptr]]!\n" - "vld1.s8 {d2}, [%[b_ptr]]!\n" - "vld1.s8 {q2}, [%[a_ptr]]!\n" - "vld1.s8 {q3}, [%[a_ptr]]!\n" - "vsdot.s8 q4 , q2, d0[0]\n" - "vsdot.s8 q5 , q2, d0[1]\n" - "vsdot.s8 q6 , q2, d1[0]\n" - "vst1.32 {d8, d9}, [%[outptr0]]!\n" - "vsdot.s8 q7 , q2, d1[1]\n" - "vsdot.s8 q8 , q2, d2[0]\n" - "vsdot.s8 q9 , q2, d2[1]\n" - "vst1.32 {d10, d11}, [%[outptr0]]!\n" - "vsdot.s8 q10 , q3, d0[0]\n" - "vsdot.s8 q11 , q3, d0[1]\n" - "vsdot.s8 q12 , q3, d1[0]\n" - "vst1.32 {d12, d13}, [%[outptr0]]!\n" - "vsdot.s8 q13 , q3, d1[1]\n" - "vsdot.s8 q14 , q3, d2[0]\n" - "vsdot.s8 q15 , q3, d2[1]\n" - - "b 6f\n" - "5: \n" - "vsdot.s8 q4 , q2, d0[0]\n" - "vsdot.s8 q5 , q2, d0[1]\n" - "vsdot.s8 q6 , q2, d1[0]\n" - "vst1.32 {d8, d9}, [%[outptr0]]!\n" - "vsdot.s8 q7 , q2, d1[1]\n" - "vsdot.s8 q8 , q2, d2[0]\n" - "vsdot.s8 q9 , q2, d2[1]\n" - "vst1.32 {d10, d11}, [%[outptr0]]!\n" - "vsdot.s8 q10 , q3, d0[0]\n" - "vsdot.s8 q11 , q3, d0[1]\n" - "vsdot.s8 q12 , q3, d1[0]\n" - "vst1.32 {d12, d13}, [%[outptr0]]!\n" - "vsdot.s8 q13 , q3, d1[1]\n" - "vsdot.s8 q14 , q3, d2[0]\n" - "vsdot.s8 q15 , q3, d2[1]\n" - - "6: \n" - "vst1.32 {d14, d15}, [%[outptr0]]!\n" - "vst1.32 {d16, d17}, [%[outptr0]]!\n" - "vst1.32 {d18, d19}, [%[outptr0]]!\n" - "vst1.32 {d20, d21}, [%[outptr1]]!\n" - "vst1.32 {d22, d23}, [%[outptr1]]!\n" - "vst1.32 {d24, d25}, [%[outptr1]]!\n" - "vst1.32 {d26, d27}, [%[outptr1]]!\n" - "vst1.32 {d28, d29}, [%[outptr1]]!\n" - "vst1.32 {d30, d31}, [%[outptr1]]!\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), - [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k), - [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q14", "q15", "cc", "memory"); -} +namespace matmul_mk4_dot_8x4x4 { // Overview of register layout: // @@ -390,144 +192,6 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -// Overview of register layout: -// -// A 1x6x4 pingpong cell of Rhs is stored in 8bit in q0-q3. -// A 1x1x4x4 pingpong cell of Lhs is stored in 8bit in q4-q5 -// A 2x6x4 block of accumulators is stored in 8bit in q10-q15 -// -// +--------+ -// Rhs |q0[0-16]| -// |q1[0-16]| -// +--------+ -// Lhs | | -// +-------+-------+ - - - - +--------+ -// | q4[0-16]| |q10[0-4]| -// | q5[0-16]| |q11[0-4]| -// +---------+ |q12[0-4]| -// |q13[0-4]| -// |q14[0-4]| -// |q15[0-4]| -// +--------+ -// Accumulator - -static void kern_4x6(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { - K /= 4; - const int8_t* a_ptr = packA; - const int8_t* b_ptr = packB; - // Fix up for odd lengths - set a flag if K is odd, but make - // sure we round up the iteration count. - int oddk = (K & 1); - int k = (K + 1) / 2 - 1; - - LDC = LDC * sizeof(int32_t); - int32_t* outptr0 = output; - - asm volatile( - // load accumulator C - "cmp %[is_first_k], #1\n" - "beq 1f\n" - "vld1.32 {d20, d21}, [%[outptr0]]!\n" - "vld1.32 {d22, d23}, [%[outptr0]]!\n" - "vld1.32 {d24, d25}, [%[outptr0]]!\n" - "vld1.32 {d26, d27}, [%[outptr0]]!\n" - "vld1.32 {d28, d29}, [%[outptr0]]!\n" - "vld1.32 {d30, d31}, [%[outptr0]]!\n" - - "b 2f\n" - - "1:\n" - "veor.s32 q10, q10, q10\n" - "veor.s32 q11, q11, q11\n" - "veor.s32 q12, q12, q12\n" - "veor.s32 q13, q13, q13\n" - "veor.s32 q14, q14, q14\n" - "veor.s32 q15, q15, q15\n" - - "2: \n" - "vld1.s8 {q0}, [%[b_ptr]]!\n" - "vld1.s8 {d2}, [%[b_ptr]]!\n" - "vld1.s8 {q4}, [%[a_ptr]]!\n" - - "cmp %[k], #0 \n" - "beq 4f \n" - - "3:\n" - "vsdot.s8 q10 , q4, d0[0]\n" - "vsdot.s8 q11 , q4, d0[1]\n" - "vsdot.s8 q12 , q4, d1[0]\n" - "vld1.s8 {q2}, [%[b_ptr]]!\n" - "vld1.s8 {d6}, [%[b_ptr]]!\n" - "vld1.s8 {q5}, [%[a_ptr]]!\n" - "vsdot.s8 q13 , q4, d1[1]\n" - "vsdot.s8 q14 , q4, d2[0]\n" - "vsdot.s8 q15 , q4, d2[1]\n" - - "vld1.s8 {q0}, [%[b_ptr]]!\n" - "vsdot.s8 q10 , q5, d4[0]\n" - "vsdot.s8 q11 , q5, d4[1]\n" - "vsdot.s8 q12 , q5, d5[0]\n" - "vld1.s8 {d2}, [%[b_ptr]]!\n" - "vsdot.s8 q13 , q5, d5[1]\n" - "vsdot.s8 q14 , q5, d6[0]\n" - "vsdot.s8 q15 , q5, d6[1]\n" - "vld1.s8 {q4}, [%[a_ptr]]!\n" - - "subs %[k], %[k], #1\n" - "bne 3b\n" - - // Target to use when K is 1 or 2 (i.e. zero iterations of main - // loop) - "4:\n" - - "cmp %[oddk], #0 \n" - "bne 5f \n" - - "vsdot.s8 q10 , q4, d0[0]\n" - "vsdot.s8 q11 , q4, d0[1]\n" - "vsdot.s8 q12 , q4, d1[0]\n" - "vld1.s8 {q2}, [%[b_ptr]]!\n" - "vld1.s8 {d6}, [%[b_ptr]]!\n" - "vld1.s8 {q5}, [%[a_ptr]]!\n" - "vsdot.s8 q13 , q4, d1[1]\n" - "vsdot.s8 q14 , q4, d2[0]\n" - "vsdot.s8 q15 , q4, d2[1]\n" - - "vsdot.s8 q10 , q5, d4[0]\n" - "vsdot.s8 q11 , q5, d4[1]\n" - "vsdot.s8 q12 , q5, d5[0]\n" - "vst1.32 {d20, d21}, [%[outptr0]]!\n" - "vsdot.s8 q13 , q5, d5[1]\n" - "vsdot.s8 q14 , q5, d6[0]\n" - "vsdot.s8 q15 , q5, d6[1]\n" - "vst1.32 {d22, d23}, [%[outptr0]]!\n" - - "b 6f\n" - "5: \n" - - "vsdot.s8 q10 , q4, d0[0]\n" - "vsdot.s8 q11 , q4, d0[1]\n" - "vsdot.s8 q12 , q4, d1[0]\n" - "vst1.32 {d20, d21}, [%[outptr0]]!\n" - "vsdot.s8 q13 , q4, d1[1]\n" - "vsdot.s8 q14 , q4, d2[0]\n" - "vsdot.s8 q15 , q4, d2[1]\n" - "vst1.32 {d22, d23}, [%[outptr0]]!\n" - - "6: \n" - "vst1.32 {d24, d25}, [%[outptr0]]!\n" - "vst1.32 {d26, d27}, [%[outptr0]]!\n" - "vst1.32 {d28, d29}, [%[outptr0]]!\n" - "vst1.32 {d30, d31}, [%[outptr0]]!\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), - [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k), - [outptr0] "+r"(outptr0) - : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q14", "q15", "cc", "memory"); -} - // Overview of register layout: // // A 2x4x4 cell of Rhs is stored in 8bit in q1, q3. @@ -671,7 +335,7 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr, +static void gemm_dots8_8x4_pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, int kmax) { int y = y0, y_start = y0 / 4; @@ -692,14 +356,12 @@ static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, +static void gemm_dots8_8x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { const int ksize = kmax - k0; const int ksize4 = ksize * 4; - const int ksize6 = ksize * 6; int8_t* outptr = out; int8_t* outptr_base = out; - int8_t* outptr_base4 = out + ((xmax - x0) / 6) * ksize6; int k = k0; for (; k + 3 < kmax; k += 4) { @@ -708,13 +370,6 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, outptr = outptr_base; int x = x0; - for (; x + 5 < xmax; x += 6) { - memcpy(outptr, inptr, sizeof(int8_t) * 24); - outptr += ksize6; - inptr += 24; - } - - outptr = outptr_base4; for (; x + 3 < xmax; x += 4) { memcpy(outptr, inptr, sizeof(int8_t) * 16); outptr += ksize4; @@ -735,12 +390,11 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, *outptr++ = 0; } } - outptr_base += 24; - outptr_base4 += 16; + outptr_base += 16; } } -} // namespace matmul_mk4_dot_8x6x4 +} // namespace matmul_mk4_dot_8x4x4 } // namespace armv7 } // namespace megdnn #endif diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.cpp b/dnn/src/armv7/matrix_mul/int8/strategy.cpp index bdca391b9bf64412f39e377ba156f761c2d0013d..08934700ed5dd74491cd7621046221a6d9315dcc 100644 --- a/dnn/src/armv7/matrix_mul/int8/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/int8/strategy.cpp @@ -16,7 +16,7 @@ #include "src/armv7/matrix_mul/int8/kernel_4x8x8.h" #include "src/armv7/matrix_mul/int8/kernel_6x8x4.h" #include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h" -#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h" +#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_common.h" @@ -254,10 +254,10 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, } } -// ===========================gemm_mk4_dots8_8x6====================================== -MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x6); +// ===========================gemm_mk4_dots8_8x4====================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x4); -void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin, +void gemm_mk4_dots8_8x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, bool transpose) const { megdnn_assert(!transpose, @@ -266,49 +266,39 @@ void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin, "mk4 format matmul with m is not times of 4."); megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 format matmul with k is not times of 4."); - matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_A(out, in, ldin, y0, ymax, k0, + matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_A(out, in, ldin, y0, ymax, k0, kmax); } -void gemm_mk4_dots8_8x6::pack_B(dt_int8* out, const dt_int8* in, int ldin, +void gemm_mk4_dots8_8x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, bool transpose) const { megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported"); megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 format matmul with k is not times of 4."); - matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_B(out, in, ldin, x0, xmax, k0, + matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_B(out, in, ldin, x0, xmax, k0, kmax); } -void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB, +void gemm_mk4_dots8_8x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, dt_int32* C, size_t LDC, bool is_first_k, const dt_int32* bias, dt_int32* workspace) const { MEGDNN_MARK_USED_VAR(bias); constexpr size_t A_INTERLEAVE = 8; - constexpr size_t B_INTERLEAVE = 6; //! K is packed to times of 4 K = round_up(K, 4); const int K4 = K * 4; - const int K6 = K * 6; const int K8 = K * 8; size_t m = 0; for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { int32_t* output = C + ((m >> 2) * LDC); const dt_int8* cur_packB = packB; - size_t n = 0; - for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_mk4_dot_8x6x4::kern_8x6(packA, cur_packB, K, output, LDC, - is_first_k); - output += 24; - cur_packB += K6; - } - - for (; n < N; n += 4) { + for (size_t n = 0; n < N; n += 4) { size_t n_remain = std::min(N - n, 4); - matmul_mk4_dot_8x6x4::kern_8x4(packA, cur_packB, K, output, LDC, + matmul_mk4_dot_8x4x4::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, n_remain); output += 16; cur_packB += K4; @@ -318,16 +308,9 @@ void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB, for (; m < M; m += 4) { int32_t* output = C + ((m >> 2) * LDC); const dt_int8* cur_packB = packB; - size_t n = 0; - for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_mk4_dot_8x6x4::kern_4x6(packA, cur_packB, K, output, LDC, - is_first_k); - output += 24; - cur_packB += K6; - } - for (; n < N; n += 4) { + for (size_t n = 0; n < N; n += 4) { size_t n_remain = std::min(N - n, 4); - matmul_mk4_dot_8x6x4::kern_4x4(packA, cur_packB, K, output, LDC, + matmul_mk4_dot_8x4x4::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, n_remain); output += 16; cur_packB += K4; diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.h b/dnn/src/armv7/matrix_mul/int8/strategy.h index bcd9588f3cc809665583081a8c5630cbfe6df155..0f2b15e9c72bb3b02a85d9b0ae353e1b7372e6e3 100644 --- a/dnn/src/armv7/matrix_mul/int8/strategy.h +++ b/dnn/src/armv7/matrix_mul/int8/strategy.h @@ -27,8 +27,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, gemm_dots8_6x8); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 6, 4, false, false, - gemm_mk4_dots8_8x6); +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 4, 4, false, false, + gemm_mk4_dots8_8x4); #endif } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index ca157fc56c63a7e00bcbb57b8c7a880890fedbe4..4bb770820efb761e333e3c0c8e06e09fe5442776 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -29,7 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #if __ARM_FEATURE_DOTPROD AlgoInt8x8x32K6x8x4 int8_k6x8x4; AlgoQuint8DotK4x8x4 quint8_k4x8x4; - AlgoInt8x8x32MK4_8x6x4DotProd int8x8x32_mk4_8x6x4_dotprod; + AlgoInt8x8x32MK4_8x4x4DotProd int8x8x32_mk4_8x4x4_dotprod; #endif AlgoF32Gemv f32_gemv; AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; @@ -57,7 +57,7 @@ public: all_algos.emplace_back(&f16_mk8_4x8); #endif #if __ARM_FEATURE_DOTPROD - all_algos.emplace_back(&int8x8x32_mk4_8x6x4_dotprod); + all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); all_algos.emplace_back(&int8_k6x8x4); all_algos.emplace_back(&quint8_k4x8x4); #endif diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index 4b11b5ba8d5b1b4fe950b3c663e212fa079a6220..a9973dafef9e03d30f297ff595d2d9c82f6713c4 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -42,7 +42,7 @@ private: #if __ARM_FEATURE_DOTPROD class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 - class AlgoInt8x8x32MK4_8x6x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x6x4 + class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 // DotProduct #endif class AlgoPack; diff --git a/dnn/test/armv7/matrix_mul.cpp b/dnn/test/armv7/matrix_mul.cpp index 6ef2d23e2886a4841be8ab95d116d8995b7ca54a..0d17be3aa79000c459ddac9d98f0a16a3383d1dc 100644 --- a/dnn/test/armv7/matrix_mul.cpp +++ b/dnn/test/armv7/matrix_mul.cpp @@ -94,7 +94,7 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) { for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34}) args.emplace_back(m, n, k, 0); matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, - handle(), "AARCH32_INT8_MK4_8X6X4_DOTPROD", + handle(), "AARCH32_INT8_MK4_8X4X4_DOTPROD", param::MatrixMul::Format::MK4_DOT, 1, 1e-3, std::move(args)); } @@ -315,7 +315,7 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) { param.format = MatrixMul::Param::Format::MK4_DOT; Benchmarker benchmarker_mk4_dot(handle()); benchmarker_mk4_dot.set_before_exec_callback( - AlgoChecker("AARCH32_INT8_MK4_8X6X4_DOTPROD")); + AlgoChecker("AARCH32_INT8_MK4_8X4X4_DOTPROD")); benchmarker_mk4_dot.set_param(param) .set_dtype(0, dtype::Int8()) .set_dtype(1, dtype::Int8())