提交 5b62acfa 编写于 作者: M Megvii Engine Team

feat(dnn/armv7): add new matmul strategy k8x8x4

GitOrigin-RevId: 0c6b7fa1b2ad8724a5c68036d58b3c1e13c3bb42
上级 ad87f78a
......@@ -541,6 +541,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8,
armv7::matmul::gemm_s8x8x16_4x8, int8_t,
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* ===================== Int8x8x16 Kernel 8x8x4 algo ===================== */
namespace {
void kern_int8x8x16_k8x8x4(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("kern_int8x8x16_k8x8x4"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto trA = kern_param.trA, trB = kern_param.trB;
armv7::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, kern_param.A_type,
kern_param.B_type,
kern_param.C_type);
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_s8x8x16_8x8>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x16K8x8x4::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type == kern_size_param.B_type &&
kern_size_param.A_type == dtype::Int8() &&
kern_size_param.C_type == dtype::Int16() &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
}
size_t MatrixMulImpl::AlgoInt8x8x16K8x8x4::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_8x8>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x4::get_kern(
const KernSizeParam&) const {
return kern_int8x8x16_k8x8x4;
}
bool MatrixMulImpl::AlgoInt8x8x16K8x8x4::preferred(
const KernSizeParam& kern_size_param) const {
return kern_size_param.K >= 8 && kern_size_param.K <= 128;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x4,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x16K8x8x4"_hash,
armv7::matmul::gemm_s8x8x16_8x8, int8_t,
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/
namespace {
......
......@@ -181,6 +181,18 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8)
};
class MatrixMulImpl::AlgoInt8x8x16K8x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARMV7_INT8X8X16_K8X8X4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K8X8X4)
};
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
......
/**
* \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
namespace megdnn {
namespace armv7 {
namespace matmul_8x8x4 {
/* +--------+---------------------------------+
* | q4 | b00 b01 b02 b03 b04 b05 b06 b07 |
* +--------+---------------------------------+
* | q5 | b10 b11 b12 b13 b14 b15 b16 b17 |
* +--------+---------------------------------+
* | q6 | b20 b21 b22 b23 b24 b25 b26 b27 |
* +--------+---------------------------------+
* | q7 | b30 b31 b32 b33 b34 b35 b36 b37 |
* +--------+---------------------------------+
* +----+-----------------+ +--------+---------------------------------+
* | d0 | a00 a01 a02 a03 | | q8 | c00 c01 c02 c03 c04 c05 c06 c07 |
* | d1 | a10 a11 a12 a13 | | q9 | c10 c11 c12 c13 c14 c15 c16 c17 |
* | d2 | a20 a21 a22 a23 | | q10 | c20 c21 c22 c23 c24 c25 c26 c27 |
* | d3 | a30 a31 a32 a33 | | q11 | c30 c31 c32 c33 c34 c35 c36 c37 |
* | d4 | a40 a41 a42 a43 | | q12 | c40 c41 c42 c43 c44 c45 c46 c47 |
* | d5 | a50 a51 a52 a53 | | q13 | c50 c51 c52 c53 c54 c55 c56 c57 |
* | d6 | a60 a61 a62 a63 | | q14 | c60 c61 c62 c63 c64 c65 c66 c67 |
* | d7 | a70 a71 a72 a73 | | q15 | c70 c71 c72 c73 c74 c75 c76 c77 |
* +----+-----------------+ +--------+---------------------------------+
*
*/
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k,
size_t n_remain) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
LDC = LDC * sizeof(int16_t);
size_t nr = n_remain;
// clang-format off
#define LOAD_C \
"mov r1, r0\n" \
"vld1.16 {d16, d17}, [r1], %[LDC]\n" \
"vld1.16 {d18, d19}, [r1], %[LDC]\n" \
"vld1.16 {d20, d21}, [r1], %[LDC]\n" \
"vld1.16 {d22, d23}, [r1], %[LDC]\n" \
"vld1.16 {d24, d25}, [r1], %[LDC]\n" \
"vld1.16 {d26, d27}, [r1], %[LDC]\n" \
"vld1.16 {d28, d29}, [r1], %[LDC]\n" \
"vld1.16 {d30, d31}, [r1], %[LDC]\n"
#define STORE_LINE(id1, id2) \
"mov r2, r1\n" \
"cmp %[nr], #8\n" \
"bne 100f\n" \
"vst1.16 {d" id1 ", d" id2 "}, [r2]!\n" \
"b 101f\n" \
"100:\n" \
"cmp %[nr], #0\n" \
"beq 101f\n" \
"vst1.16 {d" id1 "[0]}, [r2]!\n" \
"cmp %[nr], #1\n" \
"beq 101f\n" \
"vst1.16 {d" id1 "[1]}, [r2]!\n" \
"cmp %[nr], #2\n" \
"beq 101f\n" \
"vst1.16 {d" id1 "[2]}, [r2]!\n" \
"cmp %[nr], #3\n" \
"beq 101f\n" \
"vst1.16 {d" id1 "[3]}, [r2]!\n" \
"cmp %[nr], #4\n" \
"beq 101f\n" \
"vst1.16 {d" id2 "[0]}, [r2]!\n" \
"cmp %[nr], #5\n" \
"beq 101f\n" \
"vst1.16 {d" id2 "[1]}, [r2]!\n" \
"cmp %[nr], #6\n" \
"beq 101f\n" \
"vst1.16 {d" id2 "[2]}, [r2]!\n" \
"101:\n"
#define STORE_C \
"mov r1, r0\n" \
STORE_LINE("16", "17") \
"add r1, r1, %[LDC]\n" \
STORE_LINE("18", "19") \
"add r1, r1, %[LDC]\n" \
STORE_LINE("20", "21") \
"add r1, r1, %[LDC]\n" \
STORE_LINE("22", "23") \
"add r1, r1, %[LDC]\n" \
STORE_LINE("24", "25") \
"add r1, r1, %[LDC]\n" \
STORE_LINE("26", "27") \
"add r1, r1, %[LDC]\n" \
STORE_LINE("28", "29") \
"add r1, r1, %[LDC]\n" \
STORE_LINE("30", "31")
// clang-format on
register int16_t* outptr asm("r0") = output;
asm volatile(
"cmp %[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"veor.s32 q8, q8, q8\n"
"veor.s32 q9, q9, q9\n"
"veor.s32 q10, q10, q10\n"
"veor.s32 q11, q11, q11\n"
"veor.s32 q12, q12, q12\n"
"veor.s32 q13, q13, q13\n"
"veor.s32 q14, q14, q14\n"
"veor.s32 q15, q15, q15\n"
"2:\n"
"vld1.8 {d0}, [%[a_ptr]]!\n"
"vld1.8 {d2}, [%[a_ptr]]!\n"
"vld1.8 {d4}, [%[a_ptr]]!\n"
"vld1.8 {d6}, [%[a_ptr]]!\n"
"vmovl.s8 q0, d0\n"
"vmovl.s8 q1, d2\n"
"vmovl.s8 q2, d4\n"
"vmovl.s8 q3, d6\n"
"vld1.8 {d8}, [%[b_ptr]]!\n"
"vld1.8 {d10}, [%[b_ptr]]!\n"
"vld1.8 {d12}, [%[b_ptr]]!\n"
"vld1.8 {d14}, [%[b_ptr]]!\n"
"vmovl.s8 q4, d8\n"
"vmovl.s8 q5, d10\n"
"vmovl.s8 q6, d12\n"
"vmovl.s8 q7, d14\n"
"vmla.s16 q8, q4, d0[0]\n"
"vmla.s16 q9, q4, d1[0]\n"
"vmla.s16 q10, q4, d2[0]\n"
"vmla.s16 q11, q4, d3[0]\n"
"vmla.s16 q12, q4, d4[0]\n"
"vmla.s16 q13, q4, d5[0]\n"
"vmla.s16 q14, q4, d6[0]\n"
"vmla.s16 q15, q4, d7[0]\n"
"vmla.s16 q8, q5, d0[1]\n"
"vmla.s16 q9, q5, d1[1]\n"
"vmla.s16 q10, q5, d2[1]\n"
"vmla.s16 q11, q5, d3[1]\n"
"vmla.s16 q12, q5, d4[1]\n"
"vmla.s16 q13, q5, d5[1]\n"
"vmla.s16 q14, q5, d6[1]\n"
"vmla.s16 q15, q5, d7[1]\n"
"vmla.s16 q8, q6, d0[2]\n"
"vmla.s16 q9, q6, d1[2]\n"
"vmla.s16 q10, q6, d2[2]\n"
"vmla.s16 q11, q6, d3[2]\n"
"vmla.s16 q12, q6, d4[2]\n"
"vmla.s16 q13, q6, d5[2]\n"
"vmla.s16 q14, q6, d6[2]\n"
"vmla.s16 q15, q6, d7[2]\n"
"vmla.s16 q8, q7, d0[3]\n"
"vmla.s16 q9, q7, d1[3]\n"
"vmla.s16 q10, q7, d2[3]\n"
"vmla.s16 q11, q7, d3[3]\n"
"vmla.s16 q12, q7, d4[3]\n"
"vmla.s16 q13, q7, d5[3]\n"
"vmla.s16 q14, q7, d6[3]\n"
"vmla.s16 q15, q7, d7[3]\n"
"subs %[K], %[K], #1\n"
"bne 2b\n"
"3:\n" STORE_C
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K),
[ LDC ] "+r"(LDC), [ is_first_k ] "+r"(is_first_k),
[ outptr ] "+r"(outptr), [ nr ] "+r"(nr)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q13", "q14", "q15", "r1", "r2", "cc", "memory");
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
/* +--------+---------------------------------+
* | q2 | b00 b01 b02 b03 b04 b05 b06 b07 |
* +--------+---------------------------------+
* | q3 | b10 b11 b12 b13 b14 b15 b16 b17 |
* +--------+---------------------------------+
* | q4 | b20 b21 b22 b23 b24 b25 b26 b27 |
* +--------+---------------------------------+
* | q5 | b30 b31 b32 b33 b34 b35 b36 b37 |
* +--------+---------------------------------+
* +----+-----------------+ +--------+---------------------------------+
* | d0 | a00 a01 a02 a03 | | q6 | c00 c01 c02 c03 c04 c05 c06 c07 |
* | d1 | a10 a11 a12 a13 | | q7 | c10 c11 c12 c13 c14 c15 c16 c17 |
* | d2 | a20 a21 a22 a23 | | q8 | c20 c21 c22 c23 c24 c25 c26 c27 |
* | d3 | a30 a31 a32 a33 | | q9 | c30 c31 c32 c33 c34 c35 c36 c37 |
* +----+-----------------+ +--------+---------------------------------+
*
*/
static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, size_t m_remain,
size_t n_remain) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
LDC = LDC * sizeof(int16_t);
size_t mr = m_remain;
size_t nr = n_remain;
// clang-format off
#define LOAD_C \
"cmp %[mr], #0\n" \
"beq 100f\n" \
"mov r1, r0\n" \
"vld1.16 {d12, d13}, [r1], %[LDC]\n" \
"cmp %[mr], #1\n" \
"beq 100f\n" \
"vld1.16 {d14, d15}, [r1], %[LDC]\n" \
"cmp %[mr], #2\n" \
"beq 100f\n" \
"vld1.16 {d16, d17}, [r1], %[LDC]\n" \
"cmp %[mr], #3\n" \
"beq 100f\n" \
"vld1.16 {d18, d19}, [r1], %[LDC]\n" \
"100:\n" \
#define STORE_LINE(id1, id2) \
"mov r2, r1\n" \
"cmp %[nr], #0\n" \
"beq 101f\n" \
"vst1.16 {d" id1 "[0]}, [r2]!\n" \
"cmp %[nr], #1\n" \
"beq 101f\n" \
"vst1.16 {d" id1 "[1]}, [r2]!\n" \
"cmp %[nr], #2\n" \
"beq 101f\n" \
"vst1.16 {d" id1 "[2]}, [r2]!\n" \
"cmp %[nr], #3\n" \
"beq 101f\n" \
"vst1.16 {d" id1 "[3]}, [r2]!\n" \
"cmp %[nr], #4\n" \
"beq 101f\n" \
"vst1.16 {d" id2 "[0]}, [r2]!\n" \
"cmp %[nr], #5\n" \
"beq 101f\n" \
"vst1.16 {d" id2 "[1]}, [r2]!\n" \
"cmp %[nr], #6\n" \
"beq 101f\n" \
"vst1.16 {d" id2 "[2]}, [r2]!\n" \
"cmp %[nr], #7\n" \
"beq 101f\n" \
"vst1.16 {d" id2 "[3]}, [r2]!\n" \
"101:\n"
#define STORE_C \
"cmp %[mr], #0\n" \
"beq 102f\n" \
"mov r1, r0\n" \
STORE_LINE("12", "13") \
"cmp %[mr], #1\n" \
"beq 102f\n" \
"add r1, r1, %[LDC]\n" \
STORE_LINE("14", "15") \
"cmp %[mr], #2\n" \
"beq 102f\n" \
"add r1, r1, %[LDC]\n" \
STORE_LINE("16", "17") \
"cmp %[mr], #3\n" \
"beq 102f\n" \
"add r1, r1, %[LDC]\n" \
STORE_LINE("18", "19") \
"102:\n"
// clang-format on
register int16_t* outptr asm("r0") = output;
asm volatile(
"cmp %[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"veor.s32 q6, q6, q6\n"
"veor.s32 q7, q7, q7\n"
"veor.s32 q8, q8, q8\n"
"veor.s32 q9, q9, q9\n"
"2:\n"
"vld1.8 {d0}, [%[a_ptr]]!\n"
"vld1.8 {d2}, [%[a_ptr]]!\n"
"vmovl.s8 q0, d0\n"
"vmovl.s8 q1, d2\n"
"vld1.8 {d4}, [%[b_ptr]]!\n"
"vld1.8 {d6}, [%[b_ptr]]!\n"
"vld1.8 {d8}, [%[b_ptr]]!\n"
"vld1.8 {d10}, [%[b_ptr]]!\n"
"vmovl.s8 q2, d4\n"
"vmovl.s8 q3, d6\n"
"vmovl.s8 q4, d8\n"
"vmovl.s8 q5, d10\n"
"vmla.s16 q6, q2, d0[0]\n"
"vmla.s16 q7, q2, d1[0]\n"
"vmla.s16 q8, q2, d2[0]\n"
"vmla.s16 q9, q2, d3[0]\n"
"vmla.s16 q6, q3, d0[1]\n"
"vmla.s16 q7, q3, d1[1]\n"
"vmla.s16 q8, q3, d2[1]\n"
"vmla.s16 q9, q3, d3[1]\n"
"vmla.s16 q6, q4, d0[2]\n"
"vmla.s16 q7, q4, d1[2]\n"
"vmla.s16 q8, q4, d2[2]\n"
"vmla.s16 q9, q4, d3[2]\n"
"vmla.s16 q6, q5, d0[3]\n"
"vmla.s16 q7, q5, d1[3]\n"
"vmla.s16 q8, q5, d2[3]\n"
"vmla.s16 q9, q5, d3[3]\n"
"subs %[K], %[K], #1\n"
"bne 2b\n"
"3:\n" STORE_C
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K),
[ LDC ] "+r"(LDC), [ is_first_k ] "+r"(is_first_k),
[ outptr ] "+r"(outptr), [ mr ] "+r"(mr), [ nr ] "+r"(nr)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "r1",
"r2", "cc", "memory");
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* out, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
int8_t* outptr = out;
int y = y0;
for (; y + 7 < ymax; y += 8) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
const int8_t* inptr4 = inptr3 + ldin;
const int8_t* inptr5 = inptr4 + ldin;
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
prefetch_2x(inptr4);
prefetch_2x(inptr5);
prefetch_2x(inptr6);
prefetch_2x(inptr7);
int K = kmax - k0;
for (; K > 15; K -= 16) {
interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
}
if (K > 0) {
for (; K > 0; K -= 4)
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr, 4, std::min(K, 4));
}
}
for (; y < ymax; y += 4) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
int K = kmax - k0;
for (; K > 0; K -= 4) {
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4,
std::min(K, 4));
}
}
}
static void gemm_s8x8x16_8x8_pack_A_t(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
int8_t* outbase = out;
size_t K = round_up(kmax - k0, 4);
int k = k0;
for (; k < kmax; k += 4) {
const int8_t* inptr0 = in + k * ldin + x0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
int x = xmax - x0;
int8_t* outptr = outbase;
int8_t* out_tmp = outptr;
for (; x > 7; x -= 8) {
if (k + 3 >= kmax) {
switch (k + 3 - kmax) {
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
out_tmp = outptr;
transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, out_tmp);
outptr += (K - k) * 8 + (x > 15 ? 8 : 4) * k;
}
if (x > 0) {
if (k + 3 >= kmax) {
switch (k + 3 - kmax) {
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
out_tmp = outptr;
if (x > 4) {
transpose_4(inptr0, inptr1, inptr2, inptr3, out_tmp, 4, 4);
x -= 4;
out_tmp = outptr + K * 4;
}
transpose_4(inptr0, inptr1, inptr2, inptr3, out_tmp, 4, x);
}
outbase += 4 * ((xmax - x0) > 7 ? 8 : 4);
}
}
static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
int8_t* outbase = out;
int8_t* out_interleave = out;
const size_t K8 = round_up(kmax - k0, 4) * 8;
int k = k0;
for (; k < kmax; k += 4) {
const int8_t* inptr0 = in + k * ldin + x0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
int x = xmax - x0;
int8_t* outptr = outbase;
for (; x > 7; x -= 8) {
if (k + 3 >= kmax) {
switch (k + 3 - kmax) {
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
out_interleave = outptr;
asm volatile(
"vld1.32 {d0}, [%[inptr0]]!\n"
"vld1.32 {d1}, [%[inptr1]]!\n"
"vld1.32 {d2}, [%[inptr2]]!\n"
"vld1.32 {d3}, [%[inptr3]]!\n"
"vst1.32 {d0}, [%[out_interleave]]!\n"
"vst1.32 {d1}, [%[out_interleave]]!\n"
"vst1.32 {d2}, [%[out_interleave]]!\n"
"vst1.32 {d3}, [%[out_interleave]]!\n"
: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1),
[ inptr2 ] "+r"(inptr2), [ inptr3 ] "+r"(inptr3),
[ out_interleave ] "+r"(out_interleave)
:
: "q0", "q1", "cc", "memory");
outptr += K8;
}
if (x > 0) {
if (k + 3 >= kmax) {
switch (k + 3 - kmax) {
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
out_interleave = outptr;
interleave_4(inptr0, inptr1, inptr2, inptr3, out_interleave, 8, x);
}
outbase += 4 * 8;
}
}
static void gemm_s8x8x16_8x8_pack_B_t(dt_int8* out, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
int8_t* outptr = out;
int y = y0;
for (; y < ymax; y += 8) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
const int8_t* inptr4 = inptr3 + ldin;
const int8_t* inptr5 = inptr4 + ldin;
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
prefetch_2x(inptr4);
prefetch_2x(inptr5);
prefetch_2x(inptr6);
prefetch_2x(inptr7);
int k = k0;
for (; k + 3 < kmax; k += 4) {
if (y + 7 >= ymax) {
switch (y + 7 - ymax) {
case 6:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
outptr += 4 * 8;
}
if (k < kmax) {
if (y + 7 >= ymax) {
switch (y + 7 - ymax) {
case 6:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, kmax - k);
outptr += 4 * 8;
}
}
}
} // namespace matmul_8x8x4
} // namespace armv7
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -10,12 +10,13 @@
* implied.
*/
#include "src/armv7/matrix_mul/int8x8x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h"
#include "src/armv7/matrix_mul/int8x8x16/strategy.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
......@@ -181,6 +182,79 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8);
void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
} else {
matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
}
void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
} else {
matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}
void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
size_t m = 0;
for (; m + 7 < M; m += A_INTERLEAVE) {
int16_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 8));
output += B_INTERLEAVE;
cur_packB += K * 8;
}
packA += K * 8;
}
for (; m < M; m += 4) {
int16_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 8));
output += B_INTERLEAVE;
cur_packB += K * 8;
}
packA += K * 4;
}
}
// ===========================gemm_s8x8x16_mk4_8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8);
......
......@@ -22,6 +22,9 @@ MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true,
MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true,
gemm_s8x8x16_4x8);
MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 8, 8, 4, false, true,
gemm_s8x8x16_8x8);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(int8_t, int16_t, int16_t, int16_t, 8,
8, 4, false, false,
gemm_s8x8x16_mk4_8x8);
......
......@@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoQuint8K4x8x8 quint8_k4x8x8;
AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16;
AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8;
AlgoInt8x8x16K8x8x4 int8x8x16_k8x8x4;
AlgoInt8x8x16MK4_8x8x4 int8x8x16_mk4_8x8x4;
AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1;
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8;
......@@ -47,7 +48,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack() {
m_all_algos.emplace_back(&f32_gemv);
m_all_algos.emplace_back(&f32);
......@@ -69,6 +69,7 @@ public:
m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4);
m_all_algos.emplace_back(&int8x8x16_k4x2x16);
m_all_algos.emplace_back(&int8x8x16_k4x8x8);
m_all_algos.emplace_back(&int8x8x16_k8x8x4);
m_all_algos.emplace_back(&int16x16x32_k12x4x1);
m_all_algos.emplace_back(&int16x16x32_mk8_4x8);
......
......@@ -41,7 +41,8 @@ private:
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8
class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel 8x8x8
class AlgoInt8x8x16K8x8x4; // Armv7 Int8x8x16 Kernel 8x8x4
class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel mk4_8x8x4
class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1
class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......
......@@ -174,7 +174,8 @@ public:
ARMV7_INT8X8X16_MK4_K8X8X4,
ARMV7_INT16X16X32_K12X4X1,
ARMV7_INT16X16X32_MK8_4X8,
ARMV7_INT8X8X32_MK4_4X2X16
ARMV7_INT8X8X32_MK4_4X2X16,
ARMV7_INT8X8X16_K8X8X4
#endif
#endif
};
......
......@@ -52,6 +52,12 @@ TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) {
handle(), "ARMV7_INT8X8X16_K4X8X8");
}
TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K8x8x4) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "ARMV7_INT8X8X16_K8X8X4");
}
TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_MK4_K8x8x4) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "ARMV7_INT8X8X16_MK4_K8X8X4",
......@@ -183,6 +189,68 @@ void run_8x8x16_benchmark(
}
}
}
void run_8x8x16_contrast(
const char* algo0, const char* algo, Handle* handle,
MatrixMul::Param::Format format = MatrixMul::Param::Format::DEFAULT) {
constexpr size_t RUNS = 100;
param::MatrixMul param;
Benchmarker<MatrixMul> benchmarker_int(handle);
Benchmarker<MatrixMul> benchmarker_int_kern_4x2x16(handle);
benchmarker_int.set_before_exec_callback(AlgoChecker<MatrixMul>(algo0));
benchmarker_int.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
param::MatrixMul target_param;
target_param.format = format;
benchmarker_int_kern_4x2x16.set_before_exec_callback(
AlgoChecker<MatrixMul>(algo));
benchmarker_int_kern_4x2x16.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(target_param)
.set_display(false);
auto run = [&](size_t M, size_t N, size_t K) {
auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS;
auto int_kern_used = 1e10;
double computation = 2.0f * M * N * K * 1e-6;
if (format == MatrixMul::Param::Format::MK4) {
int_kern_used = benchmarker_int_kern_4x2x16.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
} else {
int_kern_used =
benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) /
RUNS;
}
printf(" %f(%f)\t %f(%f)\t %f\n", int_used, computation / int_used,
int_kern_used, computation / int_kern_used,
int_used / int_kern_used);
};
printf("\nN\t K\t M\t %s ms(GFlops)\t %s ms(GFlops)\t SPEEDUP\n", algo0,
algo);
for (size_t M : {8}) {
for (size_t K : {72}) {
for (size_t N : {8, 16, 32, 64, 72, 128, 256, 512, 1024, 4096, 8192,
16384, 32768, 65536}) {
printf("%zu\t %zu\t %zu\t", N, K, M);
run(M, N, K);
}
}
}
printf("512\t 512\t 512\t");
run(512, 512, 512);
}
void run_16x16x32_benchmark(const char* algo, Handle* handle) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
......@@ -383,6 +451,10 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K8x8x4) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_K8X8X4", handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_MK4_K8X8X4", handle(),
MatrixMul::Param::Format::MK4);
......@@ -392,6 +464,21 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) {
run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K8x8x4_CONTRAST) {
run_8x8x16_contrast("ARM_COMMON_INT8X8X16", "ARMV7_INT8X8X16_K8X8X4",
handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8_CONTRAST) {
run_8x8x16_contrast("ARM_COMMON_INT8X8X16", "ARMV7_INT8X8X16_K4X8X8",
handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8_K8x8x4_CONTRAST) {
run_8x8x16_contrast("ARMV7_INT8X8X16_K4X8X8", "ARMV7_INT8X8X16_K8X8X4",
handle());
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_FP16) {
constexpr size_t RUNS = 50;
......
......@@ -517,9 +517,18 @@ void convolution::test_conv_config_combinations(int k_size,
param.compute_mode = Param::ComputeMode::FLOAT32;
}
size_t IC = 6, OC = 9, G = 3, FH = ksize, FW = ksize;
TensorShape ishp = format ?
TensorShape{2, 18, 18, IC} : TensorShape{2, IC, 18, 18},
fshp;
TensorShape ishp = TensorShape{2, 18, 18, IC}, fshp;
if (format) {
ishp.shape[0] = 2;
ishp.shape[1] = 18;
ishp.shape[2] = 18;
ishp.shape[3] = IC;
} else {
ishp.shape[0] = 2;
ishp.shape[1] = IC;
ishp.shape[2] = 18;
ishp.shape[3] = 18;
}
if (padding) {
param.pad_h = 2 + non_square;
param.pad_w = 2 - non_square;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册