提交 36e3bb6e 编写于 作者: M Megvii Engine Team

feat(mgb/dnn): add armv7 mk4_dot matmul

GitOrigin-RevId: d4206f8e21d1f58a7e07e1345d2738dc76e7bfbd
上级 580a2753
......@@ -706,6 +706,73 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4,
"AlgoQuint8DotK4x8x4"_hash,
armv7::matmul::gemm_dot_quint8_4x8,
uint8_t, int32_t);
/* ======================== Int8 MK4 8x6x4 dot algo ======================== */
namespace {
void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("int8_mk4_8x6x4_dotprod_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_mk4_dots8_8x6>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // namespace
bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_armv7_matmul_kern,
midout_iv("AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<
armv7::matmul::gemm_mk4_dots8_8x6>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_kern(
const KernSizeParam&) const {
return int8_mk4_8x6x4_dotprod_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x6x4DotProd,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x32MK4_8x6x4DotProd"_hash,
armv7::matmul::gemm_mk4_dots8_8x6, int8_t,
int32_t);
#endif
/* ===================== F32 algo K4x8 ===================== */
......
......@@ -93,6 +93,18 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH32_INT8_MK4_8X6X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
#endif
class MatrixMulImpl::AlgoF32Gemv final
......
......@@ -125,6 +125,20 @@ static inline void interleave_4x1_2_d(const int64_t*& inptr0,
: "q0", "q1", "q2", "q3", "cc", "memory");
}
static inline void interleave_2x1_4_s(const int32_t*& inptr0,
const int32_t*& inptr1,
int32_t*& outptr) {
asm volatile(
"vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3
"vld1.32 {d2, d3}, [%[inptr1]]!\n" // A0A1A2A3
"vst1.32 {d0, d1}, [%[outptr]]!\n"
"vst1.32 {d2, d3}, [%[outptr]]!\n"
:
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr)
:
: "d0", "d1", "d2", "d3", "cc", "memory");
}
template <typename T>
static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......@@ -188,6 +202,17 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1,
: "q0", "q1", "q2", "q3", "memory");
}
template <typename T>
static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1,
T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"interleave_2x4_4_b only support uint8_t and int8_t");
interleave_2x1_4_s(reinterpret_cast<const int32_t*&>(inptr0),
reinterpret_cast<const int32_t*&>(inptr1),
reinterpret_cast<int32_t*&>(outptr));
}
template <typename T>
static inline void interleave_6x4_4_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......
/**
* \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
namespace megdnn {
namespace armv7 {
namespace matmul_mk4_dot_8x6x4 {
// Overview of register layout:
//
// A 1x6x4 cell of Rhs is stored in 8bit in q0, q1.
// A 2x1x4x4 cell of Lhs is stored in 8bit in q2, q3
// A 2x6x4 block of accumulators is stored in 8bit in q4-q15
//
// +--------+
// Rhs |q0[0-16]|
// |q1[0-16]|
// +--------+
// Lhs | |
// +-------+-------+ - - - - +--------+
// | q2[0-16]| | q4[0-4]|
// | q3[0-16]| | q5[0-4]|
// +---------+ | q6[0-4]|
// | q7[0-4]|
// | q8[0-4]|
// | q9[0-4]|
// |q10[0-4]|
// |q11[0-4]|
// |q12[0-4]|
// |q13[0-4]|
// |q14[0-4]|
// |q15[0-4]|
// +--------+
// Accumulator
static void kern_8x6(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = (K + 1) / 2 - 1;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
asm volatile(
// load accumulator C
"add %[outptr1], %[outptr0], %[LDC]\n"
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"vld1.32 {d8, d9}, [%[outptr0]]!\n"
"vld1.32 {d10, d11}, [%[outptr0]]!\n"
"vld1.32 {d12, d13}, [%[outptr0]]!\n"
"vld1.32 {d14, d15}, [%[outptr0]]!\n"
"vld1.32 {d16, d17}, [%[outptr0]]!\n"
"vld1.32 {d18, d19}, [%[outptr0]]!\n"
"vld1.32 {d20, d21}, [%[outptr1]]!\n"
"vld1.32 {d22, d23}, [%[outptr1]]!\n"
"vld1.32 {d24, d25}, [%[outptr1]]!\n"
"vld1.32 {d26, d27}, [%[outptr1]]!\n"
"vld1.32 {d28, d29}, [%[outptr1]]!\n"
"vld1.32 {d30, d31}, [%[outptr1]]!\n"
"b 2f\n"
"1:\n"
"veor.s32 q4, q4, q4\n"
"veor.s32 q5, q5, q5\n"
"veor.s32 q6, q6, q6\n"
"veor.s32 q7, q7, q7\n"
"veor.s32 q8, q8, q8\n"
"veor.s32 q9, q9, q9\n"
"veor.s32 q10, q10, q10\n"
"veor.s32 q11, q11, q11\n"
"veor.s32 q12, q12, q12\n"
"veor.s32 q13, q13, q13\n"
"veor.s32 q14, q14, q14\n"
"veor.s32 q15, q15, q15\n"
"2: \n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q2}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"cmp %[k], #0 \n"
"beq 4f \n"
"3:\n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q2}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q2}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"subs %[k], %[k], #1\n"
"bne 3b\n"
// Target to use when K is 1 or 2 (i.e. zero iterations of main
// loop)
"4:\n"
"cmp %[oddk], #0 \n"
"bne 5f \n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q2}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vst1.32 {d8, d9}, [%[outptr0]]!\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vst1.32 {d10, d11}, [%[outptr0]]!\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vst1.32 {d12, d13}, [%[outptr0]]!\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"b 6f\n"
"5: \n"
"vsdot.s8 q4 , q2, d0[0]\n"
"vsdot.s8 q5 , q2, d0[1]\n"
"vsdot.s8 q6 , q2, d1[0]\n"
"vst1.32 {d8, d9}, [%[outptr0]]!\n"
"vsdot.s8 q7 , q2, d1[1]\n"
"vsdot.s8 q8 , q2, d2[0]\n"
"vsdot.s8 q9 , q2, d2[1]\n"
"vst1.32 {d10, d11}, [%[outptr0]]!\n"
"vsdot.s8 q10 , q3, d0[0]\n"
"vsdot.s8 q11 , q3, d0[1]\n"
"vsdot.s8 q12 , q3, d1[0]\n"
"vst1.32 {d12, d13}, [%[outptr0]]!\n"
"vsdot.s8 q13 , q3, d1[1]\n"
"vsdot.s8 q14 , q3, d2[0]\n"
"vsdot.s8 q15 , q3, d2[1]\n"
"6: \n"
"vst1.32 {d14, d15}, [%[outptr0]]!\n"
"vst1.32 {d16, d17}, [%[outptr0]]!\n"
"vst1.32 {d18, d19}, [%[outptr0]]!\n"
"vst1.32 {d20, d21}, [%[outptr1]]!\n"
"vst1.32 {d22, d23}, [%[outptr1]]!\n"
"vst1.32 {d24, d25}, [%[outptr1]]!\n"
"vst1.32 {d26, d27}, [%[outptr1]]!\n"
"vst1.32 {d28, d29}, [%[outptr1]]!\n"
"vst1.32 {d30, d31}, [%[outptr1]]!\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k),
[outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q14", "q15", "cc", "memory");
}
// Overview of register layout:
//
// A 2x4x4 cell of Rhs is stored in 8bit in q1, q3.
// A 2x2x4x4 ping-pong cell of Lhs is stored in 8bit in q5, q7, q9, q11
// A 2x4x4 block of accumulators is stored in 8bit in q0, q2, q4, q6, q8, q10,
// q12, q14
//
// +--------+
// Rhs |q1[0-16]|
// |q3[0-16]|
// +--------+
// Lhs | |
// +-------+-------+ - - - - +--------+
// | q5[0-16]| | q0[0-4]|
// | q7[0-16]| | q2[0-4]|
// | q9[0-16]| | q4[0-4]|
// |q11[0-16]| | q6[0-4]|
// +---------+ | q8[0-4]|
// |q10[0-4]|
// |q12[0-4]|
// |q14[0-4]|
// +--------+
// Accumulator
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = K / 2;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
size_t x0;
// clang-format off
#define LOAD_LINE(dr0, dr1, dr2, dr3, dr4, dr5, dr6, dr7, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %[n_remain], #4\n" \
"blt 100" n "f\n" \
"vld1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \
"vld1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \
"vld1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \
"vld1.32 {d" dr6 ", d" dr7 "}, [%[x0]]!\n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %[n_remain], #0\n" \
"beq 101" n "f\n" \
"vld1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \
"cmp %[n_remain], #1\n" \
"beq 101" n "f\n" \
"vld1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \
"cmp %[n_remain], #2\n" \
"beq 101" n "f\n" \
"vld1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \
"101" n ":\n"
#define LOAD_C \
LOAD_LINE("0", "1", "4", "5", "8", "9", "12", "13", "0") \
LOAD_LINE("16", "17", "20", "21", "24", "25", "28", "29", "1") \
#define STORE_LINE(dr0, dr1, dr2, dr3, dr4, dr5, dr6, dr7, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %[n_remain], #4\n" \
"blt 102" n "f\n" \
"vst1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \
"vst1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \
"vst1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \
"vst1.32 {d" dr6 ", d" dr7 "}, [%[x0]]!\n" \
"b 103" n "f\n" \
"102" n ":\n" \
"cmp %[n_remain], #0\n" \
"beq 103" n "f\n" \
"vst1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \
"cmp %[n_remain], #1\n" \
"beq 103" n "f\n" \
"vst1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \
"cmp %[n_remain], #2\n" \
"beq 103" n "f\n" \
"vst1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \
"103" n ":\n"
#define STORE_C \
STORE_LINE("0", "1", "4", "5", "8", "9", "12", "13", "0") \
STORE_LINE("16", "17", "20", "21", "24", "25", "28", "29", "1")
// clang-format on
asm volatile(
// load accumulator C
"add %[outptr1], %[outptr0], %[LDC]\n"
"cmp %[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"veor.s32 q0, q0, q0\n"
"veor.s32 q2, q2, q2\n"
"veor.s32 q4, q4, q4\n"
"veor.s32 q6, q6, q6\n"
"veor.s32 q8, q8, q8\n"
"veor.s32 q10, q10, q10\n"
"veor.s32 q12, q12, q12\n"
"veor.s32 q14, q14, q14\n"
"2: \n"
"cmp %[oddk], #0 \n"
"beq 3f \n"
// parse the oddk
"vld1.s8 {q1}, [%[b_ptr]]!\n"
"vld1.s8 {q3}, [%[a_ptr]]!\n"
"vld1.s8 {q5}, [%[a_ptr]]!\n"
"vsdot.s8 q0 , q3, d2[0]\n"
"vsdot.s8 q2 , q3, d2[1]\n"
"vsdot.s8 q4 , q3, d3[0]\n"
"vsdot.s8 q6 , q3, d3[1]\n"
"vsdot.s8 q8 , q5, d2[0]\n"
"vsdot.s8 q10 , q5, d2[1]\n"
"vsdot.s8 q12 , q5, d3[0]\n"
"vsdot.s8 q14 , q5, d3[1]\n"
"cmp %[k], #0 \n"
"beq 4f \n"
// Loop proper
"3:\n"
"vld1.s8 {q1}, [%[b_ptr]]!\n"
"vld1.s8 {q3}, [%[b_ptr]]!\n"
"vld1.s8 {q5}, [%[a_ptr]]!\n"
"vld1.s8 {q7}, [%[a_ptr]]!\n"
"vsdot.s8 q0 , q5, d2[0]\n"
"vsdot.s8 q2 , q5, d2[1]\n"
"vsdot.s8 q4 , q5, d3[0]\n"
"vsdot.s8 q6 , q5, d3[1]\n"
"vld1.s8 {q9}, [%[a_ptr]]!\n"
"vld1.s8 {q11}, [%[a_ptr]]!\n"
"vsdot.s8 q8 , q7, d2[0]\n"
"vsdot.s8 q10 , q7, d2[1]\n"
"vsdot.s8 q12 , q7, d3[0]\n"
"vsdot.s8 q14 , q7, d3[1]\n"
"vsdot.s8 q0 , q9, d6[0]\n"
"vsdot.s8 q2 , q9, d6[1]\n"
"vsdot.s8 q4 , q9, d7[0]\n"
"vsdot.s8 q6 , q9, d7[1]\n"
"vsdot.s8 q8 , q11, d6[0]\n"
"vsdot.s8 q10 , q11, d6[1]\n"
"vsdot.s8 q12 , q11, d7[0]\n"
"vsdot.s8 q14 , q11, d7[1]\n"
"subs %[k], %[k], #1\n"
"bne 3b\n"
"4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0),
[outptr1] "+r"(outptr1), [x0] "+r"(x0)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q14", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
// Overview of register layout:
//
// A 1x6x4 pingpong cell of Rhs is stored in 8bit in q0-q3.
// A 1x1x4x4 pingpong cell of Lhs is stored in 8bit in q4-q5
// A 2x6x4 block of accumulators is stored in 8bit in q10-q15
//
// +--------+
// Rhs |q0[0-16]|
// |q1[0-16]|
// +--------+
// Lhs | |
// +-------+-------+ - - - - +--------+
// | q4[0-16]| |q10[0-4]|
// | q5[0-16]| |q11[0-4]|
// +---------+ |q12[0-4]|
// |q13[0-4]|
// |q14[0-4]|
// |q15[0-4]|
// +--------+
// Accumulator
static void kern_4x6(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = (K + 1) / 2 - 1;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
asm volatile(
// load accumulator C
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"vld1.32 {d20, d21}, [%[outptr0]]!\n"
"vld1.32 {d22, d23}, [%[outptr0]]!\n"
"vld1.32 {d24, d25}, [%[outptr0]]!\n"
"vld1.32 {d26, d27}, [%[outptr0]]!\n"
"vld1.32 {d28, d29}, [%[outptr0]]!\n"
"vld1.32 {d30, d31}, [%[outptr0]]!\n"
"b 2f\n"
"1:\n"
"veor.s32 q10, q10, q10\n"
"veor.s32 q11, q11, q11\n"
"veor.s32 q12, q12, q12\n"
"veor.s32 q13, q13, q13\n"
"veor.s32 q14, q14, q14\n"
"veor.s32 q15, q15, q15\n"
"2: \n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vld1.s8 {q4}, [%[a_ptr]]!\n"
"cmp %[k], #0 \n"
"beq 4f \n"
"3:\n"
"vsdot.s8 q10 , q4, d0[0]\n"
"vsdot.s8 q11 , q4, d0[1]\n"
"vsdot.s8 q12 , q4, d1[0]\n"
"vld1.s8 {q2}, [%[b_ptr]]!\n"
"vld1.s8 {d6}, [%[b_ptr]]!\n"
"vld1.s8 {q5}, [%[a_ptr]]!\n"
"vsdot.s8 q13 , q4, d1[1]\n"
"vsdot.s8 q14 , q4, d2[0]\n"
"vsdot.s8 q15 , q4, d2[1]\n"
"vld1.s8 {q0}, [%[b_ptr]]!\n"
"vsdot.s8 q10 , q5, d4[0]\n"
"vsdot.s8 q11 , q5, d4[1]\n"
"vsdot.s8 q12 , q5, d5[0]\n"
"vld1.s8 {d2}, [%[b_ptr]]!\n"
"vsdot.s8 q13 , q5, d5[1]\n"
"vsdot.s8 q14 , q5, d6[0]\n"
"vsdot.s8 q15 , q5, d6[1]\n"
"vld1.s8 {q4}, [%[a_ptr]]!\n"
"subs %[k], %[k], #1\n"
"bne 3b\n"
// Target to use when K is 1 or 2 (i.e. zero iterations of main
// loop)
"4:\n"
"cmp %[oddk], #0 \n"
"bne 5f \n"
"vsdot.s8 q10 , q4, d0[0]\n"
"vsdot.s8 q11 , q4, d0[1]\n"
"vsdot.s8 q12 , q4, d1[0]\n"
"vld1.s8 {q2}, [%[b_ptr]]!\n"
"vld1.s8 {d6}, [%[b_ptr]]!\n"
"vld1.s8 {q5}, [%[a_ptr]]!\n"
"vsdot.s8 q13 , q4, d1[1]\n"
"vsdot.s8 q14 , q4, d2[0]\n"
"vsdot.s8 q15 , q4, d2[1]\n"
"vsdot.s8 q10 , q5, d4[0]\n"
"vsdot.s8 q11 , q5, d4[1]\n"
"vsdot.s8 q12 , q5, d5[0]\n"
"vst1.32 {d20, d21}, [%[outptr0]]!\n"
"vsdot.s8 q13 , q5, d5[1]\n"
"vsdot.s8 q14 , q5, d6[0]\n"
"vsdot.s8 q15 , q5, d6[1]\n"
"vst1.32 {d22, d23}, [%[outptr0]]!\n"
"b 6f\n"
"5: \n"
"vsdot.s8 q10 , q4, d0[0]\n"
"vsdot.s8 q11 , q4, d0[1]\n"
"vsdot.s8 q12 , q4, d1[0]\n"
"vst1.32 {d20, d21}, [%[outptr0]]!\n"
"vsdot.s8 q13 , q4, d1[1]\n"
"vsdot.s8 q14 , q4, d2[0]\n"
"vsdot.s8 q15 , q4, d2[1]\n"
"vst1.32 {d22, d23}, [%[outptr0]]!\n"
"6: \n"
"vst1.32 {d24, d25}, [%[outptr0]]!\n"
"vst1.32 {d26, d27}, [%[outptr0]]!\n"
"vst1.32 {d28, d29}, [%[outptr0]]!\n"
"vst1.32 {d30, d31}, [%[outptr0]]!\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k),
[outptr0] "+r"(outptr0)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q14", "q15", "cc", "memory");
}
// Overview of register layout:
//
// A 2x4x4 cell of Rhs is stored in 8bit in q1, q3.
// A 1x2x4x4 ping-pong cell of Lhs is stored in 8bit in q5, q7
// A 1x4x4 block of accumulators is stored in 8bit in q0, q2, q4, q6
//
// +--------+
// Rhs |q1[0-16]|
// |q3[0-16]|
// +--------+
// Lhs | |
// +-------+-------+ - - - - +--------+
// | q5[0-16]| | q0[0-4]|
// | q7[0-16]| | q2[0-4]|
// +---------+ | q4[0-4]|
// | q6[0-4]|
// +--------+
// Accumulator
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB);
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = K / 2;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
size_t x0;
// clang-format off
#define LOAD_LINE(dr0, dr1, dr2, dr3, dr4, dr5, dr6, dr7, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %[n_remain], #4\n" \
"blt 100" n "f\n" \
"vld1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \
"vld1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \
"vld1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \
"vld1.32 {d" dr6 ", d" dr7 "}, [%[x0]]!\n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %[n_remain], #0\n" \
"beq 101" n "f\n" \
"vld1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \
"cmp %[n_remain], #1\n" \
"beq 101" n "f\n" \
"vld1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \
"cmp %[n_remain], #2\n" \
"beq 101" n "f\n" \
"vld1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \
"101" n ":\n"
#define LOAD_C \
LOAD_LINE("0", "1", "4", "5", "8", "9", "12", "13", "0") \
#define STORE_LINE(dr0, dr1, dr2, dr3, dr4, dr5, dr6, dr7, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %[n_remain], #4\n" \
"blt 102" n "f\n" \
"vst1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \
"vst1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \
"vst1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \
"vst1.32 {d" dr6 ", d" dr7 "}, [%[x0]]!\n" \
"b 103" n "f\n" \
"102" n ":\n" \
"cmp %[n_remain], #0\n" \
"beq 103" n "f\n" \
"vst1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \
"cmp %[n_remain], #1\n" \
"beq 103" n "f\n" \
"vst1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \
"cmp %[n_remain], #2\n" \
"beq 103" n "f\n" \
"vst1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \
"103" n ":\n"
#define STORE_C \
STORE_LINE("0", "1", "4", "5", "8", "9", "12", "13", "0") \
// clang-format on
asm volatile(
// load accumulator C
"cmp %[is_first_k], #1\n"
"beq 1f\n" //
LOAD_C //
"b 2f\n"
"1:\n"
"veor.s32 q0, q0, q0\n"
"veor.s32 q2, q2, q2\n"
"veor.s32 q4, q4, q4\n"
"veor.s32 q6, q6, q6\n"
"2: \n"
"cmp %[oddk], #0 \n"
"beq 3f \n"
// parse the oddk
"vld1.s8 {q1}, [%[a_ptr]]!\n"
"vld1.s8 {q3}, [%[b_ptr]]!\n"
"vsdot.s8 q0 , q1, d6[0]\n"
"vsdot.s8 q2 , q1, d6[1]\n"
"vsdot.s8 q4 , q1, d7[0]\n"
"vsdot.s8 q6 , q1, d7[1]\n"
"cmp %[k], #0 \n"
"beq 4f \n"
// Loop proper
"3:\n"
"vld1.s8 {q1}, [%[b_ptr]]!\n"
"vld1.s8 {q5}, [%[a_ptr]]!\n"
"vsdot.s8 q0 , q5, d2[0]\n"
"vsdot.s8 q2 , q5, d2[1]\n"
"vld1.s8 {q3}, [%[b_ptr]]!\n"
"vld1.s8 {q7}, [%[a_ptr]]!\n"
"vsdot.s8 q4 , q5, d3[0]\n"
"vsdot.s8 q6 , q5, d3[1]\n"
"vsdot.s8 q0 , q7, d6[0]\n"
"vsdot.s8 q2 , q7, d6[1]\n"
"vsdot.s8 q4 , q7, d7[0]\n"
"vsdot.s8 q6 , q7, d7[1]\n"
"subs %[k], %[k], #1\n"
"bne 3b\n"
"4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain),
[LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k),
[x0] "+r"(x0)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
int y = y0, y_start = y0 / 4;
for (; y + 7 < ymax; y += 8, y_start += 2) {
const int8_t* inptr0 = inptr + y_start * ldin + k0 * 4;
const int8_t* inptr1 = inptr0 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
int K = kmax - k0;
for (; K > 3; K -= 4) {
interleave_2x4_4_b(inptr0, inptr1, outptr);
}
}
for (; y + 3 < ymax; y += 4, ++y_start) {
int K = kmax - k0;
const int8_t* inptr0 = inptr + y_start * ldin + k0 * 4;
std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4);
}
}
static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0;
const int ksize4 = ksize * 4;
const int ksize6 = ksize * 6;
int8_t* outptr = out;
int8_t* outptr_base = out;
int8_t* outptr_base4 = out + ((xmax - x0) / 6) * ksize6;
int k = k0;
for (; k + 3 < kmax; k += 4) {
const int8_t* inptr = in + k / 4 * ldin + x0 * 4;
prefetch_2x(inptr);
outptr = outptr_base;
int x = x0;
for (; x + 5 < xmax; x += 6) {
memcpy(outptr, inptr, sizeof(int8_t) * 24);
outptr += ksize6;
inptr += 24;
}
outptr = outptr_base4;
for (; x + 3 < xmax; x += 4) {
memcpy(outptr, inptr, sizeof(int8_t) * 16);
outptr += ksize4;
inptr += 16;
}
if (x < xmax) {
int i = 0;
for (; i < xmax - x; i++) {
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
}
for (; i < 4; i++) {
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
}
}
outptr_base += 24;
outptr_base4 += 16;
}
}
} // namespace matmul_mk4_dot_8x6x4
} // namespace armv7
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
......@@ -16,6 +16,7 @@
#include "src/armv7/matrix_mul/int8/kernel_4x8x8.h"
#include "src/armv7/matrix_mul/int8/kernel_6x8x4.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
......@@ -252,6 +253,89 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
}
}
}
// ===========================gemm_mk4_dots8_8x6======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x6);
void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
"matrix mul mk4 with transposed matrix A is not supported.");
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0,
"mk4 format matmul with m is not times of 4.");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 format matmul with k is not times of 4.");
matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_A(out, in, ldin, y0, ymax, k0,
kmax);
}
void gemm_mk4_dots8_8x6::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
"matrix mul mk4 with transposed matrix B is not supported");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 format matmul with k is not times of 4.");
matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_B(out, in, ldin, x0, xmax, k0,
kmax);
}
void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32* bias,
dt_int32* workspace) const {
MEGDNN_MARK_USED_VAR(bias);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 6;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K4 = K * 4;
const int K6 = K * 6;
const int K8 = K * 8;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + ((m >> 2) * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_dot_8x6x4::kern_8x6(packA, cur_packB, K, output, LDC,
is_first_k);
output += 24;
cur_packB += K6;
}
for (; n < N; n += 4) {
size_t n_remain = std::min<size_t>(N - n, 4);
matmul_mk4_dot_8x6x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, n_remain);
output += 16;
cur_packB += K4;
}
packA += K8;
}
for (; m < M; m += 4) {
int32_t* output = C + ((m >> 2) * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_dot_8x6x4::kern_4x6(packA, cur_packB, K, output, LDC,
is_first_k);
output += 24;
cur_packB += K6;
}
for (; n < N; n += 4) {
size_t n_remain = std::min<size_t>(N - n, 4);
matmul_mk4_dot_8x6x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, n_remain);
output += 16;
cur_packB += K4;
}
packA += K4;
}
}
#endif
// ===========================gemm_mk4_s8_4x2======================================
......
......@@ -26,6 +26,9 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false,
#if __ARM_FEATURE_DOTPROD
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false,
gemm_dots8_6x8);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 6, 4, false, false,
gemm_mk4_dots8_8x6);
#endif
} // namespace matmul
} // namespace armv7
......
......@@ -29,6 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K6x8x4 int8_k6x8x4;
AlgoQuint8DotK4x8x4 quint8_k4x8x4;
AlgoInt8x8x32MK4_8x6x4DotProd int8x8x32_mk4_8x6x4_dotprod;
#endif
AlgoF32Gemv f32_gemv;
AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16;
......@@ -56,6 +57,7 @@ public:
all_algos.emplace_back(&f16_mk8_4x8);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_mk4_8x6x4_dotprod);
all_algos.emplace_back(&int8_k6x8x4);
all_algos.emplace_back(&quint8_k4x8x4);
#endif
......
......@@ -42,6 +42,8 @@ private:
#if __ARM_FEATURE_DOTPROD
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4
class AlgoInt8x8x32MK4_8x6x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x6x4
// DotProduct
#endif
class AlgoPack;
};
......
......@@ -86,6 +86,18 @@ TEST_F(ARMV7, MATRIX_MUL_UDOT) {
dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)), dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)),
dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4");
}
TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) {
std::vector<matrix_mul::TestArg> args;
for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11})
for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32})
for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
args.emplace_back(m, n, k, 0);
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "AARCH32_INT8_MK4_8X6X4_DOTPROD",
param::MatrixMul::Format::MK4_DOT, 1, 1e-3,
std::move(args));
}
#endif
#if MEGDNN_WITH_BENCHMARK
......@@ -286,6 +298,53 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_K6x8x4) {
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_QUINT8x8x32_K4x8x4) {
run_8x8x32_quint_benchmark(handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
Benchmarker<MatrixMul> benchmarker_default(handle());
benchmarker_default.set_times(RUNS)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32())
.set_param(param)
.set_display(false);
benchmarker_default.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH32_INT8_K6X8X4"));
param.format = MatrixMul::Param::Format::MK4_DOT;
Benchmarker<MatrixMul> benchmarker_mk4_dot(handle());
benchmarker_mk4_dot.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH32_INT8_MK4_8X6X4_DOTPROD"));
benchmarker_mk4_dot.set_param(param)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32())
.set_display(false)
.set_times(RUNS);
auto run = [&](size_t M, size_t N, size_t K) {
auto default_used =
benchmarker_default.exec({{M, K}, {K, N}, {}}) / RUNS;
auto mk4_dot_used = benchmarker_mk4_dot.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} default: %f ms %f Gflops mk4_dot: "
"%f ms "
"%f Gflops speedup: %f\n",
M, K, N, default_used, computations / default_used, mk4_dot_used,
computations / mk4_dot_used, default_used / mk4_dot_used);
};
for (size_t M = 4; M < 512; M *= 2) {
for (size_t K = 4; K < 512; K *= 2) {
for (size_t N : {4, 8, 33, 113, 128}) {
run(M, N, K);
}
}
}
}
#endif
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册