diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index aef6de4233748b11db7abb86ec63293223d5e4f7..d870edfef09094b007a773f1c5e564eb4f5cc45f 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -88,6 +88,7 @@ enum class AlgoDataType : uint32_t { QUINT8X8X32 = 1 << 3, INT8X8X16 = 1 << 4, INT16X16X32 = 1 << 5, + INT4X4X16 = 1 << 6, }; /*! diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index e360ced9659d01d87763c4c677d357b41346b1c8..0df002323e4931bfe5ac4d850ccc1368d47130af 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -17,6 +17,7 @@ #include "src/aarch64/matrix_mul/int8/strategy.h" #include "src/aarch64/matrix_mul/int8_dot/strategy.h" #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" +#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" #include "src/aarch64/matrix_mul/quint8/strategy.h" #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" #include "src/aarch64/matrix_mul/quint8_dot/strategy.h" @@ -1394,4 +1395,75 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t, int16_t, AlgoDataType::INT8X8X16, MK4); +/* ===================== Int4x4x16 K8x8x8 algo ===================== */ +namespace { +void int4x4x16_k8x8x16_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int4x4x16_k8x8x8_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type, + C_type); + + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS4 && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16 && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + (kern_size_param.K & 1) == 0 && (kern_size_param.N & 1) == 0; +} + +bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::preferred( + const KernSizeParam& kern_size_param) const { + MEGDNN_MARK_USED_VAR(kern_size_param); + return true; +} + +size_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt4x4x16K8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_kern( + const KernSizeParam&) const { + return int4x4x16_k8x8x16_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt4x4x16K8x8x8, + megdnn_aarch64_matmul_kern, + "AlgoInt4x4x16K8x8x8Impl"_hash, + aarch64::matmul::gemm_s4x4x16_s4_8x8x8, + int8_t, int16_t, AlgoDataType::INT4X4X16, + DEFAULT); // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index aa9e7e0eee13c78d94b376c24ddc45b19b9f8cb0..fcb7188f30533018bba48d6621c4e0f44b344ffd 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -192,6 +192,19 @@ public: MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16) }; +class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + PackMode packmode() const override { return PackMode::DEFAULT; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(AARCH64_INT4X4X16_K8X8X8) +}; + class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/aarch64/matrix_mul/asm/common.h b/dnn/src/aarch64/matrix_mul/asm/common.h index 369810fbddc2eeedcf91794728eba87a81ff0f38..5c4ed664c8b1fb2a67d69871f9ac86581928425b 100644 --- a/dnn/src/aarch64/matrix_mul/asm/common.h +++ b/dnn/src/aarch64/matrix_mul/asm/common.h @@ -925,6 +925,42 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, : "v0", "v1", "v2", "v3", "memory"); } + +template +static inline void interleave_8x4_1_b_with_shift( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr) { + static_assert(sizeof(T) == 1, "only support size == 1"); + asm volatile( + "ld1 {v0.s}[0], [%[inptr0]], #4\n" + "ld1 {v0.s}[1], [%[inptr1]], #4\n" + "ld1 {v0.s}[2], [%[inptr2]], #4\n" + "ld1 {v0.s}[3], [%[inptr3]], #4\n" + "ld1 {v1.s}[0], [%[inptr4]], #4\n" + "ld1 {v1.s}[1], [%[inptr5]], #4\n" + "ld1 {v1.s}[2], [%[inptr6]], #4\n" + "ld1 {v1.s}[3], [%[inptr7]], #4\n" + "shl v2.16b, v0.16b, #4\n" + "shl v5.16b, v1.16b, #4\n" + "sshr v3.16b, v0.16b, #4\n" // hig + "sshr v4.16b, v2.16b, #4\n" // low + "sshr v6.16b, v1.16b, #4\n" // hig + "sshr v7.16b, v5.16b, #4\n" // low + "zip1 v8.16b, v4.16b, v3.16b\n" + "zip2 v9.16b, v4.16b, v3.16b\n" + "zip1 v10.16b, v7.16b, v6.16b\n" + "zip2 v11.16b, v7.16b, v6.16b\n" + "st1 {v8.16b-v11.16b},[%[outptr]],#64" + : [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), + [ inptr2 ] "+r"(inptr2), [ inptr3 ] "+r"(inptr3), + [ inptr4 ] "+r"(inptr4), [ inptr5 ] "+r"(inptr5), + [ inptr6 ] "+r"(inptr6), [ inptr7 ] "+r"(inptr7), + [ outptr ] "+r"(outptr) + : + : "v0", "v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","memory"); +} + template static inline void interleave_8x8_1_h(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, @@ -1059,6 +1095,7 @@ static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, : "v0", "v1", "v2", "v3", "v4", "cc", "memory"); } + template static inline void interleave_4x16_1_s(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, @@ -1772,6 +1809,54 @@ static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, : "v0", "v1", "v2", "v3", "v4", "v5", "memory"); } +template +static inline void transpose_4x8_1_b_with_shift(const T*& inptr0, const T*& inptr1, + const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, + const T*& inptr6, const T*& inptr7, + T*& outptr) { + + static int8x16_t shuffle_idx = {0, 4, 8, 12, 1, 5, 9, 13, + 2, 6, 10, 14, 3, 7, 11, 15}; + static_assert( + std::is_same::value || std::is_same::value, + "transpose_8x4_1_b only support uint8_t and int8_t"); + asm volatile( + "ld1 {v0.s}[0], [%[inptr0]], #4\n" // A1A2A3A4 + "ld1 {v0.s}[1], [%[inptr1]], #4\n" // B1B2B3B4 + "ld1 {v0.s}[2], [%[inptr2]], #4\n" // C1C2C3C4 + "ld1 {v0.s}[3], [%[inptr3]], #4\n" // D1D2D3D4 + "ld1 {v1.s}[0], [%[inptr4]], #4\n" // E1E2E3E4 + "ld1 {v1.s}[1], [%[inptr5]], #4\n" // F1F2F3F4 + "ld1 {v1.s}[2], [%[inptr6]], #4\n" // G1G2G3G4 + "ld1 {v1.s}[3], [%[inptr7]], #4\n" // H1H2H3H4 + + "tbl v2.16b, {v0.16b}, %[shuffle_idx].16b \n" // A1B1C1D1A2B2C2D2A3B3C3D3A4B4C4D4 + "tbl v3.16b, {v1.16b}, %[shuffle_idx].16b \n" // E1F1G1H1E2F2G2H2E3F3G3H3E4F4G4H4 + + "zip1 v4.4s, v2.4s, v3.4s\n" // A1B1C1D1E1F1G1H1 A2B2C2D2E2F2G2H2 + "zip2 v5.4s, v2.4s, v3.4s\n" // A3B3C3D3E3F3G3H3 A4B4C4D4E4F4G4H4 + + "shl v6.16b, v4.16b, #4\n" + "sshr v7.16b, v4.16b, #4\n" // hig + "sshr v8.16b, v6.16b, #4\n" // low + "shl v9.16b, v5.16b, #4\n" + "sshr v10.16b, v5.16b, #4\n" // hig + "sshr v11.16b, v9.16b, #4\n" // low + "zip1 v0.2d,v8.2d,v7.2d\n" + "zip2 v1.2d,v8.2d,v7.2d\n" + "zip1 v2.2d,v11.2d,v10.2d\n" + "zip2 v3.2d,v11.2d,v10.2d\n" + "st1 {v0.2d-v3.2d},[%[outptr]],#64\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [shuffle_idx]"+w"(shuffle_idx), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","v8","v9","v10","v11","memory"); +} template static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, diff --git a/dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h b/dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h new file mode 100644 index 0000000000000000000000000000000000000000..35ea47ef3e413052cc61e0fbc7ff616f43c9d40b --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h @@ -0,0 +1,913 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int4x4x16/kernel_8x8x8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_s4_4x4x16 { + +/** + * Overview of register layout: + * + * +---------+---------+---------+---------+ + * |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]| + * Rhs +---------+---------+---------+---------+ + * Lhs | | | + * + * +--------+ - - - - +---------+---------+---------+---------+ + * |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]| + * |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]| + * |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]| + * |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]| + * +--------+ - - - - +---------+---------+---------+---------+ + * + * Accumulator + */ + +static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + K /= 8; + LDC = LDC * sizeof(int16_t); + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; +// clang-format off +#define LOAD_LINE(reg_index, n) \ + "cmp x8, #0 \n" \ + "beq 105f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" reg_index ".8h}, [x" n "], #16\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "blt 101" n "f\n" \ + "ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \ + "cmp %w[n_remain], #3\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[3], [x" n "], #2\n" \ + "cmp %w[n_remain], #4\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[4], [x" n "], #2\n" \ + "cmp %w[n_remain], #5\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[5], [x" n "], #2\n" \ + "cmp %w[n_remain], #6\n" \ + "beq 101" n "f\n" \ + "ld1 {v" reg_index ".h}[6], [x" n "], #2\n" \ + "101" n ":\n" \ + "sub x8, x8, #1\n" + +#define LOAD_C \ + "mov x8, %x[m_remain]\n" \ + LOAD_LINE("24", "0") \ + LOAD_LINE("25", "1") \ + LOAD_LINE("26", "2") \ + LOAD_LINE("27", "3") \ + LOAD_LINE("28", "4") \ + LOAD_LINE("29", "5") \ + LOAD_LINE("30", "6") \ + LOAD_LINE("31", "7") \ + "105:\n" + +#define STORE_LINE(reg_index, n) \ + "cmp x8, #0 \n" \ + "beq 105f\n" \ + "cmp %w[n_remain], #8\n" \ + "blt 102" n "f\n" \ + "st1 {v" reg_index ".8h}, [x" n "], #16\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[0], [x" n "], #2\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[1], [x" n "], #2\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[2], [x" n "], #2\n" \ + "cmp %w[n_remain], #3\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[3], [x" n "], #2\n" \ + "cmp %w[n_remain], #4\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[4], [x" n "], #2\n" \ + "cmp %w[n_remain], #5\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[5], [x" n "], #2\n" \ + "cmp %w[n_remain], #6\n" \ + "beq 103" n "f\n" \ + "st1 {v" reg_index ".h}[6], [x" n "], #2\n" \ + "103" n ":\n" \ + "sub x8, x8, #1\n" + +#define STORE_C \ + "mov x8, %x[m_remain]\n" \ + STORE_LINE("24", "0") \ + STORE_LINE("25", "1") \ + STORE_LINE("26", "2") \ + STORE_LINE("27", "3") \ + STORE_LINE("28", "4") \ + STORE_LINE("29", "5") \ + STORE_LINE("30", "6") \ + STORE_LINE("31", "7") \ + "105:\n" + // clang-format on + register int16_t* outptr asm("x0") = output; + asm volatile( + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 2f\n" LOAD_C + "b 1f\n" + + "2:\n" // Clear the C regs. + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + // General loop. + "1:\n" + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "dup v0.8b,v20.b[0]\n" + "dup v1.8b,v20.b[1]\n" + "dup v2.8b,v20.b[2]\n" + "dup v3.8b,v20.b[3]\n" + "ld1 {v22.16b}, [%[a_ptr]],#16\n" + "ld1 {v23.16b}, [%[a_ptr]],#16\n" + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v4.8b,v20.b[4]\n" + "dup v5.8b,v20.b[5]\n" + "dup v6.8b,v20.b[6]\n" + "dup v7.8b,v20.b[7]\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v20.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v20.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v20.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v20.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v20.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v20.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v20.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v20.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v18.8b}, [%[b_ptr]], 8\n" + + "dup v0.8b,v21.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v21.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v21.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v21.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v21.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v21.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v21.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v21.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v19.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v21.b[8]\n" + "smlal v24.8h, v0.8b, v18.8b\n" + "dup v9.8b,v21.b[9]\n" + "smlal v25.8h, v1.8b, v18.8b\n" + "dup v10.8b,v21.b[10]\n" + "smlal v26.8h, v2.8b, v18.8b\n" + "dup v11.8b,v21.b[11]\n" + "smlal v27.8h, v3.8b, v18.8b\n" + "dup v12.8b,v21.b[12]\n" + "smlal v28.8h, v4.8b, v18.8b\n" + "dup v13.8b,v21.b[13]\n" + "smlal v29.8h, v5.8b, v18.8b\n" + "dup v14.8b,v21.b[14]\n" + "smlal v30.8h, v6.8b, v18.8b\n" + "dup v15.8b,v21.b[15]\n" + "smlal v31.8h, v7.8b, v18.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v22.b[0]\n" + "smlal v24.8h, v8.8b, v19.8b\n" + "dup v1.8b,v22.b[1]\n" + "smlal v25.8h, v9.8b, v19.8b\n" + "dup v2.8b,v22.b[2]\n" + "smlal v26.8h, v10.8b, v19.8b\n" + "dup v3.8b,v22.b[3]\n" + "smlal v27.8h, v11.8b, v19.8b\n" + "dup v4.8b,v22.b[4]\n" + "smlal v28.8h, v12.8b, v19.8b\n" + "dup v5.8b,v22.b[5]\n" + "smlal v29.8h, v13.8b, v19.8b\n" + "dup v6.8b,v22.b[6]\n" + "smlal v30.8h, v14.8b, v19.8b\n" + "dup v7.8b,v22.b[7]\n" + "smlal v31.8h, v15.8b, v19.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v22.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v22.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v22.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v22.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v22.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v22.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v22.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v22.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v18.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v23.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v23.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v23.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v23.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v23.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v23.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v23.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v23.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v19.8b}, [%[b_ptr]], 8\n" + "dup v8.8b,v23.b[8]\n" + "smlal v24.8h, v0.8b, v18.8b\n" + "dup v9.8b,v23.b[9]\n" + "smlal v25.8h, v1.8b, v18.8b\n" + "dup v10.8b,v23.b[10]\n" + "smlal v26.8h, v2.8b, v18.8b\n" + "dup v11.8b,v23.b[11]\n" + "smlal v27.8h, v3.8b, v18.8b\n" + "dup v12.8b,v23.b[12]\n" + "smlal v28.8h, v4.8b, v18.8b\n" + "dup v13.8b,v23.b[13]\n" + "smlal v29.8h, v5.8b, v18.8b\n" + "dup v14.8b,v23.b[14]\n" + "smlal v30.8h, v6.8b, v18.8b\n" + "dup v15.8b,v23.b[15]\n" + "smlal v31.8h, v7.8b, v18.8b\n" + + "smlal v24.8h, v8.8b, v19.8b\n" + "smlal v25.8h, v9.8b, v19.8b\n" + "smlal v26.8h, v10.8b, v19.8b\n" + "smlal v27.8h, v11.8b, v19.8b\n" + "smlal v28.8h, v12.8b, v19.8b\n" + "smlal v29.8h, v13.8b, v19.8b\n" + "smlal v30.8h, v14.8b, v19.8b\n" + "smlal v31.8h, v15.8b, v19.8b\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 1b\n" + + "3:\n" + // Store back into memory + STORE_C + + : + [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), + [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), + [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), + [ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) + : + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + K /= 8; + LDC = LDC * sizeof(int16_t); + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; +// clang-format off + +#define LOAD_C_8 \ + "ld1 {v24.8h}, [x0], #16\n" \ + "ld1 {v25.8h}, [x1], #16\n" \ + "ld1 {v26.8h}, [x2], #16\n" \ + "ld1 {v27.8h}, [x3], #16\n" \ + "ld1 {v28.8h}, [x4], #16\n" \ + "ld1 {v29.8h}, [x5], #16\n" \ + "ld1 {v30.8h}, [x6], #16\n" \ + "ld1 {v31.8h}, [x7], #16\n" \ + + +#define STORE_C_8 \ + "st1 {v24.8h}, [x0], #16\n" \ + "st1 {v25.8h}, [x1], #16\n" \ + "st1 {v26.8h}, [x2], #16\n" \ + "st1 {v27.8h}, [x3], #16\n" \ + "st1 {v28.8h}, [x4], #16\n" \ + "st1 {v29.8h}, [x5], #16\n" \ + "st1 {v30.8h}, [x6], #16\n" \ + "st1 {v31.8h}, [x7], #16\n" \ + +// clang-format on + register int16_t* outptr asm("x0") = output; + asm volatile( + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 2f\n" LOAD_C_8 + "b 1f\n" + + "2:\n" // Clear the C regs. + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + // General loop. + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" + "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" + "1:\n" + // "ld1 {v20.16b}, [%[a_ptr]],#16\n" + // "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "dup v0.8b,v20.b[0]\n" + "ld1 {v22.16b}, [%[a_ptr]],#16\n" + "dup v1.8b,v20.b[1]\n" + "ld1 {v23.16b}, [%[a_ptr]],#16\n" + "dup v2.8b,v20.b[2]\n" + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v3.8b,v20.b[3]\n" + "dup v4.8b,v20.b[4]\n" + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v5.8b,v20.b[5]\n" + "dup v6.8b,v20.b[6]\n" + "dup v7.8b,v20.b[7]\n" + + + "dup v8.8b,v20.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v20.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v20.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v20.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v20.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v20.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v20.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v20.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + + "dup v0.8b,v21.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v21.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v21.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v21.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v21.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v21.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v21.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v21.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v21.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v21.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v21.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v21.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v21.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v21.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v21.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v21.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v22.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v22.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v22.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v22.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v22.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v22.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v22.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v22.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v22.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v22.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v22.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v22.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v22.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v22.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v22.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v22.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v23.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v23.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v23.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v23.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v23.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v23.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v23.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v23.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v8.8b,v23.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v23.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v23.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v23.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v23.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v23.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v23.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v23.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "smlal v31.8h, v15.8b, v17.8b\n" + //"ld1 {v20.16b}, [%[a_ptr]],#16\n" + //"ld1 {v21.16b}, [%[a_ptr]],#16\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 1b\n" + + "3:\n" + // Store back into memory + STORE_C_8 + + : + [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), + [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), + [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), + [ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) + : + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} +//packa +static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[8]; + int8_t tmpbuff0[8]; + int8_t tmpbuff1[8]; + int8_t tmpbuff2[8]; + int8_t tmpbuff3[8]; + int8_t tmpbuff4[8]; + int8_t tmpbuff5[8]; + int8_t tmpbuff6[8]; + int8_t tmpbuff7[8]; + std::memset(zerobuff, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff0, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff1, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff2, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff3, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff4, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff5, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); + ldin /= 2; + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + int K = (kmax - k0)/2; + //! read 4 * 16 in each row + for (; K > 3; K -= 4) { + transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, + inptr5, inptr6, inptr7, outptr); + } + + if (K > 0) { + std::memcpy(tmpbuff0,inptr0,K); + std::memcpy(tmpbuff1,inptr1,K); + std::memcpy(tmpbuff2,inptr2,K); + std::memcpy(tmpbuff3,inptr3,K); + std::memcpy(tmpbuff4,inptr4,K); + std::memcpy(tmpbuff5,inptr5,K); + std::memcpy(tmpbuff6,inptr6,K); + std::memcpy(tmpbuff7,inptr7,K); + inptr0 = tmpbuff0; + inptr1 = tmpbuff1; + inptr2 = tmpbuff2; + inptr3 = tmpbuff3; + inptr4 = tmpbuff4; + inptr5 = tmpbuff5; + inptr6 = tmpbuff6; + inptr7 = tmpbuff7; + transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, + inptr5, inptr6, inptr7, outptr); + } + } + for (; y < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + int K = (kmax - k0)/2; + //! read 4 * 16 in each row + for (; K > 3; K -= 4) { + if (y + 7 >= ymax) { + switch (y + 7 - ymax) { + case 6: + inptr1 = zerobuff; MEGDNN_FALLTHRU + case 5: + inptr2 = zerobuff; MEGDNN_FALLTHRU + case 4: + inptr3 = zerobuff; MEGDNN_FALLTHRU + case 3: + inptr4 = zerobuff; MEGDNN_FALLTHRU + case 2: + inptr5 = zerobuff; MEGDNN_FALLTHRU + case 1: + inptr6 = zerobuff; MEGDNN_FALLTHRU + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, + inptr5, inptr6, inptr7, outptr); + } + if (K > 0) { + if (y + 7 >= ymax) { + switch (y + 7 - ymax) { + case 6: + inptr1 = zerobuff; MEGDNN_FALLTHRU + case 5: + inptr2 = zerobuff; MEGDNN_FALLTHRU + case 4: + inptr3 = zerobuff; MEGDNN_FALLTHRU + case 3: + inptr4 = zerobuff; MEGDNN_FALLTHRU + case 2: + inptr5 = zerobuff; MEGDNN_FALLTHRU + case 1: + inptr6 = zerobuff; MEGDNN_FALLTHRU + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + std::memcpy(tmpbuff0,inptr0,K); + std::memcpy(tmpbuff1,inptr1,K); + std::memcpy(tmpbuff2,inptr2,K); + std::memcpy(tmpbuff3,inptr3,K); + std::memcpy(tmpbuff4,inptr4,K); + std::memcpy(tmpbuff5,inptr5,K); + std::memcpy(tmpbuff6,inptr6,K); + std::memcpy(tmpbuff7,inptr7,K); + inptr0 = tmpbuff0; + inptr1 = tmpbuff1; + inptr2 = tmpbuff2; + inptr3 = tmpbuff3; + inptr4 = tmpbuff4; + inptr5 = tmpbuff5; + inptr6 = tmpbuff6; + inptr7 = tmpbuff7; + transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, + inptr5, inptr6, inptr7, outptr); + } + } +} +//packb +static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[8]; + int8_t tmpbuff0[8]; + int8_t tmpbuff1[8]; + int8_t tmpbuff2[8]; + int8_t tmpbuff3[8]; + int8_t tmpbuff4[8]; + int8_t tmpbuff5[8]; + int8_t tmpbuff6[8]; + int8_t tmpbuff7[8]; + std::memset(zerobuff, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff0, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff1, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff2, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff3, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff4, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff5, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); + std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); + const int ksize = kmax - k0; + const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4 + int8_t* outptr = out; + int8_t* outptr_interleave = nullptr; + + int k = k0; + ldin /= 2; + xmax = xmax / 2; + for (; k + 7 < kmax; k += 8) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + int x = x0; + int8_t* outptr_inner = outptr; + for (; x + 3 < xmax; x += 4) { + outptr_interleave = outptr_inner; + interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr_inner += ksize8; + } + + if (x < xmax) { + int remainx = xmax - x; + std::memcpy(tmpbuff0,inptr0,remainx); + std::memcpy(tmpbuff1,inptr1,remainx); + std::memcpy(tmpbuff2,inptr2,remainx); + std::memcpy(tmpbuff3,inptr3,remainx); + std::memcpy(tmpbuff4,inptr4,remainx); + std::memcpy(tmpbuff5,inptr5,remainx); + std::memcpy(tmpbuff6,inptr6,remainx); + std::memcpy(tmpbuff7,inptr7,remainx); + inptr0 = tmpbuff0; + inptr1 = tmpbuff1; + inptr2 = tmpbuff2; + inptr3 = tmpbuff3; + inptr4 = tmpbuff4; + inptr5 = tmpbuff5; + inptr6 = tmpbuff6; + inptr7 = tmpbuff7; + + outptr_interleave = outptr_inner; + interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr_inner += ksize8; + } + outptr += 64; + } + if (k < kmax) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + int k_remain = kmax - k - 1; + int x = x0; + int8_t* outptr_inner = outptr; + for (; x + 3 < xmax; x += 4) { + switch (k_remain) { + case 0: + inptr1 = zerobuff; + MEGDNN_FALLTHRU; + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU; + case 2: + inptr3 = zerobuff; + MEGDNN_FALLTHRU; + case 3: + inptr4 = zerobuff; + MEGDNN_FALLTHRU; + case 4: + inptr5 = zerobuff; + MEGDNN_FALLTHRU; + case 5: + inptr6 = zerobuff; + MEGDNN_FALLTHRU; + case 6: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + break; + } + outptr_interleave = outptr_inner; + interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr_inner += ksize8; + } + if (x < xmax) { + switch (k_remain) { + case 0: + inptr1 = zerobuff; + MEGDNN_FALLTHRU; + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU; + case 2: + inptr3 = zerobuff; + MEGDNN_FALLTHRU; + case 3: + inptr4 = zerobuff; + MEGDNN_FALLTHRU; + case 4: + inptr5 = zerobuff; + MEGDNN_FALLTHRU; + case 5: + inptr6 = zerobuff; + MEGDNN_FALLTHRU; + case 6: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + break; + } + int remainx = xmax - x; + outptr_interleave = outptr_inner; + std::memcpy(tmpbuff0,inptr0,remainx); + std::memcpy(tmpbuff1,inptr1,remainx); + std::memcpy(tmpbuff2,inptr2,remainx); + std::memcpy(tmpbuff3,inptr3,remainx); + std::memcpy(tmpbuff4,inptr4,remainx); + std::memcpy(tmpbuff5,inptr5,remainx); + std::memcpy(tmpbuff6,inptr6,remainx); + std::memcpy(tmpbuff7,inptr7,remainx); + inptr0 = tmpbuff0; + inptr1 = tmpbuff1; + inptr2 = tmpbuff2; + inptr3 = tmpbuff3; + inptr4 = tmpbuff4; + inptr5 = tmpbuff5; + inptr6 = tmpbuff6; + inptr7 = tmpbuff7; + + outptr_interleave = outptr_inner; + interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr_interleave); + outptr_inner += ksize8; + } + } +} + +} // namespace matmul_4x4x16 +} // namespace aarch64 +} // namespace megdnn + + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp b/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b91f61d0d3222503e9a97f9725d8d490f2477c3 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp @@ -0,0 +1,109 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" +#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/gemm_common.h" + +using namespace megdnn; +using namespace aarch64; +using namespace aarch64::matmul; + +// ===========================gemm_s4x4x16_s4_8x8x8================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); +void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0, + kmax); + } else { + matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0, + kmax); + } else { + matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + (A_dtype.enumv() == DTypeEnum::QuantizedS4 && + C_dtype.enumv() == DTypeEnum::QuantizedS16), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 8 + K = round_up(K, 8); + const int K8 = K * 8; + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int16_t* output = C + (m * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC, + is_first_k, A_INTERLEAVE, B_INTERLEAVE); + output += B_INTERLEAVE; + cur_packB += K8; + } + + for (; n < N; n += B_INTERLEAVE) { + matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, + is_first_k, A_INTERLEAVE, + std::min(N - n, B_INTERLEAVE)); + output += B_INTERLEAVE; + cur_packB += K8; + } + + packA += K8; + } + + for (; m < M; m += A_INTERLEAVE) { + int16_t* output = C + (m * LDC); + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n < N; n += B_INTERLEAVE) { + matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, + is_first_k, + std::min(M - m, A_INTERLEAVE), + std::min(N - n, B_INTERLEAVE)); + output += B_INTERLEAVE; + cur_packB += K8; + } + packA += K8; + } +} + + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h b/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h new file mode 100644 index 0000000000000000000000000000000000000000..52762c24de4dc95997e8a631527fd47425ca3b52 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h @@ -0,0 +1,26 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "src/fallback/matrix_mul/gemm_common.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul { + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, + gemm_s4x4x16_s4_8x8x8); + +} // namespace matmul +} // namespace aarch64 +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index f4ee77e977d6fc44e9db1daca8062835947c130e..c673c855f8fa1ddcb89b77e91142f522956b257b 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -50,6 +50,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #else AlgoQuint8K8x8x8 quint8_k8x8x8; #endif + AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8; SmallVector m_all_algos; fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; @@ -87,6 +88,7 @@ public: #else m_all_algos.emplace_back(&quint8_k8x8x8); #endif + m_all_algos.emplace_back(&int4x4x16_k8x8x8); for (auto&& algo : m_all_algos) { m_all_algos_map.emplace(algo->info().desc, algo); diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index 0e4a5fa920ae767c69e2481a381ed32eb8ced98f..fce50f989e70a5cdb3d69c77c362fdb3f7ebce51 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -66,8 +66,8 @@ private: #else class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 #endif - class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 - + class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16 + class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 class AlgoPack; public: static const AlgoPack& algo_pack(); diff --git a/dnn/src/common/matrix_mul.cpp b/dnn/src/common/matrix_mul.cpp index 484d3371225dd26be964d39f69d6af31f8883166..c765e10ebda38016248d5fff60df34454075114d 100644 --- a/dnn/src/common/matrix_mul.cpp +++ b/dnn/src/common/matrix_mul.cpp @@ -33,6 +33,8 @@ void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { C_candi = dtype::QuantizedS32(mul_scale(A, B)); } else if (A.enumv() == DTypeEnum::Quantized4Asymm) { C_candi = dtype::QuantizedS32(mul_scale(A, B)); + } else if (A.enumv() == DTypeEnum::QuantizedS4) { + C_candi = dtype::QuantizedS16(mul_scale(A, B)); } if (!C.valid()) { C = C_candi; @@ -169,6 +171,8 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, A.dtype.enumv() == DTypeEnum::Quantized8Asymm || A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS32); + } else if(A.dtype.enumv() == DTypeEnum::QuantizedS4){ + megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS16); } megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index f740b43575fd7c709392951352094f217cc66ae8..64bc738ff23dc3e99c6e1cdb6f626070358e57dc 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -154,6 +154,7 @@ public: AARCH64_QUINT8_K8X8X4_DOTPROD, AARCH64_QUINT8_GEMV_DOTPROD, AARCH64_QUINT8_K8X8X8, + AARCH64_INT4X4X16_K8X8X8, #else ARMV7_F32 = 1 << 16, ARMV7_F32_MK4_PACK_4X12, diff --git a/dnn/src/naive/matrix_mul/matrix_mul_helper.h b/dnn/src/naive/matrix_mul/matrix_mul_helper.h index 7ea7606bfcf1efcc91ec754d01ddca2168bc02fc..843f754a57e3e941e74c6071c27983bee1f4a072 100644 --- a/dnn/src/naive/matrix_mul/matrix_mul_helper.h +++ b/dnn/src/naive/matrix_mul/matrix_mul_helper.h @@ -179,6 +179,42 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, C.compatible_ptr(), M, N, K, LDA, LDB, LDC, nA.layout.dtype, nB.layout.dtype); } +template +void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, + _megdnn_tensor_out C, + _megdnn_workspace workspace, + const param::MatrixMul& param) { + auto convert_layout = [](const TensorLayout& layout) { + auto ret = layout; + auto param = layout.dtype.param(); + ret.dtype = dtype::QuantizedS8(param.scale); + return ret; + }; + TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; + TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), + convert_layout(B.layout)}; + auto convert_4to8 = [](const TensorND& in, const TensorND& out) { + auto ptr = static_cast(in.raw_ptr) + in.layout.span().low_byte; + auto out_ptr = + out.compatible_ptr() + out.layout.span().low_byte; + for (size_t i = 0; i < in.layout.span().dist_elem(); i += 2) { + int8_t cur = ptr[i / 2]; + out_ptr[i] = cur << 4; + out_ptr[i] = out_ptr[i] >> 4; + out_ptr[i + 1] = cur >> 4; + } + }; + convert_4to8(A, nA); + convert_4to8(B, nB); + auto M = C.layout.shape[0], N = C.layout.shape[1]; + auto K = A.layout.shape[param.transposeA ? 0 : 1]; + auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], + LDC = C.layout.stride[0]; + run_matrix_mul_tpl( + nA.compatible_ptr(), nB.compatible_ptr(), + C.compatible_ptr(), M, N, K, LDA, LDB, LDC, + nA.layout.dtype, nB.layout.dtype); +} } // namespace naive } // namespace megdnn diff --git a/dnn/src/naive/matrix_mul/opr_impl.cpp b/dnn/src/naive/matrix_mul/opr_impl.cpp index 4195e40cdc4ae2572c50be96d76ccdbb0c514a08..fd5128e5082734f8487ee152d9987b2b64d94276 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/matrix_mul/opr_impl.cpp @@ -26,7 +26,8 @@ size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, MIDOUT_BEGIN( megdnn_naive_matmul, midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { - if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { + if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm || + A.dtype.enumv() == DTypeEnum::QuantizedS4) { return (A.span().dist_elem() + B.span().dist_elem()) * sizeof(uint8_t); } @@ -104,6 +105,11 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, param.format == param::MatrixMul::Format::DEFAULT) { exec_matrix_mul_quint4x4x32_helper(A, B, C, workspace, param); return; + } else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && + C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 && + param.format == param::MatrixMul::Format::DEFAULT) { + exec_matrix_mul_qint4x4x16_helper(A, B, C, workspace, param); + return; } #undef cb megdnn_throw(ssprintf( diff --git a/dnn/test/aarch64/matrix_mul.cpp b/dnn/test/aarch64/matrix_mul.cpp index bc9169bd5998e56c47f0ebb85acf803ceccb2d61..27404d044d7676014f49f088ba43f6a4eb8355e5 100644 --- a/dnn/test/aarch64/matrix_mul.cpp +++ b/dnn/test/aarch64/matrix_mul.cpp @@ -164,6 +164,55 @@ TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K4x4x16) { handle(), "AARCH64_INT8X8X16_K4X4X16"); } +TEST_F(AARCH64, MATRIX_MUL_INT4x4x16_K8x8x8_QUANTIZEDS4) { + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + + Checker checker(handle()); + checker.set_dtype(0, dtype::QuantizedS4{0.6}) + .set_dtype(1, dtype::QuantizedS4{0.5}) + .set_dtype(2, dtype::QuantizedS16{0.6 * 0.5}) + .set_param(param); + checker.set_before_exec_callback( + AlgoChecker("AARCH64_INT4X4X16_K8X8X8")); + + auto run = [&](size_t M, size_t N, size_t K) { + printf("M N K %zu %zu %zu \n", M, N, K); + TensorShape A, B; + if (param.transposeA) { + A = TensorShape{K, M}; + } else { + A = TensorShape{M, K}; + } + if (param.transposeB) { + B = TensorShape{N, K}; + } else { + B = TensorShape{K, N}; + } + checker.exec({A, B, {}}); + }; + + for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 20}) + for (size_t n : {2, 4, 6, 8, 10, 12, 14, 16, 24}) + for (size_t k : {2, 4, 6, 8, 10, 12, 14, 16, 32}) + run(m, n, k); + + for (size_t k = 4; k <= 256; k *= 8) { + for (size_t m = 4; m <= 256; m *= 4) { + for (size_t n = 4; n <= 256; n *= 4) { + run(m, n, k); + } + } + } + param.transposeA = true; + run(8,8,8); + run(16,8,16); + param.transposeB = true; + run(8,8,8); + run(16,16,16); +} + TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, handle(), "AARCH64_INT16X16X32_K12X8X1"); @@ -410,6 +459,63 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { run(384, 384, 384); } +TEST_F(AARCH64, BENCHMARK_4x4x16_vs_8x8x16) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_int4_4x4x16(handle()); + benchmarker_int4_4x4x16.set_times(RUNS) + .set_dtype(0, dtype::QuantizedS4{0.3}) + .set_dtype(1, dtype::QuantizedS4{0.3}) + .set_dtype(2, dtype::QuantizedS16{0.09}) + .set_param(param) + .set_display(false); + benchmarker.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + benchmarker.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X16_K4X4X16")); + + auto run = [&](size_t M, size_t N, size_t K) { + auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; + auto int4416_used = + benchmarker_int4_4x4x16.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} normal 8x8x16 used: %f ms %f " + "Gflops int4416 used %f int4416_gflops %f speedup %f\n", + M, K, N, default_used, computations / default_used, int4416_used, + computations / int4416_used, default_used / int4416_used); + }; + + for (int m = 32; m <= 1024; m += 32) + for (int n = 32; n <= 1024; n += 32) + for (int k = 32; k <= 512; k += 32) + run(m, n, k); + + run(32, 32, 32); + run(32, 32, 8); + run(32, 32, 16); + run(32, 32, 24); + run(32 * 2, 32 * 2, 32); + run(32 * 4, 32 * 4, 32); + run(32 * 6, 32 * 6, 32); + run(32 * 8, 32 * 8, 32); + run(32 * 2, 32 * 2, 32 * 2); + run(32 * 4, 32 * 4, 32 * 3); + run(32 * 6, 32 * 6, 32 * 4); + run(32 * 8, 32 * 8, 32 * 5); + run(32 * 10, 32 * 10, 32 * 10); + run(384, 384, 384); + run(256, 256, 384); + run(512, 512, 384); + run(1024, 1024, 384); +} + TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) { constexpr size_t RUNS = 50; param::MatrixMul param; diff --git a/dnn/test/common/rng.cpp b/dnn/test/common/rng.cpp index 10c7f213a169f411e230016c7acc9c8293ecfc48..f3e25f327f240706bfd810cba083daa17b25b343 100644 --- a/dnn/test/common/rng.cpp +++ b/dnn/test/common/rng.cpp @@ -183,6 +183,34 @@ void IIDRNG::gen(const TensorND& tensor) { } return; } + if (tensor.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { + auto ptr = static_cast(tensor.raw_ptr); + if (output_is_float()) { + for (size_t i = 0; i < nr_elems; i += 2) { + int8_t val0 = + tensor.layout.dtype.param() + .quantize(static_cast(gen_single_val())) + .as_int8(); + int8_t val1 = + tensor.layout.dtype.param() + .quantize(static_cast(gen_single_val())) + .as_int8(); + ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4); + } + } else { + for (size_t i = 0; i < nr_elems; i += 2) { + int8_t val0 = static_cast(gen_single_val()); + int8_t val1 = static_cast(gen_single_val()); + + val0 = std::min(val0,DTypeTrait::max()); + val0 = std::max(val0,DTypeTrait::min()); + val1 = std::min(val1,DTypeTrait::max()); + val1 = std::max(val1,DTypeTrait::min()); + ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4); + } + } + return; + } megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", tensor.layout.dtype.name()); } diff --git a/dnn/test/naive/matrix_mul.cpp b/dnn/test/naive/matrix_mul.cpp index 4bfa4596d8fd32e46f9fef7e8cc41882ddbdf79d..b708d7f730f2810d09c47a60603c84bf15271df6 100644 --- a/dnn/test/naive/matrix_mul.cpp +++ b/dnn/test/naive/matrix_mul.cpp @@ -203,6 +203,67 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZED4x4x32) { }); } +TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { + Checker checker(handle(), /* check_dispatch */ false); + auto GenTensorValueQuint4 = [](const TensorShape& shape, + dtype::QuantizedS4 dtype, + const std::vector& values) { + TensorND tensor; + tensor.layout = {shape, dtype}; + tensor.raw_ptr = + static_cast(malloc(tensor.layout.span().dist_byte())); + uint8_t* ptr = static_cast(tensor.raw_ptr); + megdnn_assert(values.size() == tensor.layout.span().dist_elem()); + for (size_t i = 0; i < tensor.layout.span().dist_elem(); i += 2) { + int val0 = values[i], val1 = values[i + 1]; + ptr[i / 2] =(val0 & 0xF) | (val1 << 4); + } + return tensor; + }; + using Param = MatrixMul::Param; + Param param; + checker.set_param(param); + checker.set_dtype(2, dtype::QuantizedS16(0.3f * 0.3f)); + checker.exect( + Testcase{ + GenTensorValueQuint4( + {8, 8}, dtype::QuantizedS4(0.3f), + {-8, 7, 2, 1, 2, 3, 2, 7, + 2, 5, 3, 3, 7, 4, -7, 1, + -5, 7, -4, -1, -1, 2, 4, 1, + 7, 2, -6, -2, -6, 3, 4, 4, + -2, 2, 3, 0, 6, 5, 3, 4, + -1, -1, -5, 5, 2, 5, 1, 4, + 6, 2, 0, 0, 3, 2, 2, 1, + -4, -3, 7, 5, 0, 3, 2, 3}), + GenTensorValueQuint4( + {8, 8}, dtype::QuantizedS4(0.3f), + {5, -8, -7, -6, 4, 7, -5, -5, + -4, 7, -3, -2, 5, 6, 4, 2, + 3, -1, 2, 2, 7, 3, 6, 0, + 5, 4, 0, 2, 2, 3, 3, 2, + 1, -8, -7, -6, 0, -5, -4, 4, + -3, 7, 1, 6, -2, 2, -1, 5, + 2, 0, 7, 6, 5, 4, 3, 2, + 0, 0, 1, 0, 5, 2, 2, 6}), + {}}, + Testcase{ + {}, + {}, + TensorValue( + {8, 8}, dtype::QuantizedS16(0.3f * 0.3f), + {-60, 120, 49, 58, 58, 13, 92, 125, + -5, 0, -116, -70, 22, 9, -14, 46, + -69, 111, 44, 48, 6, 19, 42, 57, + -8, 25, 10, 16, 26, 97, -28, -12, + -12, 14, 2, 26, 48, 7, 24, 93, + -2, 45, 2, 32, -19, -1, -16, 72, + 23, -44, -52, -34, 45, 53, -28, 6, + 33, 45, 71, 84, 47, 10, 74, 61}) + + }); +} + TEST_F(NAIVE, MATRIX_MUL_QUANTIZED8x8x32) { Checker checker(handle(), /* check_dispatch */ false); MatrixMul::Param param;