提交 73d84162 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/aarch64): add matmul with dotprod for mk4

GitOrigin-RevId: feb391d635f5d19ff745e6426d385ebeedb1f0ef
上级 c1397792
......@@ -474,6 +474,76 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern(
MIDOUT_END();
return nullptr;
}
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
namespace {
void int8x8x32_mk4_8x12x4_dotprod_kern(
const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int8x8x32_mk4_8x12x4_dotprod_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_8x12>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<
aarch64::matmul::gemm_mk4_s8_8x12>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern(
const KernSizeParam&) const {
return int8x8x32_mk4_8x12x4_dotprod_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash,
aarch64::matmul::gemm_mk4_s8_8x12, int8_t,
int32_t);
#else
/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
......
......@@ -118,6 +118,19 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
#else
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
......
......@@ -615,6 +615,20 @@ static inline void interleave_12x4_4_b(const T*& inptr0, const T*& inptr1,
reinterpret_cast<int32_t*&>(outptr));
}
static inline void interleave_2x1_4_s(const int32_t*& inptr0,
const int32_t*& inptr1,
int32_t*& outptr) {
asm volatile(
"ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3
"ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3
"st1 {v0.4s}, [%[outptr]], #16\n"
"st1 {v1.4s}, [%[outptr]], #16\n"
:
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr)
:
: "v0", "v1", "cc", "memory");
}
static inline void interleave_8x1_4_s(
const int32_t*& inptr0, const int32_t*& inptr1, const int32_t*& inptr2,
const int32_t*& inptr3, const int32_t*& inptr4, const int32_t*& inptr5,
......@@ -752,6 +766,17 @@ static inline void interleave_8x2_2_d(
"v11", "v12", "v13", "v14", "v15", "cc", "memory");
}
template <typename T>
static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1,
T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"interleave_2x4_4_b only support uint8_t and int8_t");
interleave_2x1_4_s(reinterpret_cast<const int32_t*&>(inptr0),
reinterpret_cast<const int32_t*&>(inptr1),
reinterpret_cast<int32_t*&>(outptr));
}
template <typename T>
static inline void interleave_8x4_4_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace aarch64 {
namespace matmul_mk4_8x12x4 {
// Overview of register layout:
//
// A 12x4 cell of Rhs is stored in 8bit in q2-q4.
// A 8x4x2 cell of Lhs is stored in 8bit in q0-q1,q5-q6
// A 8x12 block of accumulators is stored in 8bit in q8--q31.
//
// +------------+------------+------------+
// | v2[0-16]| v3[0-16]| v4[0-16]|
// Rhs +------------+------------+------------+
//
// | | | |
//
// Lhs | | | |
//
// +--------+--------+ - - - - +------------+------------+------------+
// |v0[0-16]|v5[0-16]| | v8 v9v10v11|v16v17v18v19|v24v25v26v27|
// |v1[0-16]|v6[0-16]| |v12v13v14v15|v20v21v22v23|v28v29v30v31|
// +--------+--------+ - - - - +------------+------------+------------+
//
// Accumulator
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = ((K + 1) / 2) - 1;
int32x4_t a0;
int32x4_t a1;
int32x4_t b0;
int32x4_t b1;
int32x4_t b2;
int32x4_t a0a;
int32x4_t a1a;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
asm volatile (
// load accumulator C
"add %[outptr1], %[outptr0], %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 5f\n"
// we can not use ld1, as it can not encode {v8, v16, v24}
"ldp q8, q9, [%[outptr0]]\n"
"ldp q10, q11, [%[outptr0], #32]\n"
"ldp q16, q17, [%[outptr0], #64]\n"
"ldp q18, q19, [%[outptr0], #96]\n"
"ldp q24, q25, [%[outptr0], #128]\n"
"ldp q26, q27, [%[outptr0], #160]\n"
"ldp q12, q13, [%[outptr1]]\n"
"ldp q14, q15, [%[outptr1], #32]\n"
"ldp q20, q21, [%[outptr1], #64]\n"
"ldp q22, q23, [%[outptr1], #96]\n"
"ldp q28, q29, [%[outptr1], #128]\n"
"ldp q30, q31, [%[outptr1], #160]\n"
"b 6f\n"
"5:\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v14.16b, v14.16b, v14.16b\n"
"eor v15.16b, v15.16b, v15.16b\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v20.16b, v20.16b\n"
"eor v21.16b, v21.16b, v21.16b\n"
"eor v22.16b, v22.16b, v22.16b\n"
"eor v23.16b, v23.16b, v23.16b\n"
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
"6: \n"
// Initialize result registers, load initial operands, prime prefetches.
"ldr %q[a0], [%[a_ptr]]\n"
"ldr %q[b0], [%[b_ptr]]\n"
"ldr %q[a1], [%[a_ptr], #16]\n"
"ldr %q[b1], [%[b_ptr], #16]\n"
ASM_PREFETCH("[%[b_ptr], #64]")
ASM_PREFETCH("[%[a_ptr], #64]")
ASM_PREFETCH("[%[b_ptr], #128]")
ASM_PREFETCH("[%[a_ptr], #128]")
ASM_PREFETCH("[%[b_ptr], #192]")
ASM_PREFETCH("[%[b_ptr], #256]")
ASM_PREFETCH("[%[a_ptr], #192]")
ASM_PREFETCH("[%[b_ptr], #320]")
ASM_PREFETCH("[%[a_ptr], #256]")
ASM_PREFETCH("[%[b_ptr], #384]")
// Skip loop if we are doing zero iterations of it.
"cbz %w[k], 4f\n"
// Loop proper
"1:\n"
"sdot v8.4s , %[a0].16b, %[b0].4b[0]\n"
"sdot v9.4s , %[a0].16b, %[b0].4b[1]\n"
"ldr %q[b2], [%[b_ptr], #32]\n"
"sdot v10.4s, %[a0].16b, %[b0].4b[2]\n"
"sdot v11.4s, %[a0].16b, %[b0].4b[3]\n"
"ldr %q[a0a], [%[a_ptr], #32]\n"
"sdot v12.4s, %[a1].16b, %[b0].4b[0]\n"
"sdot v13.4s, %[a1].16b, %[b0].4b[1]\n"
"ldr %q[a1a], [%[a_ptr], #48]\n"
"sdot v14.4s, %[a1].16b, %[b0].4b[2]\n"
"sdot v15.4s, %[a1].16b, %[b0].4b[3]\n"
"ldr %q[b0], [%[b_ptr], #48]\n"
"sdot v16.4s, %[a0].16b, %[b1].4b[0]\n"
"sdot v17.4s, %[a0].16b, %[b1].4b[1]\n"
ASM_PREFETCH("[%[a_ptr], #320]")
"sdot v18.4s, %[a0].16b, %[b1].4b[2]\n"
"sdot v19.4s, %[a0].16b, %[b1].4b[3]\n"
"sdot v20.4s, %[a1].16b, %[b1].4b[0]\n"
"sdot v21.4s, %[a1].16b, %[b1].4b[1]\n"
"sdot v22.4s, %[a1].16b, %[b1].4b[2]\n"
"sdot v23.4s, %[a1].16b, %[b1].4b[3]\n"
"ldr %q[b1], [%[b_ptr], #64]\n"
"sdot v24.4s, %[a0].16b, %[b2].4b[0]\n"
"sdot v25.4s, %[a0].16b, %[b2].4b[1]\n"
ASM_PREFETCH("[%[b_ptr], #448]")
"sdot v26.4s, %[a0].16b, %[b2].4b[2]\n"
"sdot v27.4s, %[a0].16b, %[b2].4b[3]\n"
"sdot v28.4s, %[a1].16b, %[b2].4b[0]\n"
"sdot v29.4s, %[a1].16b, %[b2].4b[1]\n"
"sdot v30.4s, %[a1].16b, %[b2].4b[2]\n"
"sdot v31.4s, %[a1].16b, %[b2].4b[3]\n"
"ldr %q[b2], [%[b_ptr], #80]\n"
"sdot v8.4s , %[a0a].16b, %[b0].4b[0]\n"
"sdot v9.4s , %[a0a].16b, %[b0].4b[1]\n"
"ldr %q[a0], [%[a_ptr], #64]\n"
"sdot v10.4s, %[a0a].16b, %[b0].4b[2]\n"
"sdot v11.4s, %[a0a].16b, %[b0].4b[3]\n"
"sdot v12.4s, %[a1a].16b, %[b0].4b[0]\n"
"ldr %q[a1], [%[a_ptr], #80]\n"
"sdot v13.4s, %[a1a].16b, %[b0].4b[1]\n"
"sdot v14.4s, %[a1a].16b, %[b0].4b[2]\n"
"sdot v15.4s, %[a1a].16b, %[b0].4b[3]\n"
"ldr %q[b0], [%[b_ptr], #96]\n"
"sdot v16.4s, %[a0a].16b, %[b1].4b[0]\n"
"sdot v17.4s, %[a0a].16b, %[b1].4b[1]\n"
ASM_PREFETCH("[%[b_ptr], #512]")
"sdot v18.4s, %[a0a].16b, %[b1].4b[2]\n"
"sdot v19.4s, %[a0a].16b, %[b1].4b[3]\n"
"sdot v20.4s, %[a1a].16b, %[b1].4b[0]\n"
"sdot v21.4s, %[a1a].16b, %[b1].4b[1]\n"
"sdot v22.4s, %[a1a].16b, %[b1].4b[2]\n"
"sdot v23.4s, %[a1a].16b, %[b1].4b[3]\n"
"ldr %q[b1], [%[b_ptr], #112]\n"
"sdot v24.4s, %[a0a].16b, %[b2].4b[0]\n"
"sdot v25.4s, %[a0a].16b, %[b2].4b[1]\n"
"add %[a_ptr], %[a_ptr], #64\n"
"sdot v26.4s, %[a0a].16b, %[b2].4b[2]\n"
"sdot v27.4s, %[a0a].16b, %[b2].4b[3]\n"
"add %[b_ptr], %[b_ptr], #96\n"
"sdot v28.4s, %[a1a].16b, %[b2].4b[0]\n"
"sdot v29.4s, %[a1a].16b, %[b2].4b[1]\n"
"subs %w[k], %w[k], #1\n"
"sdot v30.4s, %[a1a].16b, %[b2].4b[2]\n"
"sdot v31.4s, %[a1a].16b, %[b2].4b[3]\n"
"bne 1b\n"
// Target to use when K is 1 or 2 (i.e. zero iterations of main loop)
"4:\n"
// Branch to alternative tail for odd K
"cbnz %w[oddk], 2f\n"
// Detached final iteration (even K)
"sdot v8.4s , %[a0].16b, %[b0].4b[0]\n"
"sdot v9.4s , %[a0].16b, %[b0].4b[1]\n"
"ldr %q[b2], [%[b_ptr], #32]\n"
"sdot v10.4s, %[a0].16b, %[b0].4b[2]\n"
"sdot v11.4s, %[a0].16b, %[b0].4b[3]\n"
"ldr %q[a0a], [%[a_ptr], #32]\n"
"sdot v12.4s, %[a1].16b, %[b0].4b[0]\n"
"sdot v13.4s, %[a1].16b, %[b0].4b[1]\n"
"ldr %q[a1a], [%[a_ptr], #48]\n"
"sdot v14.4s, %[a1].16b, %[b0].4b[2]\n"
"sdot v15.4s, %[a1].16b, %[b0].4b[3]\n"
"ldr %q[b0], [%[b_ptr], #48]\n"
"sdot v16.4s, %[a0].16b, %[b1].4b[0]\n"
"sdot v17.4s, %[a0].16b, %[b1].4b[1]\n"
"sdot v18.4s, %[a0].16b, %[b1].4b[2]\n"
"sdot v19.4s, %[a0].16b, %[b1].4b[3]\n"
"sdot v20.4s, %[a1].16b, %[b1].4b[0]\n"
"sdot v21.4s, %[a1].16b, %[b1].4b[1]\n"
"sdot v22.4s, %[a1].16b, %[b1].4b[2]\n"
"sdot v23.4s, %[a1].16b, %[b1].4b[3]\n"
"ldr %q[b1], [%[b_ptr], #64]\n"
"sdot v24.4s, %[a0].16b, %[b2].4b[0]\n"
"sdot v25.4s, %[a0].16b, %[b2].4b[1]\n"
"add %[a_ptr], %[a_ptr], #64\n"
"sdot v26.4s, %[a0].16b, %[b2].4b[2]\n"
"sdot v27.4s, %[a0].16b, %[b2].4b[3]\n"
"sdot v28.4s, %[a1].16b, %[b2].4b[0]\n"
"sdot v29.4s, %[a1].16b, %[b2].4b[1]\n"
"sdot v30.4s, %[a1].16b, %[b2].4b[2]\n"
"sdot v31.4s, %[a1].16b, %[b2].4b[3]\n"
"ldr %q[b2], [%[b_ptr], #80]\n"
"sdot v8.4s , %[a0a].16b, %[b0].4b[0]\n"
"sdot v16.4s, %[a0a].16b, %[b1].4b[0]\n"
"add %[b_ptr], %[b_ptr], #96\n"
"sdot v9.4s , %[a0a].16b, %[b0].4b[1]\n"
"str q8, [%[outptr0], #0]\n"
"sdot v17.4s, %[a0a].16b, %[b1].4b[1]\n"
"str q16, [%[outptr0], #64]\n"
"sdot v24.4s, %[a0a].16b, %[b2].4b[0]\n"
"str q24, [%[outptr0], #128]\n"
"sdot v25.4s, %[a0a].16b, %[b2].4b[1]\n"
"str q9, [%[outptr0], #16]\n"
"sdot v10.4s, %[a0a].16b, %[b0].4b[2]\n"
"str q17, [%[outptr0], #80]\n"
"sdot v18.4s, %[a0a].16b, %[b1].4b[2]\n"
"str q25, [%[outptr0], #144]\n"
"sdot v26.4s, %[a0a].16b, %[b2].4b[2]\n"
"str q10, [%[outptr0], #32]\n"
"sdot v11.4s, %[a0a].16b, %[b0].4b[3]\n"
"str q18, [%[outptr0], #96]\n"
"sdot v19.4s, %[a0a].16b, %[b1].4b[3]\n"
"str q26, [%[outptr0], #160]\n"
"sdot v27.4s, %[a0a].16b, %[b2].4b[3]\n"
"str q11, [%[outptr0], #48]\n"
"sdot v12.4s, %[a1a].16b, %[b0].4b[0]\n"
"str q19, [%[outptr0], #112]\n"
"sdot v20.4s, %[a1a].16b, %[b1].4b[0]\n"
"str q27, [%[outptr0], #176]\n"
"sdot v28.4s, %[a1a].16b, %[b2].4b[0]\n"
"str q12, [%[outptr1], #0]\n"
"sdot v13.4s, %[a1a].16b, %[b0].4b[1]\n"
"str q20, [%[outptr1], #64]\n"
"sdot v21.4s, %[a1a].16b, %[b1].4b[1]\n"
"str q28, [%[outptr1], #128]\n"
"sdot v29.4s, %[a1a].16b, %[b2].4b[1]\n"
"str q13, [%[outptr1], #16]\n"
"sdot v14.4s, %[a1a].16b, %[b0].4b[2]\n"
"str q21, [%[outptr1], #80]\n"
"sdot v22.4s, %[a1a].16b, %[b1].4b[2]\n"
"str q29, [%[outptr1], #144]\n"
"sdot v30.4s, %[a1a].16b, %[b2].4b[2]\n"
"str q14, [%[outptr1], #32]\n"
"sdot v15.4s, %[a1a].16b, %[b0].4b[3]\n"
"str q22, [%[outptr1], #96]\n"
"sdot v23.4s, %[a1a].16b, %[b1].4b[3]\n"
"str q30, [%[outptr1], #160]\n"
"sdot v31.4s, %[a1a].16b, %[b2].4b[3]\n"
"str q15, [%[outptr1], #48]\n"
"b 3f\n"
// Detached final iteration (odd K)
"2:\n"
"sdot v8.4s , %[a0].16b, %[b0].4b[0]\n"
"ldr %q[b2], [%[b_ptr], #32]\n"
"sdot v16.4s, %[a0].16b, %[b1].4b[0]\n"
"sdot v9.4s , %[a0].16b, %[b0].4b[1]\n"
"str q8, [%[outptr0], #0]\n"
"sdot v17.4s, %[a0].16b, %[b1].4b[1]\n"
"str q16, [%[outptr0], #64]\n"
"sdot v24.4s, %[a0].16b, %[b2].4b[0]\n"
"add %[b_ptr], %[b_ptr], #48\n"
"add %[a_ptr], %[a_ptr], #32\n"
"str q24, [%[outptr0], #128]\n"
"sdot v25.4s, %[a0].16b, %[b2].4b[1]\n"
"str q9, [%[outptr0], #16]\n"
"sdot v10.4s, %[a0].16b, %[b0].4b[2]\n"
"str q17, [%[outptr0], #80]\n"
"sdot v18.4s, %[a0].16b, %[b1].4b[2]\n"
"str q25, [%[outptr0], #144]\n"
"sdot v26.4s, %[a0].16b, %[b2].4b[2]\n"
"str q10, [%[outptr0], #32]\n"
"sdot v11.4s, %[a0].16b, %[b0].4b[3]\n"
"str q18, [%[outptr0], #96]\n"
"sdot v19.4s, %[a0].16b, %[b1].4b[3]\n"
"str q26, [%[outptr0], #160]\n"
"sdot v27.4s, %[a0].16b, %[b2].4b[3]\n"
"str q11, [%[outptr0], #48]\n"
"sdot v12.4s, %[a1].16b, %[b0].4b[0]\n"
"str q19, [%[outptr0], #112]\n"
"sdot v20.4s, %[a1].16b, %[b1].4b[0]\n"
"str q27, [%[outptr0], #176]\n"
"sdot v28.4s, %[a1].16b, %[b2].4b[0]\n"
"str q12, [%[outptr1], #0]\n"
"sdot v13.4s, %[a1].16b, %[b0].4b[1]\n"
"str q20, [%[outptr1], #64]\n"
"sdot v21.4s, %[a1].16b, %[b1].4b[1]\n"
"str q28, [%[outptr1], #128]\n"
"sdot v29.4s, %[a1].16b, %[b2].4b[1]\n"
"str q13, [%[outptr1], #16]\n"
"sdot v14.4s, %[a1].16b, %[b0].4b[2]\n"
"str q21, [%[outptr1], #80]\n"
"sdot v22.4s, %[a1].16b, %[b1].4b[2]\n"
"str q29, [%[outptr1], #144]\n"
"sdot v30.4s, %[a1].16b, %[b2].4b[2]\n"
"str q14, [%[outptr1], #32]\n"
"sdot v15.4s, %[a1].16b, %[b0].4b[3]\n"
"str q22, [%[outptr1], #96]\n"
"sdot v23.4s, %[a1].16b, %[b1].4b[3]\n"
"str q30, [%[outptr1], #160]\n"
"sdot v31.4s, %[a1].16b, %[b2].4b[3]\n"
"str q15, [%[outptr1], #48]\n"
// Common tail
"3:\n"
"str q23, [%[outptr1], #112]\n"
"str q31, [%[outptr1], #176]\n"
:
[a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr),[oddk] "+r" (oddk),
[is_first_k] "+r" (is_first_k), [k] "+r" (k), [LDC] "+r" (LDC),
[a0] "=w" (a0), [a1] "=w" (a1), [a0a] "=w" (a0a), [a1a] "=w" (a1a),
[b0] "=w" (b0), [b1] "=w" (b1), [b2] "=w" (b2),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1)
:
: "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory"
);
}
// Overview of register layout:
//
// A (12x4)x2 cell of Rhs is stored in 8bit in q2-q7.
// A (4x4)x2 cell of Lhs is stored in 8bit in q0-q1
// A 4x12 block of accumulators is stored in 8bit in q8--q19.
//
// +------------+------------+------------+
// | v2[0-16]| v3[0-16]| v4[0-16]|
// Rhs +------------+------------+------------+
// | v5[0-16]| v6[0-16]| v7[0-16]|
// +------------+------------+------------+
// Lhs | | | |
//
// +--------+--------+ - - - - +------------+------------+------------+
// |v0[0-16]|v1[0-16]| | v8 v9v10v11|v12v13v14v15|v16v17v18v19|
// +--------+--------+ - - - - +------------+------------+------------+
//
// Accumulator
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = K / 2;
int32x4_t a0;
int32x4_t b0;
int32x4_t b1;
int32x4_t b2;
int32x4_t a0a;
int32x4_t b0a;
int32x4_t b1a;
int32x4_t b2a;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
asm volatile(
// load accumulator C
"cmp %w[is_first_k], #1\n"
"beq 1f\n"
"ldp q8, q9, [%[outptr0]]\n"
"ldp q10, q11, [%[outptr0], #32]\n"
"ldp q12, q13, [%[outptr0], #64]\n"
"ldp q14, q15, [%[outptr0], #96]\n"
"ldp q16, q17, [%[outptr0], #128]\n"
"ldp q18, q19, [%[outptr0], #160]\n"
"b 2f\n"
"1:\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v14.16b, v14.16b, v14.16b\n"
"eor v15.16b, v15.16b, v15.16b\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"2: \n"
"cbz %w[oddk], 3f\n"
// parse the oddk
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[b1], [%[b_ptr]], #16\n"
"ldr %q[b2], [%[b_ptr]], #16\n"
"sdot v8.4s, %[a0].16b, %[b0].4b[0]\n"
"sdot v9.4s, %[a0].16b, %[b0].4b[1]\n"
"sdot v10.4s, %[a0].16b, %[b0].4b[2]\n"
"sdot v11.4s, %[a0].16b, %[b0].4b[3]\n"
"sdot v12.4s, %[a0].16b, %[b1].4b[0]\n"
"sdot v13.4s, %[a0].16b, %[b1].4b[1]\n"
"sdot v14.4s, %[a0].16b, %[b1].4b[2]\n"
"sdot v15.4s, %[a0].16b, %[b1].4b[3]\n"
"sdot v16.4s, %[a0].16b, %[b2].4b[0]\n"
"sdot v17.4s, %[a0].16b, %[b2].4b[1]\n"
"sdot v18.4s, %[a0].16b, %[b2].4b[2]\n"
"sdot v19.4s, %[a0].16b, %[b2].4b[3]\n"
"cbz %w[k], 4f\n"
// Loop proper
"3:\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[b1], [%[b_ptr]], #16\n"
"ldr %q[b2], [%[b_ptr]], #16\n"
"ldr %q[a0a], [%[a_ptr]], #16\n"
"ldr %q[b0a], [%[b_ptr]], #16\n"
"ldr %q[b1a], [%[b_ptr]], #16\n"
"ldr %q[b2a], [%[b_ptr]], #16\n"
"sdot v8.4s, %[a0].16b, %[b0].4b[0]\n"
"sdot v9.4s, %[a0].16b, %[b0].4b[1]\n"
"sdot v10.4s, %[a0].16b, %[b0].4b[2]\n"
"sdot v11.4s, %[a0].16b, %[b0].4b[3]\n"
"sdot v12.4s, %[a0].16b, %[b1].4b[0]\n"
"sdot v13.4s, %[a0].16b, %[b1].4b[1]\n"
"sdot v14.4s, %[a0].16b, %[b1].4b[2]\n"
"sdot v15.4s, %[a0].16b, %[b1].4b[3]\n"
"sdot v16.4s, %[a0].16b, %[b2].4b[0]\n"
"sdot v17.4s, %[a0].16b, %[b2].4b[1]\n"
"sdot v18.4s, %[a0].16b, %[b2].4b[2]\n"
"sdot v19.4s, %[a0].16b, %[b2].4b[3]\n"
"sdot v8.4s , %[a0a].16b, %[b0a].4b[0]\n"
"sdot v9.4s , %[a0a].16b, %[b0a].4b[1]\n"
"sdot v10.4s, %[a0a].16b, %[b0a].4b[2]\n"
"sdot v11.4s, %[a0a].16b, %[b0a].4b[3]\n"
"sdot v12.4s, %[a0a].16b, %[b1a].4b[0]\n"
"sdot v13.4s, %[a0a].16b, %[b1a].4b[1]\n"
"sdot v14.4s, %[a0a].16b, %[b1a].4b[2]\n"
"sdot v15.4s, %[a0a].16b, %[b1a].4b[3]\n"
"sdot v16.4s, %[a0a].16b, %[b2a].4b[0]\n"
"sdot v17.4s, %[a0a].16b, %[b2a].4b[1]\n"
"sdot v18.4s, %[a0a].16b, %[b2a].4b[2]\n"
"sdot v19.4s, %[a0a].16b, %[b2a].4b[3]\n"
"subs %w[k], %w[k], #1\n"
"bne 3b\n"
"4:\n"
"stp q8, q9, [%[outptr0]]\n"
"stp q10, q11, [%[outptr0], #32]\n"
"stp q12, q13, [%[outptr0], #64]\n"
"stp q14, q15, [%[outptr0], #96]\n"
"stp q16, q17, [%[outptr0], #128]\n"
"stp q18, q19, [%[outptr0], #160]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k),
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a),
[b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a),
[b1a] "=w"(b1a), [b2a] "=w"(b2a)
:
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "memory", "cc");
}
// Overview of register layout:
//
// A (4x4)x2 cell of Rhs is stored in 8bit in q2-q7.
// A (8x4)x2 cell of Lhs is stored in 8bit in q0-q1, q4-q5
// A 8x4 block of accumulators is stored in 8bit in q6-q13.
//
// +------------+
// | v2[0-16]|
// Rhs +------------+
// | v3[0-16]|
// +------------+
// Lhs | |
//
// +--------+--------+ - - - - +------------+
// |v0[0-16]|v4[0-16]| | v6 v7 v8 v9|
// +--------+--------+ - - - - +------------+
// |v1[0-16]|v5[0-16]| |v10v11v12v13|
// +--------+--------+ - - - - +------------+
// Accumulator
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = K / 2;
int32x4_t a0;
int32x4_t a1;
int32x4_t b0;
int32x4_t b0a;
int32x4_t a0a;
int32x4_t a1a;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
int32_t* outptr1;
size_t x0;
// clang-format off
#define LOAD_LINE(v0, v1, v2, v3, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ldr q" v0 ", [%[x0]] \n" \
"ldr q" v1 ", [%[x0], #16] \n" \
"ldr q" v2 ", [%[x0], #32] \n" \
"ldr q" v3 ", [%[x0], #48] \n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 101" n "f\n" \
"ldr q" v0 ", [%[x0]], #16\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ldr q" v1 ", [%[x0]], #16\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ldr q" v2 ", [%[x0]], #16\n" \
"101" n ":\n"
#define LOAD_C \
LOAD_LINE("6", "7", "8", "9", "0") \
LOAD_LINE("10", "11", "12", "13", "1") \
#define STORE_LINE(v0, v1, v2, v3, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %w[n_remain], #4\n" \
"blt 102" n "f\n" \
"str q" v0 ", [%[x0]] \n" \
"str q" v1 ", [%[x0], #16] \n" \
"str q" v2 ", [%[x0], #32] \n" \
"str q" v3 ", [%[x0], #48] \n" \
"b 103" n "f\n" \
"102" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 103" n "f\n" \
"str q" v0 ", [%[x0]], #16\n" \
"cmp %w[n_remain], #1\n" \
"beq 103" n "f\n" \
"str q" v1 ", [%[x0]], #16\n" \
"cmp %w[n_remain], #2\n" \
"beq 103" n "f\n" \
"str q" v2 ", [%[x0]], #16\n" \
"103" n ":\n"
#define STORE_C \
STORE_LINE("6", "7", "8", "9", "0") \
STORE_LINE("10", "11", "12", "13", "1")
// clang-format on
asm volatile(
// load accumulator C
"add %[outptr1], %[outptr0], %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"2: \n"
"cbz %w[oddk], 3f\n"
// parse the oddk
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[a1], [%[a_ptr]], #16\n"
"sdot v6.4s , %[a0].16b, %[b0].4b[0]\n"
"sdot v7.4s , %[a0].16b, %[b0].4b[1]\n"
"sdot v8.4s, %[a0].16b, %[b0].4b[2]\n"
"sdot v9.4s, %[a0].16b, %[b0].4b[3]\n"
"sdot v10.4s, %[a1].16b, %[b0].4b[0]\n"
"sdot v11.4s, %[a1].16b, %[b0].4b[1]\n"
"sdot v12.4s, %[a1].16b, %[b0].4b[2]\n"
"sdot v13.4s, %[a1].16b, %[b0].4b[3]\n"
"cbz %w[k], 4f\n"
// Loop proper
"3:\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[a1], [%[a_ptr]], #16\n"
"ldr %q[a0a], [%[a_ptr]], #16\n"
"ldr %q[a1a], [%[a_ptr]], #16\n"
"ldr %q[b0a], [%[b_ptr]], #16\n"
"sdot v6.4s , %[a0].16b, %[b0].4b[0]\n"
"sdot v7.4s , %[a0].16b, %[b0].4b[1]\n"
"sdot v8.4s, %[a0].16b, %[b0].4b[2]\n"
"sdot v9.4s, %[a0].16b, %[b0].4b[3]\n"
"sdot v10.4s, %[a1].16b, %[b0].4b[0]\n"
"sdot v11.4s, %[a1].16b, %[b0].4b[1]\n"
"sdot v12.4s, %[a1].16b, %[b0].4b[2]\n"
"sdot v13.4s, %[a1].16b, %[b0].4b[3]\n"
"sdot v6.4s , %[a0a].16b, %[b0a].4b[0]\n"
"sdot v7.4s , %[a0a].16b, %[b0a].4b[1]\n"
"sdot v8.4s, %[a0a].16b, %[b0a].4b[2]\n"
"sdot v9.4s, %[a0a].16b, %[b0a].4b[3]\n"
"sdot v10.4s, %[a1a].16b, %[b0a].4b[0]\n"
"sdot v11.4s, %[a1a].16b, %[b0a].4b[1]\n"
"sdot v12.4s, %[a1a].16b, %[b0a].4b[2]\n"
"sdot v13.4s, %[a1a].16b, %[b0a].4b[3]\n"
"subs %w[k], %w[k], #1\n"
"bne 3b\n"
"4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0),
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a),
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1),
[x0] "=r"(x0)
:
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory",
"cc");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
// Overview of register layout:
//
// A (4x4)x2 cell of Rhs is stored in 8bit in q2-q3.
// A (4x4)x2 cell of Lhs is stored in 8bit in q0-q1
// A 4x4 block of accumulators is stored in 8bit in q4-q7.
//
// +------------+
// | v2[0-16]|
// Rhs +------------+
// | v3[0-16]|
// +------------+
// Lhs | |
//
// +--------+--------+ - - - - +------------+
// |v0[0-16]|v4[0-16]| | v4 v5 v6 v7|
// +--------+--------+ - - - - +------------+
// Accumulator
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB);
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
int k = K / 2;
int32x4_t a0;
int32x4_t a0a;
int32x4_t b0;
int32x4_t b0a;
LDC = LDC * sizeof(int32_t);
int32_t* outptr0 = output;
size_t x0;
// clang-format off
#define LOAD_LINE(v0, v1, v2, v3, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ldr q" v0 ", [%[x0]] \n" \
"ldr q" v1 ", [%[x0], #16] \n" \
"ldr q" v2 ", [%[x0], #32] \n" \
"ldr q" v3 ", [%[x0], #48] \n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 101" n "f\n" \
"ldr q" v0 ", [%[x0]], #16\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ldr q" v1 ", [%[x0]], #16\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ldr q" v2 ", [%[x0]], #16\n" \
"101" n ":\n"
#define LOAD_C \
LOAD_LINE("4", "5", "6", "7", "0")
#define STORE_LINE(v0, v1, v2, v3, n) \
"mov %[x0], %[outptr" n "]\n" \
"cmp %w[n_remain], #4\n" \
"blt 102" n "f\n" \
"str q" v0 ", [%[x0]] \n" \
"str q" v1 ", [%[x0], #16] \n" \
"str q" v2 ", [%[x0], #32] \n" \
"str q" v3 ", [%[x0], #48] \n" \
"b 103" n "f\n" \
"102" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 103" n "f\n" \
"str q" v0 ", [%[x0]], #16\n" \
"cmp %w[n_remain], #1\n" \
"beq 103" n "f\n" \
"str q" v1 ", [%[x0]], #16\n" \
"cmp %w[n_remain], #2\n" \
"beq 103" n "f\n" \
"str q" v2 ", [%[x0]], #16\n" \
"103" n ":\n"
#define STORE_C \
STORE_LINE("4", "5", "6", "7", "0")
// clang-format on
asm volatile(
// load accumulator C
"cmp %w[is_first_k], #1\n"
"beq 1f\n" //
LOAD_C //
"b 2f\n"
"1:\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"2: \n"
"cbz %w[oddk], 3f\n"
// parse the oddk
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"sdot v4.4s , %[a0].16b, %[b0].4b[0]\n"
"sdot v5.4s , %[a0].16b, %[b0].4b[1]\n"
"sdot v6.4s, %[a0].16b, %[b0].4b[2]\n"
"sdot v7.4s, %[a0].16b, %[b0].4b[3]\n"
"cbz %w[k], 4f\n"
// Loop proper
"3:\n"
"ldr %q[a0], [%[a_ptr]], #16\n"
"ldr %q[b0], [%[b_ptr]], #16\n"
"ldr %q[a0a], [%[a_ptr]], #16\n"
"ldr %q[b0a], [%[b_ptr]], #16\n"
"sdot v4.4s , %[a0].16b, %[b0].4b[0]\n"
"sdot v5.4s , %[a0].16b, %[b0].4b[1]\n"
"sdot v6.4s, %[a0].16b, %[b0].4b[2]\n"
"sdot v7.4s, %[a0].16b, %[b0].4b[3]\n"
"sdot v4.4s , %[a0a].16b, %[b0a].4b[0]\n"
"sdot v5.4s , %[a0a].16b, %[b0a].4b[1]\n"
"sdot v6.4s, %[a0a].16b, %[b0a].4b[2]\n"
"sdot v7.4s, %[a0a].16b, %[b0a].4b[3]\n"
"subs %w[k], %w[k], #1\n"
"bne 3b\n"
"4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain),
[LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k),
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a),
[x0] "=r"(x0)
:
: "v4", "v5", "v6", "v7", "memory", "cc");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0,
"mk4 matmul with m is not times of 4");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 matmul with k is not times of 4");
int y = y0;
int start_y = y0 / 4;
for (; y + 7 < ymax; y += 8, start_y += 2) {
const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2);
const int8_t* inptr1 = inptr0 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
int K = kmax - k0;
//! read 2 * 4 in each row
for (; K > 3; K -= 4) {
interleave_2x4_4_b(inptr0, inptr1, outptr);
}
}
for (; y + 3 < ymax; y += 4, start_y ++) {
int K = kmax - k0;
const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2);
std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4);
}
}
static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0;
const int ksize12 = ksize * 12;
const int ksize4 = ksize * 4;
int8_t* outptr = out;
int8_t* outptr_base = out;
//! 4x4 block output start pos
int8_t* outptr_base4 = out + ((xmax - x0) / 12) * ksize12;
int k = k0;
for (; k + 3 < kmax; k += 4) {
const int8_t* inptr = in + (k >> 2) * ldin + (x0 << 2);
prefetch_2x(inptr);
int x = x0;
outptr = outptr_base;
for (; x + 11 < xmax; x += 12) {
std::memcpy(outptr, inptr, 48);
outptr += ksize12;
inptr += 48;
}
outptr = outptr_base4;
for (; x + 3 < xmax; x += 4) {
std::memcpy(outptr, inptr, 16);
outptr += ksize4;
inptr += 16;
}
if (x < xmax) {
int i = 0;
for (; i < xmax - x; i++) {
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
}
for (; i < 4; i++) {
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
}
}
outptr_base += 48;
outptr_base4 += 16;
}
}
} // namespace matmul_mk4_8x12x4
} // namespace aarch64
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
......@@ -14,12 +14,14 @@
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h"
#if __ARM_FEATURE_DOTPROD
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;
/* ====================== gemm_s8_8x12 ===========================*/
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12);
void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
......@@ -109,5 +111,91 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
packA += K4;
}
}
/* ====================== gemm_mk4_s8_8x12 ===========================*/
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12);
void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}
void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 12;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K8 = (K << 3);
const int K12 = K * 12;
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + ((m >> 2) * LDC);
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += (B_INTERLEAVE << 2);
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 16;
cur_packB += K4;
}
packA += K8;
}
for (; m < M; m += 4) {
int32_t* output = C + ((m >> 2) * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += (B_INTERLEAVE << 2);
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 16;
cur_packB += K4;
}
packA += K4;
}
}
#endif
// vim: syntax=cpp.doxygen
......@@ -19,6 +19,9 @@ namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_s8_8x12);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_mk4_s8_8x12);
} // namespace aarch64
} // namespace matmul
} // namespace megdnn
......
......@@ -29,6 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod;
AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod;
AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod;
#else
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16;
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16;
......@@ -64,6 +65,7 @@ public:
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_gemv_dotprod);
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod);
#else
all_algos.emplace_back(&int8x8x32_gemv);
all_algos.emplace_back(&int8x8x32_k4x4x16);
......
......@@ -35,6 +35,8 @@ private:
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct
class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel
// 8x12x4 DotProduct
#else
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16
......
......@@ -64,6 +64,18 @@ TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD");
}
TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_MK4_8X12X4_DOTPROD) {
std::vector<matrix_mul::TestArg> args;
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 10, 11})
for (size_t n : {2, 3, 4, 5, 8, 12, 13, 14, 15, 16, 31})
for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
args.emplace_back(m, n, k, 0);
matrix_mul::check_matrix_mul(
dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
"AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD",
param::MatrixMul::Format::MK4_DOT, 1, 1e-3, std::move(args));
}
#else
TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K4X4X16) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
......@@ -460,6 +472,54 @@ TEST_F(AARCH64, BENCHMARK_GEMV_INT_8X8X32) {
run(M, N, K);
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT8X8X32_MK4_8X12X4) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
Benchmarker<MatrixMul> benchmarker(handle());
Benchmarker<MatrixMul> benchmarker_mk4(handle());
benchmarker.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int32{})
.set_param(param)
.set_display(false);
benchmarker.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_K8X12X4"));
param.format = MatrixMul::Param::Format::MK4_DOT;
benchmarker_mk4.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"));
benchmarker_mk4.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int32{})
.set_param(param)
.set_display(false);
auto run = [&](size_t M, size_t N, size_t K) {
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
auto mk_used = benchmarker_mk4.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
"%f Gflops speedup_vs_normal: %f\n",
M, K, N, default_used, computations / default_used, mk_used,
computations / mk_used, default_used / mk_used);
};
run(256, 256, 128);
for (size_t k = 4; k <= 512; k *= 2) {
for (size_t m = 4; m <= 512; m *= 2) {
for (size_t n = 4; n <= 512; n *= 2) {
run(m, n, k);
}
}
std::cout << std::endl;
}
}
#endif // __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册