提交 34659c2e 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mgb/dnn): remove armv7 matmul mk4dot block 8x6

GitOrigin-RevId: 4c746ef22895ebc6ab4298c66f71e239b437cc69
上级 48ac1e1a
......@@ -707,11 +707,11 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4,
armv7::matmul::gemm_dot_quint8_4x8,
uint8_t, int32_t);
/* ======================== Int8 MK4 8x6x4 dot algo ======================== */
/* ======================== Int8 MK4 8x4x4 dot algo ======================== */
namespace {
void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("int8_mk4_8x6x4_dotprod_kern"_hash)) {
midout_iv("int8_mk4_8x4x4_dotprod_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
......@@ -720,9 +720,9 @@ void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type,
armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_mk4_dots8_8x6>(
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_mk4_dots8_8x4>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
......@@ -731,7 +731,7 @@ void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
}
} // namespace
bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable(
bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
......@@ -743,35 +743,35 @@ bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable(
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace(
size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_armv7_matmul_kern,
midout_iv("AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace"_hash)) {
midout_iv("AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type,
armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<
armv7::matmul::gemm_mk4_dots8_8x6>(M, N, K, trA, trB,
armv7::matmul::gemm_mk4_dots8_8x4>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_kern(
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern(
const KernSizeParam&) const {
return int8_mk4_8x6x4_dotprod_kern;
return int8_mk4_8x4x4_dotprod_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x6x4DotProd,
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x32MK4_8x6x4DotProd"_hash,
armv7::matmul::gemm_mk4_dots8_8x6, int8_t,
"AlgoInt8x8x32MK4_8x4x4DotProd"_hash,
armv7::matmul::gemm_mk4_dots8_8x4, int8_t,
int32_t);
#endif
......
......@@ -94,11 +94,11 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd final : public AlgoBase {
class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH32_INT8_MK4_8X6X4_DOTPROD";
return "AARCH32_INT8_MK4_8X4X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
......
/**
* \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h
* \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -17,205 +17,7 @@
namespace megdnn {
namespace armv7 {
namespace matmul_mk4_dot_8x6x4 {
// Overview of register layout:
//
// A 1x6x4 cell of Rhs is stored in 8bit in q0, q1.
// A 2x1x4x4 cell of Lhs is stored in 8bit in q2, q3
// A 2x6x4 block of accumulators is stored in 8bit in q4-q15
//
// +--------+
// Rhs |q0[0-16]|
// |q1[0-16]|
// +--------+
// Lhs | |
// +-------+-------+ - - - - +--------+
// | q2[0-16]| | q4[0-4]|
// | q3[0-16]| | q5[0-4]|
// +---------+ | q6[0-4]|
// | q7[0-4]|
// | q8[0-4]|
// | q9[0-4]|
// |q10[0-4]|
// |q11[0-4]|
// |q12[0-4]|
// |q13[0-4]|
// |q14[0-4]|
// |q15[0-4]|
// +--------+
// Accumulator
static void kern_8x6(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = (K + 1) / 2 - 1;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
asm volatile(
// load accumulator C
"add %[outptr1], %[outptr0], %[LDC]\n"
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"vld1.32 {d8, d9}, [%[outptr0]]!\n"
"vld1.32 {d10, d11}, [%[outptr0]]!\n"
"vld1.32 {d12, d13}, [%[outptr0]]!\n"
"vld1.32 {d14, d15}, [%[outptr0]]!\n"
"vld1.32 {d16, d17}, [%[outptr0]]!\n"
"vld1.32 {d18, d19}, [%[outptr0]]!\n"
"vld1.32 {d20, d21}, [%[outptr1]]!\n"
"vld1.32 {d22, d23}, [%[outptr1]]!\n"
"vld1.32 {d24, d25}, [%[outptr1]]!\n"
"vld1.32 {d26, d27}, [%[outptr1]]!\n"
"vld1.32 {d28, d29}, [%[outptr1]]!\n"
"vld1.32 {d30, d31}, [%[outptr1]]!\n"
"b 2f\n"
"1:\n"
"veor.s32 q4, q4, q4\n"
"veor.s32 q5, q5, q5\n"
"veor.s32 q6, q6, q6\n"
"veor.s32 q7, q7, q7\n"
"veor.s32 q8, q8, q8\n"
"veor.s32 q9, q9, q9\n"
"veor.s32 q10, q10, q10\n"
"veor.s32 q11, q11, q11\n"
"veor.s32 q12, q12, q12\n"
"veor.s32 q13, q13, q13\n"
"veor.s32 q14, q14, q14\n"
"veor.s32 q15, q15, q15\n"
"2: \n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q2}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"cmp %[k], #0 \n"
"beq 4f \n"
"3:\n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q2}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q2}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"subs %[k], %[k], #1\n"
"bne 3b\n"
// Target to use when K is 1 or 2 (i.e. zero iterations of main
// loop)
"4:\n"
"cmp %[oddk], #0 \n"
"bne 5f \n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q2}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vst1.32 {d8, d9}, [%[outptr0]]!\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vst1.32 {d10, d11}, [%[outptr0]]!\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vst1.32 {d12, d13}, [%[outptr0]]!\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"b 6f\n"
"5: \n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vst1.32 {d8, d9}, [%[outptr0]]!\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vst1.32 {d10, d11}, [%[outptr0]]!\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vst1.32 {d12, d13}, [%[outptr0]]!\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"6: \n"
"vst1.32 {d14, d15}, [%[outptr0]]!\n"
"vst1.32 {d16, d17}, [%[outptr0]]!\n"
"vst1.32 {d18, d19}, [%[outptr0]]!\n"
"vst1.32 {d20, d21}, [%[outptr1]]!\n"
"vst1.32 {d22, d23}, [%[outptr1]]!\n"
"vst1.32 {d24, d25}, [%[outptr1]]!\n"
"vst1.32 {d26, d27}, [%[outptr1]]!\n"
"vst1.32 {d28, d29}, [%[outptr1]]!\n"
"vst1.32 {d30, d31}, [%[outptr1]]!\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k),
[outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q14", "q15", "cc", "memory");
}
namespace matmul_mk4_dot_8x4x4 {
// Overview of register layout:
//
......@@ -390,144 +192,6 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}
// Overview of register layout:
//
// A 1x6x4 pingpong cell of Rhs is stored in 8bit in q0-q3.
// A 1x1x4x4 pingpong cell of Lhs is stored in 8bit in q4-q5
// A 2x6x4 block of accumulators is stored in 8bit in q10-q15
//
// +--------+
// Rhs |q0[0-16]|
// |q1[0-16]|
// +--------+
// Lhs | |
// +-------+-------+ - - - - +--------+
// | q4[0-16]| |q10[0-4]|
// | q5[0-16]| |q11[0-4]|
// +---------+ |q12[0-4]|
// |q13[0-4]|
// |q14[0-4]|
// |q15[0-4]|
// +--------+
// Accumulator
static void kern_4x6(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = (K + 1) / 2 - 1;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
asm volatile(
// load accumulator C
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"vld1.32 {d20, d21}, [%[outptr0]]!\n"
"vld1.32 {d22, d23}, [%[outptr0]]!\n"
"vld1.32 {d24, d25}, [%[outptr0]]!\n"
"vld1.32 {d26, d27}, [%[outptr0]]!\n"
"vld1.32 {d28, d29}, [%[outptr0]]!\n"
"vld1.32 {d30, d31}, [%[outptr0]]!\n"
"b 2f\n"
"1:\n"
"veor.s32 q10, q10, q10\n"
"veor.s32 q11, q11, q11\n"
"veor.s32 q12, q12, q12\n"
"veor.s32 q13, q13, q13\n"
"veor.s32 q14, q14, q14\n"
"veor.s32 q15, q15, q15\n"
"2: \n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q4}, [%[a_ptr]]!\n"
"cmp %[k], #0 \n"
"beq 4f \n"
"3:\n"
"vsdot.s8 q10 , q4, d0[0]\n"
"vsdot.s8 q11 , q4, d0[1]\n"
"vsdot.s8 q12 , q4, d1[0]\n"
"vld1.s8 {q2}, [%[b_ptr]]!\n"
"vld1.s8 {d6}, [%[b_ptr]]!\n"
"vld1.s8 {q5}, [%[a_ptr]]!\n"
"vsdot.s8 q13 , q4, d1[1]\n"
"vsdot.s8 q14 , q4, d2[0]\n"
"vsdot.s8 q15 , q4, d2[1]\n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vsdot.s8 q10 , q5, d4[0]\n"
"vsdot.s8 q11 , q5, d4[1]\n"
"vsdot.s8 q12 , q5, d5[0]\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vsdot.s8 q13 , q5, d5[1]\n"
"vsdot.s8 q14 , q5, d6[0]\n"
"vsdot.s8 q15 , q5, d6[1]\n"
"vld1.s8 {q4}, [%[a_ptr]]!\n"
"subs %[k], %[k], #1\n"
"bne 3b\n"
// Target to use when K is 1 or 2 (i.e. zero iterations of main
// loop)
"4:\n"
"cmp %[oddk], #0 \n"
"bne 5f \n"
"vsdot.s8 q10 , q4, d0[0]\n"
"vsdot.s8 q11 , q4, d0[1]\n"
"vsdot.s8 q12 , q4, d1[0]\n"
"vld1.s8 {q2}, [%[b_ptr]]!\n"
"vld1.s8 {d6}, [%[b_ptr]]!\n"
"vld1.s8 {q5}, [%[a_ptr]]!\n"
"vsdot.s8 q13 , q4, d1[1]\n"
"vsdot.s8 q14 , q4, d2[0]\n"
"vsdot.s8 q15 , q4, d2[1]\n"
"vsdot.s8 q10 , q5, d4[0]\n"
"vsdot.s8 q11 , q5, d4[1]\n"
"vsdot.s8 q12 , q5, d5[0]\n"
"vst1.32 {d20, d21}, [%[outptr0]]!\n"
"vsdot.s8 q13 , q5, d5[1]\n"
"vsdot.s8 q14 , q5, d6[0]\n"
"vsdot.s8 q15 , q5, d6[1]\n"
"vst1.32 {d22, d23}, [%[outptr0]]!\n"
"b 6f\n"
"5: \n"
"vsdot.s8 q10 , q4, d0[0]\n"
"vsdot.s8 q11 , q4, d0[1]\n"
"vsdot.s8 q12 , q4, d1[0]\n"
"vst1.32 {d20, d21}, [%[outptr0]]!\n"
"vsdot.s8 q13 , q4, d1[1]\n"
"vsdot.s8 q14 , q4, d2[0]\n"
"vsdot.s8 q15 , q4, d2[1]\n"
"vst1.32 {d22, d23}, [%[outptr0]]!\n"
"6: \n"
"vst1.32 {d24, d25}, [%[outptr0]]!\n"
"vst1.32 {d26, d27}, [%[outptr0]]!\n"
"vst1.32 {d28, d29}, [%[outptr0]]!\n"
"vst1.32 {d30, d31}, [%[outptr0]]!\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k),
[outptr0] "+r"(outptr0)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q14", "q15", "cc", "memory");
}
// Overview of register layout:
//
// A 2x4x4 cell of Rhs is stored in 8bit in q1, q3.
......@@ -671,7 +335,7 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}
static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr,
static void gemm_dots8_8x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
int y = y0, y_start = y0 / 4;
......@@ -692,14 +356,12 @@ static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr,
}
}
static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin,
static void gemm_dots8_8x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0;
const int ksize4 = ksize * 4;
const int ksize6 = ksize * 6;
int8_t* outptr = out;
int8_t* outptr_base = out;
int8_t* outptr_base4 = out + ((xmax - x0) / 6) * ksize6;
int k = k0;
for (; k + 3 < kmax; k += 4) {
......@@ -708,13 +370,6 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin,
outptr = outptr_base;
int x = x0;
for (; x + 5 < xmax; x += 6) {
memcpy(outptr, inptr, sizeof(int8_t) * 24);
outptr += ksize6;
inptr += 24;
}
outptr = outptr_base4;
for (; x + 3 < xmax; x += 4) {
memcpy(outptr, inptr, sizeof(int8_t) * 16);
outptr += ksize4;
......@@ -735,12 +390,11 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin,
*outptr++ = 0;
}
}
outptr_base += 24;
outptr_base4 += 16;
outptr_base += 16;
}
}
} // namespace matmul_mk4_dot_8x6x4
} // namespace matmul_mk4_dot_8x4x4
} // namespace armv7
} // namespace megdnn
#endif
......
......@@ -16,7 +16,7 @@
#include "src/armv7/matrix_mul/int8/kernel_4x8x8.h"
#include "src/armv7/matrix_mul/int8/kernel_6x8x4.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
......@@ -254,10 +254,10 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
}
}
// ===========================gemm_mk4_dots8_8x6======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x6);
// ===========================gemm_mk4_dots8_8x4======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x4);
void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin,
void gemm_mk4_dots8_8x4::pack_A(dt_int8* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
......@@ -266,49 +266,39 @@ void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin,
"mk4 format matmul with m is not times of 4.");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 format matmul with k is not times of 4.");
matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_A(out, in, ldin, y0, ymax, k0,
matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_A(out, in, ldin, y0, ymax, k0,
kmax);
}
void gemm_mk4_dots8_8x6::pack_B(dt_int8* out, const dt_int8* in, int ldin,
void gemm_mk4_dots8_8x4::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
"matrix mul mk4 with transposed matrix B is not supported");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 format matmul with k is not times of 4.");
matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_B(out, in, ldin, x0, xmax, k0,
matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_B(out, in, ldin, x0, xmax, k0,
kmax);
}
void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB,
void gemm_mk4_dots8_8x4::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32* bias,
dt_int32* workspace) const {
MEGDNN_MARK_USED_VAR(bias);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 6;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K4 = K * 4;
const int K6 = K * 6;
const int K8 = K * 8;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + ((m >> 2) * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_dot_8x6x4::kern_8x6(packA, cur_packB, K, output, LDC,
is_first_k);
output += 24;
cur_packB += K6;
}
for (; n < N; n += 4) {
for (size_t n = 0; n < N; n += 4) {
size_t n_remain = std::min<size_t>(N - n, 4);
matmul_mk4_dot_8x6x4::kern_8x4(packA, cur_packB, K, output, LDC,
matmul_mk4_dot_8x4x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, n_remain);
output += 16;
cur_packB += K4;
......@@ -318,16 +308,9 @@ void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB,
for (; m < M; m += 4) {
int32_t* output = C + ((m >> 2) * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_dot_8x6x4::kern_4x6(packA, cur_packB, K, output, LDC,
is_first_k);
output += 24;
cur_packB += K6;
}
for (; n < N; n += 4) {
for (size_t n = 0; n < N; n += 4) {
size_t n_remain = std::min<size_t>(N - n, 4);
matmul_mk4_dot_8x6x4::kern_4x4(packA, cur_packB, K, output, LDC,
matmul_mk4_dot_8x4x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, n_remain);
output += 16;
cur_packB += K4;
......
......@@ -27,8 +27,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false,
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false,
gemm_dots8_6x8);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 6, 4, false, false,
gemm_mk4_dots8_8x6);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 4, 4, false, false,
gemm_mk4_dots8_8x4);
#endif
} // namespace matmul
} // namespace armv7
......
......@@ -29,7 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K6x8x4 int8_k6x8x4;
AlgoQuint8DotK4x8x4 quint8_k4x8x4;
AlgoInt8x8x32MK4_8x6x4DotProd int8x8x32_mk4_8x6x4_dotprod;
AlgoInt8x8x32MK4_8x4x4DotProd int8x8x32_mk4_8x4x4_dotprod;
#endif
AlgoF32Gemv f32_gemv;
AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16;
......@@ -57,7 +57,7 @@ public:
all_algos.emplace_back(&f16_mk8_4x8);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_mk4_8x6x4_dotprod);
all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod);
all_algos.emplace_back(&int8_k6x8x4);
all_algos.emplace_back(&quint8_k4x8x4);
#endif
......
......@@ -42,7 +42,7 @@ private:
#if __ARM_FEATURE_DOTPROD
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4
class AlgoInt8x8x32MK4_8x6x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x6x4
class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4
// DotProduct
#endif
class AlgoPack;
......
......@@ -94,7 +94,7 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) {
for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
args.emplace_back(m, n, k, 0);
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "AARCH32_INT8_MK4_8X6X4_DOTPROD",
handle(), "AARCH32_INT8_MK4_8X4X4_DOTPROD",
param::MatrixMul::Format::MK4_DOT, 1, 1e-3,
std::move(args));
}
......@@ -315,7 +315,7 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) {
param.format = MatrixMul::Param::Format::MK4_DOT;
Benchmarker<MatrixMul> benchmarker_mk4_dot(handle());
benchmarker_mk4_dot.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH32_INT8_MK4_8X6X4_DOTPROD"));
AlgoChecker<MatrixMul>("AARCH32_INT8_MK4_8X4X4_DOTPROD"));
benchmarker_mk4_dot.set_param(param)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册