提交 92b12685 编写于 作者: M Megvii Engine Team

feat(dnn/aarch64): add aarch64 int8X8X16_mk4_k8x8x8 matmul, performance is better

GitOrigin-RevId: b6af21e8e314b4edd62f0fddcf8578d2eaa0fc2a
上级 5ee1a1c4
......@@ -1310,4 +1310,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
int32_t);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
namespace {
void int8x8x16_mk4_8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type,
B_type, C_type);
megdnn::matmul::GemmInterleaved<
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB,
strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x16(kern_size_param) &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}
bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred(
const KernSizeParam&) const {
return true;
}
size_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type,
B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern(
const KernSizeParam&) const {
return int8x8x16_mk4_8x8x8_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_K8x8x8Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t,
int16_t);
// vim: syntax=cpp.doxygen
......@@ -202,6 +202,22 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_K8X8X8";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
......
......@@ -2101,6 +2101,62 @@ static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) {
vreinterpretq_s32_s8(input2), 3);
}
template <typename T>
static inline void interleave_8x8_mk4_b(const T*& inptr0, const T*& inptr1,
T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"transpose_8x4_1_b only support uint8_t and int8_t");
asm volatile(
"ld1 {v0.4s}, [%[inptr0]], #16\n"
"ld1 {v1.4s}, [%[inptr1]], #16\n"
"ld1 {v2.4s}, [%[inptr0]], #16\n"
"ld1 {v3.4s}, [%[inptr1]], #16\n"
"zip1 v4.4s, v0.4s, v1.4s \n"
"zip2 v5.4s, v0.4s, v1.4s \n"
"zip1 v6.4s, v2.4s, v3.4s\n"
"zip2 v7.4s, v2.4s, v3.4s\n"
"st1 {v4.4s},[%[outptr]],#16\n"
"st1 {v5.4s},[%[outptr]],#16\n"
"st1 {v6.4s},[%[outptr]],#16\n"
"st1 {v7.4s},[%[outptr]],#16\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory");
}
template <typename T>
static inline void transpose_8x8_mk4_b(const T*& inptr0, const T*& inptr1,
T* outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"transpose_8x4_1_b only support uint8_t and int8_t");
asm volatile(
"ld4 {v0.8b-v3.8b}, [%[inptr0]], #32\n"
"ld4 {v4.8b-v7.8b}, [%[inptr1]], #32\n"
"st1 {v0.2s},[%[outptr]],#8\n"
"st1 {v1.2s},[%[outptr]],#8\n"
"st1 {v2.2s},[%[outptr]],#8\n"
"st1 {v3.2s},[%[outptr]],#8\n"
"st1 {v4.2s},[%[outptr]],#8\n"
"st1 {v5.2s},[%[outptr]],#8\n"
"st1 {v6.2s},[%[outptr]],#8\n"
"st1 {v7.2s},[%[outptr]],#8\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory");
}
} // namespace aarch64
} // namespace megdnn
......
/**
* \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <inttypes.h>
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace aarch64 {
namespace matmul_mk4_8x8x8 {
/**
* Overview of register layout:
*
* A 8x8 cell of Lhs is stored in 8bit in v16-v17
* B 8x8 cell of Rhs is stored in 8bit in v0-v15, v20-v23
* C 8x8 block of accumulators is stored in 16bit in v24-v31
*
* +---------------------------------+
* | v0 ------------------------ v7 |
* | v8 ------------------------ v15|
* Rhs +---------------------------------+
* Lhs | |
* +--------+ - - - - +---------------------------------+
* | v16 | | v24 |
* | v17 | | v25 |
* | v16 | | v26 |
* | v17 | | v27 |
* | v16 | | v28 |
* | v17 | | v29 |
* | v16 | | v30 |
* | v17 | | v31 |
* +--------+ - - - - +---------------------------------+
*
* Accumulator
*/
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
// clang-format off
#define LOAD_C_8 \
"ld1 {v0.8h}, [x0], #16\n" \
"ld1 {v1.8h}, [x0], #16\n" \
"ld1 {v2.8h}, [x0], #16\n" \
"ld1 {v3.8h}, [x0], #16\n" \
"ld1 {v4.8h}, [x1], #16\n" \
"ld1 {v5.8h}, [x1], #16\n" \
"ld1 {v6.8h}, [x1], #16\n" \
"ld1 {v7.8h}, [x1], #16\n" \
#define STORE_C_8 \
"st1 {v0.8h}, [x0], #16\n" \
"st1 {v1.8h}, [x0], #16\n" \
"st1 {v2.8h}, [x0], #16\n" \
"st1 {v3.8h}, [x0], #16\n" \
"st1 {v4.8h}, [x1], #16\n" \
"st1 {v5.8h}, [x1], #16\n" \
"st1 {v6.8h}, [x1], #16\n" \
"st1 {v7.8h}, [x1], #16\n" \
register int16_t* outptr asm("x0") = output;
asm volatile(
"add x1, x0, %x[LDC]\n"
"eor v24.16b, v24.16b, v24.16b\n"
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
"eor v25.16b, v25.16b, v25.16b\n"
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
"eor v26.16b, v26.16b, v26.16b\n"
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"eor v27.16b, v27.16b, v27.16b\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
// General loop.
"1:\n"
"dup v0.8b,v20.b[0]\n"
"ld1 {v22.16b}, [%[a_ptr]],#16\n"
"dup v1.8b,v20.b[1]\n"
"ld1 {v23.16b}, [%[a_ptr]],#16\n"
"dup v2.8b,v20.b[2]\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v3.8b,v20.b[3]\n"
"dup v4.8b,v20.b[4]\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n"
"dup v8.8b,v20.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v20.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v20.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v20.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v20.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v20.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v20.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v20.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v21.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v21.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v21.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v21.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v21.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v21.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v21.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v21.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v21.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v21.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v21.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v21.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v21.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v21.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v21.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v21.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v22.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v22.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v22.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v22.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v22.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v22.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v22.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v22.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v22.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v22.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v22.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v22.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v22.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v22.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v22.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v22.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v23.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v23.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v23.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v23.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v23.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v23.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v23.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v23.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v23.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v23.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v23.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v23.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v23.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v23.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v23.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v23.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"subs %w[K], %w[K], #1\n"
"cbnz %w[K], 1b\n"
"cmp %w[is_first_k], #1\n"
"beq 2f\n" LOAD_C_8
"b 3f \n"
"2: \n"
"eor v0.16b, v0.16b, v0.16b\n"
"eor v1.16b, v1.16b, v1.16b\n"
"eor v2.16b, v2.16b, v2.16b\n"
"eor v3.16b, v3.16b, v3.16b\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"3:\n"
"zip1 v8.2d, v24.2d, v25.2d\n"
"zip2 v9.2d, v24.2d, v25.2d\n"
"zip1 v10.2d, v26.2d, v27.2d\n"
"zip2 v11.2d, v26.2d, v27.2d\n"
"zip1 v12.2d, v28.2d, v29.2d\n"
"zip2 v13.2d, v28.2d, v29.2d\n"
"zip1 v14.2d, v30.2d, v31.2d\n"
"zip2 v15.2d, v30.2d, v31.2d\n"
"add v0.8h, v0.8h, v8.8h\n"
"add v1.8h, v1.8h, v10.8h\n"
"add v2.8h, v2.8h, v12.8h\n"
"add v3.8h, v3.8h, v14.8h\n"
"add v4.8h, v4.8h, v9.8h\n"
"add v5.8h, v5.8h, v11.8h\n"
"add v6.8h, v6.8h, v13.8h\n"
"add v7.8h, v7.8h, v15.8h\n"
// Store back into memory
STORE_C_8
:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
// clang-format on
}
static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;
const int8_t* b_ptr = packA;
// clang-format off
register int16_t* outptr asm("x0") = output;
asm volatile(
"add x1, x0, %x[LDC]\n"
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
// General loop.
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
"1:\n"
"dup v0.8b,v20.b[0]\n"
"ld1 {v22.16b}, [%[a_ptr]],#16\n"
"dup v1.8b,v20.b[1]\n"
"ld1 {v23.16b}, [%[a_ptr]],#16\n"
"dup v2.8b,v20.b[2]\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v3.8b,v20.b[3]\n"
"dup v4.8b,v20.b[4]\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n"
"dup v8.8b,v20.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v20.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v20.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v20.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v20.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v20.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v20.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v20.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v21.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v21.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v21.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v21.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v21.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v21.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v21.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v21.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v21.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v21.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v21.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v21.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v21.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v21.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v21.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v21.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v22.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v22.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v22.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v22.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v22.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v22.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v22.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v22.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v22.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v22.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v22.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v22.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v22.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v22.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v22.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v22.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v23.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v23.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v23.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v23.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v23.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v23.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v23.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v23.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v23.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v23.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v23.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v23.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v23.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v23.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v23.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v23.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"subs %w[K], %w[K], #1\n"
"cbnz %w[K], 1b\n"
"cmp %w[is_first_k], #1\n"
"beq 2f\n"
"cmp %x[m_remain], #8 \n"
"beq 8f \n"
"cmp %x[m_remain], #4 \n"
"beq 9f \n"
"8: \n"
"cmp %x[n_remain], #8\n"
"beq 200f \n"
"cmp %x[n_remain], #7\n"
"beq 201f \n"
"cmp %x[n_remain], #6\n"
"beq 202f \n"
"cmp %x[n_remain], #5\n"
"beq 203f \n"
"cmp %x[n_remain], #4\n"
"beq 204f \n"
"cmp %x[n_remain], #3\n"
"beq 205f \n"
"cmp %x[n_remain], #2\n"
"beq 206f \n"
"cmp %x[n_remain], #1\n"
"beq 207f \n"
"200: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v2.8h}, [x0], #16\n"
"ld1 {v3.8h}, [x0], #16\n"
"ld1 {v4.8h}, [x1], #16\n"
"ld1 {v5.8h}, [x1], #16\n"
"ld1 {v6.8h}, [x1], #16\n"
"ld1 {v7.8h}, [x1], #16\n"
"b 3f \n"
"201: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v2.8h}, [x0], #16\n"
"ld1 {v3.d}[0], [x0], #8\n"
"ld1 {v4.8h}, [x1], #16\n"
"ld1 {v5.8h}, [x1], #16\n"
"ld1 {v6.8h}, [x1], #16\n"
"ld1 {v7.d}[0], [x1], #8\n"
"b 3f \n"
"202: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v2.8h}, [x0], #16\n"
"ld1 {v4.8h}, [x1], #16\n"
"ld1 {v5.8h}, [x1], #16\n"
"ld1 {v6.8h}, [x1], #16\n"
"b 3f \n"
"203: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v2.d}[0], [x0], #8\n"
"ld1 {v4.8h}, [x1], #16\n"
"ld1 {v5.8h}, [x1], #16\n"
"ld1 {v6.d}[0], [x1], #8\n"
"b 3f \n"
"204: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v4.8h}, [x1], #16\n"
"ld1 {v5.8h}, [x1], #16\n"
"b 3f \n"
"205: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.d}[0], [x0], #8\n"
"ld1 {v4.8h}, [x1], #16\n"
"ld1 {v5.d}[0], [x1], #8\n"
"b 3f \n"
"206: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v4.8h}, [x1], #16\n"
"b 3f \n"
"207: \n"
"ld1 {v0.d}[0], [x0], #8\n"
"ld1 {v4.d}[0], [x1], #8\n"
"b 3f \n"
"9: \n"
"cmp %x[n_remain], #8\n"
"beq 300f \n"
"cmp %x[n_remain], #7\n"
"beq 301f \n"
"cmp %x[n_remain], #6\n"
"beq 302f \n"
"cmp %x[n_remain], #5\n"
"beq 303f \n"
"cmp %x[n_remain], #4\n"
"beq 304f \n"
"cmp %x[n_remain], #3\n"
"beq 305f \n"
"cmp %x[n_remain], #2\n"
"beq 306f \n"
"cmp %x[n_remain], #1\n"
"beq 307f \n"
"300: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v2.8h}, [x0], #16\n"
"ld1 {v3.8h}, [x0], #16\n"
"b 3f \n"
"301: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v2.8h}, [x0], #16\n"
"ld1 {v3.d}[0], [x0], #8\n"
"b 3f \n"
"302: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v2.8h}, [x0], #16\n"
"b 3f \n"
"303: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"ld1 {v2.d}[0], [x0], #8\n"
"b 3f \n"
"304: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.8h}, [x0], #16\n"
"b 3f \n"
"305: \n"
"ld1 {v0.8h}, [x0], #16\n"
"ld1 {v1.d}[0], [x0], #8\n"
"b 3f \n"
"306: \n"
"ld1 {v0.8h}, [x0], #16\n"
"b 3f \n"
"307: \n"
"ld1 {v0.d}[0], [x0], #8\n"
"b 3f \n"
"2: \n"
"eor v0.16b, v0.16b, v0.16b\n"
"eor v1.16b, v1.16b, v1.16b\n"
"eor v2.16b, v2.16b, v2.16b\n"
"eor v3.16b, v3.16b, v3.16b\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"3:\n"
"zip1 v8.2d, v24.2d, v25.2d\n"
"zip1 v10.2d, v26.2d, v27.2d\n"
"add v0.8h, v0.8h, v8.8h \n"
"zip1 v12.2d, v28.2d, v29.2d\n"
"add v1.8h, v1.8h, v10.8h \n"
"zip1 v14.2d, v30.2d, v31.2d\n"
"add v2.8h, v2.8h, v12.8h \n"
"add v3.8h, v3.8h, v14.8h \n"
"zip2 v9.2d, v24.2d, v25.2d\n"
"zip2 v11.2d, v26.2d, v27.2d \n"
"add v4.8h, v4.8h, v9.8h \n"
"zip2 v13.2d, v28.2d, v29.2d \n"
"add v5.8h, v5.8h, v11.8h \n"
"zip2 v15.2d, v30.2d, v31.2d \n"
"add v6.8h, v6.8h, v13.8h \n"
"add v7.8h, v7.8h, v15.8h \n"
//save to memory
"cmp %x[m_remain], #8 \n"
"beq 4f \n"
"cmp %x[m_remain], #4 \n"
"beq 5f \n"
"4: \n"
"cmp %x[n_remain], #8\n"
"beq 100f \n"
"cmp %x[n_remain], #7\n"
"beq 101f \n"
"cmp %x[n_remain], #6\n"
"beq 102f \n"
"cmp %x[n_remain], #5\n"
"beq 103f \n"
"cmp %x[n_remain], #4\n"
"beq 104f \n"
"cmp %x[n_remain], #3\n"
"beq 105f \n"
"cmp %x[n_remain], #2\n"
"beq 106f \n"
"cmp %x[n_remain], #1\n"
"beq 107f \n"
"100: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v2.8h}, [x0], #16\n"
"st1 {v3.8h}, [x0], #16\n"
"st1 {v4.8h}, [x1], #16\n"
"st1 {v5.8h}, [x1], #16\n"
"st1 {v6.8h}, [x1], #16\n"
"st1 {v7.8h}, [x1], #16\n"
"b 1000f \n"
"101: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v2.8h}, [x0], #16\n"
"st1 {v3.d}[0], [x0], #8\n"
"st1 {v4.8h}, [x1], #16\n"
"st1 {v5.8h}, [x1], #16\n"
"st1 {v6.8h}, [x1], #16\n"
"st1 {v7.d}[0], [x1], #8\n"
"b 1000f \n"
"102: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v2.8h}, [x0], #16\n"
"st1 {v4.8h}, [x1], #16\n"
"st1 {v5.8h}, [x1], #16\n"
"st1 {v6.8h}, [x1], #16\n"
"b 1000f \n"
"103: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v2.d}[0], [x0], #8\n"
"st1 {v4.8h}, [x1], #16\n"
"st1 {v5.8h}, [x1], #16\n"
"st1 {v6.d}[0], [x1], #8\n"
"b 1000f \n"
"104: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v4.8h}, [x1], #16\n"
"st1 {v5.8h}, [x1], #16\n"
"b 1000f \n"
"105: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.d}[0], [x0], #8\n"
"st1 {v4.8h}, [x1], #16\n"
"st1 {v5.d}[0], [x1], #8\n"
"b 1000f \n"
"106: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v4.8h}, [x1], #16\n"
"b 1000f \n"
"107: \n"
"st1 {v0.d}[0], [x0], #8\n"
"st1 {v4.d}[0], [x1], #8\n"
"b 1000f \n"
"5: \n"
"cmp %x[n_remain], #8\n"
"beq 200f \n"
"cmp %x[n_remain], #7\n"
"beq 201f \n"
"cmp %x[n_remain], #6\n"
"beq 202f \n"
"cmp %x[n_remain], #5\n"
"beq 203f \n"
"cmp %x[n_remain], #4\n"
"beq 204f \n"
"cmp %x[n_remain], #3\n"
"beq 205f \n"
"cmp %x[n_remain], #2\n"
"beq 206f \n"
"cmp %x[n_remain], #1\n"
"beq 207f \n"
"200: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v2.8h}, [x0], #16\n"
"st1 {v3.8h}, [x0], #16\n"
"b 1000f \n"
"201: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v2.8h}, [x0], #16\n"
"st1 {v3.d}[0], [x0], #8\n"
"b 1000f \n"
"202: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v2.8h}, [x0], #16\n"
"b 1000f \n"
"203: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"st1 {v2.d}[0], [x0], #8\n"
"b 1000f \n"
"204: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.8h}, [x0], #16\n"
"b 1000f \n"
"205: \n"
"st1 {v0.8h}, [x0], #16\n"
"st1 {v1.d}[0], [x0], #8\n"
"b 1000f \n"
"206: \n"
"st1 {v0.8h}, [x0], #16\n"
"b 1000f \n"
"207: \n"
"st1 {v0.d}[0], [x0], #8\n"
"b 1000f \n"
"1000: \n"
:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
// clang-format on
#undef LOAD_C_8
#undef STORE_C_8
}
static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
// clang-format off
#define LOAD_C_4 \
"ld1 {v0.8h}, [x0], #16\n" \
"ld1 {v1.8h}, [x0], #16\n" \
"ld1 {v2.8h}, [x0], #16\n" \
"ld1 {v3.8h}, [x0], #16\n" \
#define STORE_C_4 \
"st1 {v0.8h}, [x0], #16\n" \
"st1 {v1.8h}, [x0], #16\n" \
"st1 {v2.8h}, [x0], #16\n" \
"st1 {v3.8h}, [x0], #16\n" \
register int16_t* outptr asm("x0") = output;
asm volatile(
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
// General loop.
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
"1:\n"
"dup v0.8b,v20.b[0]\n"
"ld1 {v22.16b}, [%[a_ptr]],#16\n"
"dup v1.8b,v20.b[1]\n"
"ld1 {v23.16b}, [%[a_ptr]],#16\n"
"dup v2.8b,v20.b[2]\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v3.8b,v20.b[3]\n"
"dup v4.8b,v20.b[4]\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n"
"dup v8.8b,v20.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v20.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v20.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v20.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v20.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v20.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v20.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v20.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v21.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v21.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v21.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v21.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v21.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v21.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v21.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v21.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v21.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v21.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v21.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v21.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v21.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v21.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v21.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v21.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v22.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v22.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v22.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v22.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v22.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v22.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v22.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v22.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v22.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v22.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v22.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v22.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v22.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v22.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v22.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v22.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v23.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v23.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v23.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v23.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v23.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v23.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v23.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v23.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v23.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v23.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v23.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v23.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v23.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v23.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v23.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v23.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"subs %w[K], %w[K], #1\n"
"cbnz %w[K], 1b\n"
"cmp %w[is_first_k], #1\n"
"beq 2f\n" LOAD_C_4
"b 3f \n"
"2: \n"
"eor v0.16b, v0.16b, v0.16b\n"
"eor v1.16b, v1.16b, v1.16b\n"
"eor v2.16b, v2.16b, v2.16b\n"
"eor v3.16b, v3.16b, v3.16b\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"3:\n"
"zip1 v8.2d, v24.2d, v25.2d\n"
"zip1 v10.2d, v26.2d, v27.2d\n"
"add v0.8h, v0.8h, v8.8h\n"
"zip1 v12.2d, v28.2d, v29.2d\n"
"add v1.8h, v1.8h, v10.8h\n"
"zip1 v14.2d, v30.2d, v31.2d\n"
"add v2.8h, v2.8h, v12.8h\n"
"add v3.8h, v3.8h, v14.8h\n"
// Store back into memory
STORE_C_4
:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
// clang-format on
#undef LOAD_C_4
#undef STORE_C_4
}
static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
// clang-format off
register int16_t* outptr asm("x0") = output;
asm volatile(
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
// General loop.
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
"1:\n"
"dup v0.8b,v20.b[0]\n"
"ld1 {v22.16b}, [%[a_ptr]],#16\n"
"dup v1.8b,v20.b[1]\n"
"ld1 {v23.16b}, [%[a_ptr]],#16\n"
"dup v2.8b,v20.b[2]\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v3.8b,v20.b[3]\n"
"dup v4.8b,v20.b[4]\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n"
"dup v8.8b,v20.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v20.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v20.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v20.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v20.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v20.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v20.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v20.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v21.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v21.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v21.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v21.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v21.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v21.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v21.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v21.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v21.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v21.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v21.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v21.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v21.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v21.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v21.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v21.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v22.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v22.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v22.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v22.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v22.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v22.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v22.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v22.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v22.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v22.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v22.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v22.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v22.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v22.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v22.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v22.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v23.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v23.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v23.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v23.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v23.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v23.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v23.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v23.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v23.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v23.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v23.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v23.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v23.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v23.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v23.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v23.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"subs %w[K], %w[K], #1 \n"
"cbnz %w[K], 1b \n"
"cmp %w[is_first_k], #1 \n"
"beq 2f \n"
"cmp %w[n_remain],#7 \n"
"beq 200f \n"
"cmp %w[n_remain],#6 \n"
"beq 201f \n"
"cmp %w[n_remain],#5 \n"
"beq 202f \n"
"cmp %w[n_remain],#4 \n"
"beq 203f \n"
"cmp %w[n_remain],#3 \n"
"beq 204f \n"
"cmp %w[n_remain],#2 \n"
"beq 205f \n"
"cmp %w[n_remain],#1 \n"
"beq 206f \n"
"200: \n"
"ld1 {v0.8h}, [x0],#16 \n"
"ld1 {v1.8h}, [x0],#16 \n"
"ld1 {v2.8h}, [x0],#16 \n"
"ld1 {v3.d}[0], [x0],#8 \n"
"b 3f \n"
"201: \n"
"ld1 {v0.8h}, [x0],#16 \n"
"ld1 {v1.8h}, [x0],#16 \n"
"ld1 {v2.8h}, [x0],#16 \n"
"b 3f \n"
"202: \n"
"ld1 {v0.8h}, [x0],#16 \n"
"ld1 {v1.8h}, [x0],#16 \n"
"ld1 {v2.d}[0], [x0],#8 \n"
"b 3f \n"
"203: \n"
"ld1 {v0.8h}, [x0],#16 \n"
"ld1 {v1.8h}, [x0],#16 \n"
"b 3f \n"
"204: \n"
"ld1 {v0.8h}, [x0],#16 \n"
"ld1 {v1.d}[0], [x0],#8 \n"
"b 3f \n"
"205: \n"
"ld1 {v0.8h}, [x0],#16 \n"
"b 3f \n"
"206: \n"
"ld1 {v0.d}[0], [x0],#8 \n"
"b 3f \n"
"2: \n"
"eor v0.16b, v0.16b, v0.16b\n"
"eor v1.16b, v1.16b, v1.16b\n"
"eor v2.16b, v2.16b, v2.16b\n"
"eor v3.16b, v3.16b, v3.16b\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"3: \n"
"zip1 v8.2d, v24.2d, v25.2d\n"
"zip1 v10.2d, v26.2d, v27.2d\n"
"add v0.8h, v0.8h, v8.8h \n"
"zip1 v12.2d, v28.2d, v29.2d\n"
"add v1.8h, v1.8h, v10.8h\n"
"zip1 v14.2d, v30.2d, v31.2d\n"
"add v2.8h, v2.8h, v12.8h\n"
"add v3.8h, v3.8h, v14.8h\n"
// Store back into memory
"cmp %w[n_remain],#7 \n"
"beq 100f \n"
"cmp %w[n_remain],#6 \n"
"beq 101f \n"
"cmp %w[n_remain],#5 \n"
"beq 102f \n"
"cmp %w[n_remain],#4 \n"
"beq 103f \n"
"cmp %w[n_remain],#3 \n"
"beq 104f \n"
"cmp %w[n_remain],#2 \n"
"beq 105f \n"
"cmp %w[n_remain],#1 \n"
"beq 106f \n"
"100: \n"
"st1 {v0.8h}, [x0],#16 \n"
"st1 {v1.8h}, [x0],#16 \n"
"st1 {v2.8h}, [x0],#16 \n"
"st1 {v3.d}[0], [x0],#8 \n"
"b 1000f \n"
"101: \n"
"st1 {v0.8h}, [x0],#16 \n"
"st1 {v1.8h}, [x0],#16 \n"
"st1 {v2.8h}, [x0],#16 \n"
"b 1000f \n"
"102: \n"
"st1 {v0.8h}, [x0],#16 \n"
"st1 {v1.8h}, [x0],#16 \n"
"st1 {v2.d}[0], [x0],#8 \n"
"b 1000f \n"
"103: \n"
"st1 {v0.8h}, [x0],#16 \n"
"st1 {v1.8h}, [x0],#16 \n"
"b 1000f \n"
"104: \n"
"st1 {v0.8h}, [x0],#16 \n"
"st1 {v1.d}[0], [x0],#8 \n"
"b 1000f \n"
"105: \n"
"st1 {v0.8h}, [x0],#16 \n"
"b 1000f \n"
"106: \n"
"st1 {v0.d}[0], [x0],#8 \n"
"b 1000f \n"
"1000: \n"
:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
// clang-format on
#undef LOAD_C_4
#undef STORE_C_4
}
//! pack to icxoc
//! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7))
//! if M K is not times of 8,pack 0 instead
static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
const dt_int8* inptr, int ldin,
int m0, int mmax, int k0, int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 8;
constexpr int pack_k = 8;
constexpr int pack_size = 4;
int8_t tmpbuff0[pack_m * pack_size] = {0};
int8_t tmpbuff1[pack_m * pack_size] = {0};
int8_t zerobuff[pack_m * pack_size] = {0};
const int m_size = mmax - m0;
const int m_end = m_size / pack_m * pack_m + m0;
int remain_m = mmax - m_end;
for (int m_idx = m0; m_idx < m_end; m_idx += pack_m) {
const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
interleave_8x8_mk4_b(inptr0,inptr1,outptr);
}
if (k_idx < kmax) {
memcpy(tmpbuff0, inptr0, sizeof(int8_t) * (kmax - k_idx) * pack_size);
memcpy(tmpbuff1, inptr1, sizeof(int8_t) * (kmax - k_idx) * pack_size);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
interleave_8x8_mk4_b(inptr0, inptr1, outptr);
}
}
int m_idx = m_end;
if (remain_m == 4) {
const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
inptr1 = zerobuff;
interleave_8x8_mk4_b(inptr0,inptr1,outptr);
}
if (k_idx < kmax) {
memcpy(tmpbuff0, inptr0, sizeof(int8_t) * (kmax - k_idx) * pack_size);
inptr0 = tmpbuff0;
inptr1 = zerobuff;
interleave_8x8_mk4_b(inptr0, inptr1, outptr);
}
}
}
//! pack to nxic
//! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead.
static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_n = 8;
constexpr int pack_k = 8;
constexpr int pack_size = 4;
int8_t tmpbuff0[pack_n * pack_size] = {0};
int8_t tmpbuff1[pack_n * pack_size] = {0};
int8_t zerobuff[pack_n * pack_size] = {0};
const int ksize = round_up<int>((kmax - k0),8);
const int nsize = nmax - n0;
const int n_end = nsize / pack_n * pack_n + n0;
const int remain_n = nsize % pack_n;
int output_stride = ksize * pack_n;
int8_t* outptr_base = out;
int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size;
const int8_t* inptr1 = inptr0 + ldin;
prefetch_3x(inptr0);
prefetch_3x(inptr1);
auto outptr = outptr_base;
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) {
transpose_8x8_mk4_b(inptr0, inptr1, outptr);
outptr += output_stride;
}
if (remain_n > 0) {
memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size);
memcpy(tmpbuff1, inptr1, sizeof(int8_t) * remain_n * pack_size);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
transpose_8x8_mk4_b(inptr0, inptr1, outptr);
outptr += output_stride;
}
outptr_base += pack_n * pack_k;
}
if(k_idx < kmax){
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size;
const int8_t* inptr1 = nullptr;
prefetch_3x(inptr0);
auto outptr = outptr_base;
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) {
inptr1 = zerobuff;
transpose_8x8_mk4_b(inptr0, inptr1, outptr);
outptr += output_stride;
}
if (remain_n > 0) {
memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size);
inptr1 = zerobuff;
inptr0 = tmpbuff0;
transpose_8x8_mk4_b(inptr0, inptr1, outptr);
outptr += output_stride;
}
outptr_base += pack_n * pack_size;
}
}
} // namespace matmul_mk4_16x12x4_a53
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -13,6 +13,7 @@
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
......@@ -357,4 +358,81 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_mk4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8);
void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0,
ymax, k0, kmax);
}
void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0,
xmax, k0, kmax);
}
void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
constexpr size_t pack_size = 4;
constexpr size_t pack_m = 8;
constexpr size_t pack_n = 8;
const size_t remain_n = N % pack_n;
size_t remain_m = M % pack_m;
K = round_up<size_t>(K, 8);
size_t KSIZE8 = K * pack_n;
size_t m_idx = 0;
for (; m_idx + pack_m <= M; m_idx += pack_m) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, pack_m, pack_n);
output += pack_n * pack_size;
cur_packB += KSIZE8;
}
if (remain_n > 0) {
matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, pack_m, remain_n);
output += remain_n * pack_size;
cur_packB += KSIZE8;
}
packA += KSIZE8;
}
if (remain_m == 4) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, 4, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, 4, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
}
}
// vim: syntax=cpp.doxygen
......@@ -26,6 +26,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false,
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16,
16, 12, 4, false, false,
gemm_s8x8x16_mk4_16x12_a53);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, false,
gemm_s8x8x16_mk4_8x8x8);
} // namespace matmul
} // namespace aarch64
......
......@@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16;
AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4;
AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8;
AlgoInt8x8x16MK4_K8x8x8 int8x8x16_mk4_k8x8x8;
AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1;
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8;
......@@ -73,6 +74,7 @@ public:
#endif
all_algos.emplace_back(&int8x8x16_k4x4x16);
all_algos.emplace_back(&int8x8x16_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_4x4x8);
all_algos.emplace_back(&int8x8x16_mk4_16x12x4);
......
......@@ -57,6 +57,7 @@ private:
#else
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
#endif
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16
class AlgoPack;
};
......
......@@ -122,6 +122,20 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) {
std::move(args));
}
TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_MK4) {
std::vector<matrix_mul::TestArg> args;
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17})
for (size_t n :
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24})
for (size_t k :
{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29})
args.emplace_back(m, n, k, 0);
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "AARCH64_INT8X8X16_MK4_K8X8X8",
param::MatrixMul::Format::MK4, 1, 1e-3,
std::move(args));
}
TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "AARCH64_INT8X8X16_MK4_4X4X8",
......@@ -396,6 +410,71 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) {
run(384, 384, 384);
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
Benchmarker<MatrixMul> benchmarker(handle());
Benchmarker<MatrixMul> benchmarker_mk4(handle());
Benchmarker<MatrixMul> benchmarker_mk4_4x4x8(handle());
benchmarker.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
benchmarker.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16"));
param.format = MatrixMul::Param::Format::MK4;
benchmarker_mk4.set_before_exec_callback(
AlgoChecker<MatrixMul>(
"AARCH64_INT8X8X16_MK4_K8X8X8"
));
benchmarker_mk4.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
benchmarker_mk4_4x4x8.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8"));
benchmarker_mk4_4x4x8.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
auto run = [&](size_t M, size_t N, size_t K) {
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
auto mk_used = benchmarker_mk4.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
auto mk4_4x4x8_used =
benchmarker_mk4_4x4x8.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
"%f Gflops speedup: %f, mk4_4x4x8 %f Gflops %f ms speedup: %f\n",
M, K, N, default_used, computations / default_used, mk_used,
computations / mk_used, default_used / mk_used,
computations / mk4_4x4x8_used, mk4_4x4x8_used , mk4_4x4x8_used/mk_used);
};
run(384, 384, 384);
run(512, 512, 512);
run(1024, 1024, 384);
run(256, 256, 384);
for(int m = 32; m <= 512;m*=2)
for(int n = 32; n <= 512;n*=2)
for(int k = 32; k < 512;k*=2){
run(m,n,k);
}
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册