提交 86cf7490 编写于 作者: M Megvii Engine Team

feat(dnn/aarch64): add quantizeds4 matmul int4x4x16_k8x8x8

GitOrigin-RevId: 781290024466459c292cb46d7c97c5d39454525f
上级 bff0fc61
......@@ -88,6 +88,7 @@ enum class AlgoDataType : uint32_t {
QUINT8X8X32 = 1 << 3,
INT8X8X16 = 1 << 4,
INT16X16X32 = 1 << 5,
INT4X4X16 = 1 << 6,
};
/*!
......
......@@ -17,6 +17,7 @@
#include "src/aarch64/matrix_mul/int8/strategy.h"
#include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
#include "src/aarch64/matrix_mul/quint8/strategy.h"
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
......@@ -1394,4 +1395,75 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8,
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8,
int8_t, int16_t, AlgoDataType::INT8X8X16,
MK4);
/* ===================== Int4x4x16 K8x8x8 algo ===================== */
namespace {
void int4x4x16_k8x8x16_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int4x4x16_k8x8x8_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();
aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s4x4x16_s4_8x8x8>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS4 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16 &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
(kern_size_param.K & 1) == 0 && (kern_size_param.N & 1) == 0;
}
bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::preferred(
const KernSizeParam& kern_size_param) const {
MEGDNN_MARK_USED_VAR(kern_size_param);
return true;
}
size_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt4x4x16K8x8x8::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<matmul::gemm_s4x4x16_s4_8x8x8>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_kern(
const KernSizeParam&) const {
return int4x4x16_k8x8x16_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt4x4x16K8x8x8,
megdnn_aarch64_matmul_kern,
"AlgoInt4x4x16K8x8x8Impl"_hash,
aarch64::matmul::gemm_s4x4x16_s4_8x8x8,
int8_t, int16_t, AlgoDataType::INT4X4X16,
DEFAULT);
// vim: syntax=cpp.doxygen
......@@ -192,6 +192,19 @@ public:
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16)
};
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT4X4X16_K8X8X8)
};
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
......
......@@ -925,6 +925,42 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1,
: "v0", "v1", "v2", "v3", "memory");
}
template <typename T>
static inline void interleave_8x4_1_b_with_shift(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7,
T* outptr) {
static_assert(sizeof(T) == 1, "only support size == 1");
asm volatile(
"ld1 {v0.s}[0], [%[inptr0]], #4\n"
"ld1 {v0.s}[1], [%[inptr1]], #4\n"
"ld1 {v0.s}[2], [%[inptr2]], #4\n"
"ld1 {v0.s}[3], [%[inptr3]], #4\n"
"ld1 {v1.s}[0], [%[inptr4]], #4\n"
"ld1 {v1.s}[1], [%[inptr5]], #4\n"
"ld1 {v1.s}[2], [%[inptr6]], #4\n"
"ld1 {v1.s}[3], [%[inptr7]], #4\n"
"shl v2.16b, v0.16b, #4\n"
"shl v5.16b, v1.16b, #4\n"
"sshr v3.16b, v0.16b, #4\n" // hig
"sshr v4.16b, v2.16b, #4\n" // low
"sshr v6.16b, v1.16b, #4\n" // hig
"sshr v7.16b, v5.16b, #4\n" // low
"zip1 v8.16b, v4.16b, v3.16b\n"
"zip2 v9.16b, v4.16b, v3.16b\n"
"zip1 v10.16b, v7.16b, v6.16b\n"
"zip2 v11.16b, v7.16b, v6.16b\n"
"st1 {v8.16b-v11.16b},[%[outptr]],#64"
: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1),
[ inptr2 ] "+r"(inptr2), [ inptr3 ] "+r"(inptr3),
[ inptr4 ] "+r"(inptr4), [ inptr5 ] "+r"(inptr5),
[ inptr6 ] "+r"(inptr6), [ inptr7 ] "+r"(inptr7),
[ outptr ] "+r"(outptr)
:
: "v0", "v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","memory");
}
template <typename T>
static inline void interleave_8x8_1_h(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......@@ -1059,6 +1095,7 @@ static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1,
: "v0", "v1", "v2", "v3", "v4", "cc", "memory");
}
template <typename T>
static inline void interleave_4x16_1_s(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......@@ -1772,6 +1809,54 @@ static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1,
: "v0", "v1", "v2", "v3", "v4", "v5", "memory");
}
template <typename T>
static inline void transpose_4x8_1_b_with_shift(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
const T*& inptr4, const T*& inptr5,
const T*& inptr6, const T*& inptr7,
T*& outptr) {
static int8x16_t shuffle_idx = {0, 4, 8, 12, 1, 5, 9, 13,
2, 6, 10, 14, 3, 7, 11, 15};
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"transpose_8x4_1_b only support uint8_t and int8_t");
asm volatile(
"ld1 {v0.s}[0], [%[inptr0]], #4\n" // A1A2A3A4
"ld1 {v0.s}[1], [%[inptr1]], #4\n" // B1B2B3B4
"ld1 {v0.s}[2], [%[inptr2]], #4\n" // C1C2C3C4
"ld1 {v0.s}[3], [%[inptr3]], #4\n" // D1D2D3D4
"ld1 {v1.s}[0], [%[inptr4]], #4\n" // E1E2E3E4
"ld1 {v1.s}[1], [%[inptr5]], #4\n" // F1F2F3F4
"ld1 {v1.s}[2], [%[inptr6]], #4\n" // G1G2G3G4
"ld1 {v1.s}[3], [%[inptr7]], #4\n" // H1H2H3H4
"tbl v2.16b, {v0.16b}, %[shuffle_idx].16b \n" // A1B1C1D1A2B2C2D2A3B3C3D3A4B4C4D4
"tbl v3.16b, {v1.16b}, %[shuffle_idx].16b \n" // E1F1G1H1E2F2G2H2E3F3G3H3E4F4G4H4
"zip1 v4.4s, v2.4s, v3.4s\n" // A1B1C1D1E1F1G1H1 A2B2C2D2E2F2G2H2
"zip2 v5.4s, v2.4s, v3.4s\n" // A3B3C3D3E3F3G3H3 A4B4C4D4E4F4G4H4
"shl v6.16b, v4.16b, #4\n"
"sshr v7.16b, v4.16b, #4\n" // hig
"sshr v8.16b, v6.16b, #4\n" // low
"shl v9.16b, v5.16b, #4\n"
"sshr v10.16b, v5.16b, #4\n" // hig
"sshr v11.16b, v9.16b, #4\n" // low
"zip1 v0.2d,v8.2d,v7.2d\n"
"zip2 v1.2d,v8.2d,v7.2d\n"
"zip1 v2.2d,v11.2d,v10.2d\n"
"zip2 v3.2d,v11.2d,v10.2d\n"
"st1 {v0.2d-v3.2d},[%[outptr]],#64\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
[inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [shuffle_idx]"+w"(shuffle_idx),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","v8","v9","v10","v11","memory");
}
template <typename T>
static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......
/**
* \file dnn/src/aarch64/matrix_mul/int4x4x16/kernel_8x8x8.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <inttypes.h>
#include <cstring>
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace aarch64 {
namespace matmul_s4_4x4x16 {
/**
* Overview of register layout:
*
* +---------+---------+---------+---------+
* |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]|
* Rhs +---------+---------+---------+---------+
* Lhs | | |
*
* +--------+ - - - - +---------+---------+---------+---------+
* |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]|
* |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]|
* |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]|
* |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]|
* +--------+ - - - - +---------+---------+---------+---------+
*
* Accumulator
*/
static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// clang-format off
#define LOAD_LINE(reg_index, n) \
"cmp x8, #0 \n" \
"beq 105f\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ld1 {v" reg_index ".8h}, [x" n "], #16\n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"blt 101" n "f\n" \
"ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \
"cmp %w[n_remain], #3\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[3], [x" n "], #2\n" \
"cmp %w[n_remain], #4\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[4], [x" n "], #2\n" \
"cmp %w[n_remain], #5\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[5], [x" n "], #2\n" \
"cmp %w[n_remain], #6\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[6], [x" n "], #2\n" \
"101" n ":\n" \
"sub x8, x8, #1\n"
#define LOAD_C \
"mov x8, %x[m_remain]\n" \
LOAD_LINE("24", "0") \
LOAD_LINE("25", "1") \
LOAD_LINE("26", "2") \
LOAD_LINE("27", "3") \
LOAD_LINE("28", "4") \
LOAD_LINE("29", "5") \
LOAD_LINE("30", "6") \
LOAD_LINE("31", "7") \
"105:\n"
#define STORE_LINE(reg_index, n) \
"cmp x8, #0 \n" \
"beq 105f\n" \
"cmp %w[n_remain], #8\n" \
"blt 102" n "f\n" \
"st1 {v" reg_index ".8h}, [x" n "], #16\n" \
"b 103" n "f\n" \
"102" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[0], [x" n "], #2\n" \
"cmp %w[n_remain], #1\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[1], [x" n "], #2\n" \
"cmp %w[n_remain], #2\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[2], [x" n "], #2\n" \
"cmp %w[n_remain], #3\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[3], [x" n "], #2\n" \
"cmp %w[n_remain], #4\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[4], [x" n "], #2\n" \
"cmp %w[n_remain], #5\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[5], [x" n "], #2\n" \
"cmp %w[n_remain], #6\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[6], [x" n "], #2\n" \
"103" n ":\n" \
"sub x8, x8, #1\n"
#define STORE_C \
"mov x8, %x[m_remain]\n" \
STORE_LINE("24", "0") \
STORE_LINE("25", "1") \
STORE_LINE("26", "2") \
STORE_LINE("27", "3") \
STORE_LINE("28", "4") \
STORE_LINE("29", "5") \
STORE_LINE("30", "6") \
STORE_LINE("31", "7") \
"105:\n"
// clang-format on
register int16_t* outptr asm("x0") = output;
asm volatile(
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"
"add x4, x3, %x[LDC]\n"
"add x5, x4, %x[LDC]\n"
"add x6, x5, %x[LDC]\n"
"add x7, x6, %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 2f\n" LOAD_C
"b 1f\n"
"2:\n" // Clear the C regs.
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
// General loop.
"1:\n"
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"dup v0.8b,v20.b[0]\n"
"dup v1.8b,v20.b[1]\n"
"dup v2.8b,v20.b[2]\n"
"dup v3.8b,v20.b[3]\n"
"ld1 {v22.16b}, [%[a_ptr]],#16\n"
"ld1 {v23.16b}, [%[a_ptr]],#16\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v4.8b,v20.b[4]\n"
"dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v20.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v20.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v20.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v20.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v20.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v20.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v20.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v20.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v18.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v21.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v21.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v21.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v21.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v21.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v21.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v21.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v21.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v19.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v21.b[8]\n"
"smlal v24.8h, v0.8b, v18.8b\n"
"dup v9.8b,v21.b[9]\n"
"smlal v25.8h, v1.8b, v18.8b\n"
"dup v10.8b,v21.b[10]\n"
"smlal v26.8h, v2.8b, v18.8b\n"
"dup v11.8b,v21.b[11]\n"
"smlal v27.8h, v3.8b, v18.8b\n"
"dup v12.8b,v21.b[12]\n"
"smlal v28.8h, v4.8b, v18.8b\n"
"dup v13.8b,v21.b[13]\n"
"smlal v29.8h, v5.8b, v18.8b\n"
"dup v14.8b,v21.b[14]\n"
"smlal v30.8h, v6.8b, v18.8b\n"
"dup v15.8b,v21.b[15]\n"
"smlal v31.8h, v7.8b, v18.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v22.b[0]\n"
"smlal v24.8h, v8.8b, v19.8b\n"
"dup v1.8b,v22.b[1]\n"
"smlal v25.8h, v9.8b, v19.8b\n"
"dup v2.8b,v22.b[2]\n"
"smlal v26.8h, v10.8b, v19.8b\n"
"dup v3.8b,v22.b[3]\n"
"smlal v27.8h, v11.8b, v19.8b\n"
"dup v4.8b,v22.b[4]\n"
"smlal v28.8h, v12.8b, v19.8b\n"
"dup v5.8b,v22.b[5]\n"
"smlal v29.8h, v13.8b, v19.8b\n"
"dup v6.8b,v22.b[6]\n"
"smlal v30.8h, v14.8b, v19.8b\n"
"dup v7.8b,v22.b[7]\n"
"smlal v31.8h, v15.8b, v19.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v22.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v22.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v22.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v22.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v22.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v22.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v22.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v22.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v18.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v23.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v23.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v23.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v23.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v23.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v23.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v23.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v23.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v19.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v23.b[8]\n"
"smlal v24.8h, v0.8b, v18.8b\n"
"dup v9.8b,v23.b[9]\n"
"smlal v25.8h, v1.8b, v18.8b\n"
"dup v10.8b,v23.b[10]\n"
"smlal v26.8h, v2.8b, v18.8b\n"
"dup v11.8b,v23.b[11]\n"
"smlal v27.8h, v3.8b, v18.8b\n"
"dup v12.8b,v23.b[12]\n"
"smlal v28.8h, v4.8b, v18.8b\n"
"dup v13.8b,v23.b[13]\n"
"smlal v29.8h, v5.8b, v18.8b\n"
"dup v14.8b,v23.b[14]\n"
"smlal v30.8h, v6.8b, v18.8b\n"
"dup v15.8b,v23.b[15]\n"
"smlal v31.8h, v7.8b, v18.8b\n"
"smlal v24.8h, v8.8b, v19.8b\n"
"smlal v25.8h, v9.8b, v19.8b\n"
"smlal v26.8h, v10.8b, v19.8b\n"
"smlal v27.8h, v11.8b, v19.8b\n"
"smlal v28.8h, v12.8b, v19.8b\n"
"smlal v29.8h, v13.8b, v19.8b\n"
"smlal v30.8h, v14.8b, v19.8b\n"
"smlal v31.8h, v15.8b, v19.8b\n"
"subs %w[K], %w[K], #1\n"
"cbnz %w[K], 1b\n"
"3:\n"
// Store back into memory
STORE_C
:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// clang-format off
#define LOAD_C_8 \
"ld1 {v24.8h}, [x0], #16\n" \
"ld1 {v25.8h}, [x1], #16\n" \
"ld1 {v26.8h}, [x2], #16\n" \
"ld1 {v27.8h}, [x3], #16\n" \
"ld1 {v28.8h}, [x4], #16\n" \
"ld1 {v29.8h}, [x5], #16\n" \
"ld1 {v30.8h}, [x6], #16\n" \
"ld1 {v31.8h}, [x7], #16\n" \
#define STORE_C_8 \
"st1 {v24.8h}, [x0], #16\n" \
"st1 {v25.8h}, [x1], #16\n" \
"st1 {v26.8h}, [x2], #16\n" \
"st1 {v27.8h}, [x3], #16\n" \
"st1 {v28.8h}, [x4], #16\n" \
"st1 {v29.8h}, [x5], #16\n" \
"st1 {v30.8h}, [x6], #16\n" \
"st1 {v31.8h}, [x7], #16\n" \
// clang-format on
register int16_t* outptr asm("x0") = output;
asm volatile(
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"
"add x4, x3, %x[LDC]\n"
"add x5, x4, %x[LDC]\n"
"add x6, x5, %x[LDC]\n"
"add x7, x6, %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 2f\n" LOAD_C_8
"b 1f\n"
"2:\n" // Clear the C regs.
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
// General loop.
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
"1:\n"
// "ld1 {v20.16b}, [%[a_ptr]],#16\n"
// "ld1 {v21.16b}, [%[a_ptr]],#16\n"
"dup v0.8b,v20.b[0]\n"
"ld1 {v22.16b}, [%[a_ptr]],#16\n"
"dup v1.8b,v20.b[1]\n"
"ld1 {v23.16b}, [%[a_ptr]],#16\n"
"dup v2.8b,v20.b[2]\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v3.8b,v20.b[3]\n"
"dup v4.8b,v20.b[4]\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n"
"dup v8.8b,v20.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v20.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v20.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v20.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v20.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v20.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v20.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v20.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v21.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v21.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v21.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v21.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v21.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v21.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v21.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v21.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v21.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v21.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v21.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v21.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v21.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v21.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v21.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v21.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v22.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v22.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v22.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v22.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v22.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v22.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v22.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v22.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v22.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v22.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v22.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v22.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v22.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v22.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v22.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v22.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v16.8b}, [%[b_ptr]], 8\n"
"dup v0.8b,v23.b[0]\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"dup v1.8b,v23.b[1]\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"dup v2.8b,v23.b[2]\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"dup v3.8b,v23.b[3]\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"dup v4.8b,v23.b[4]\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"dup v5.8b,v23.b[5]\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"dup v6.8b,v23.b[6]\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"dup v7.8b,v23.b[7]\n"
"smlal v31.8h, v15.8b, v17.8b\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"
"dup v8.8b,v23.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
"dup v9.8b,v23.b[9]\n"
"smlal v25.8h, v1.8b, v16.8b\n"
"dup v10.8b,v23.b[10]\n"
"smlal v26.8h, v2.8b, v16.8b\n"
"dup v11.8b,v23.b[11]\n"
"smlal v27.8h, v3.8b, v16.8b\n"
"dup v12.8b,v23.b[12]\n"
"smlal v28.8h, v4.8b, v16.8b\n"
"dup v13.8b,v23.b[13]\n"
"smlal v29.8h, v5.8b, v16.8b\n"
"dup v14.8b,v23.b[14]\n"
"smlal v30.8h, v6.8b, v16.8b\n"
"dup v15.8b,v23.b[15]\n"
"smlal v31.8h, v7.8b, v16.8b\n"
"ld1 {v20.16b}, [%[a_ptr]],#16\n"
"smlal v24.8h, v8.8b, v17.8b\n"
"smlal v25.8h, v9.8b, v17.8b\n"
"smlal v26.8h, v10.8b, v17.8b\n"
"smlal v27.8h, v11.8b, v17.8b\n"
"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"smlal v28.8h, v12.8b, v17.8b\n"
"smlal v29.8h, v13.8b, v17.8b\n"
"smlal v30.8h, v14.8b, v17.8b\n"
"smlal v31.8h, v15.8b, v17.8b\n"
//"ld1 {v20.16b}, [%[a_ptr]],#16\n"
//"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"subs %w[K], %w[K], #1\n"
"cbnz %w[K], 1b\n"
"3:\n"
// Store back into memory
STORE_C_8
:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
//packa
static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[8];
int8_t tmpbuff0[8];
int8_t tmpbuff1[8];
int8_t tmpbuff2[8];
int8_t tmpbuff3[8];
int8_t tmpbuff4[8];
int8_t tmpbuff5[8];
int8_t tmpbuff6[8];
int8_t tmpbuff7[8];
std::memset(zerobuff, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff0, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff1, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff2, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff3, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff4, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff5, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8);
ldin /= 2;
int y = y0;
for (; y + 7 < ymax; y += 8) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
const int8_t* inptr4 = inptr3 + ldin;
const int8_t* inptr5 = inptr4 + ldin;
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
prefetch_2x(inptr4);
prefetch_2x(inptr5);
prefetch_2x(inptr6);
prefetch_2x(inptr7);
int K = (kmax - k0)/2;
//! read 4 * 16 in each row
for (; K > 3; K -= 4) {
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
}
if (K > 0) {
std::memcpy(tmpbuff0,inptr0,K);
std::memcpy(tmpbuff1,inptr1,K);
std::memcpy(tmpbuff2,inptr2,K);
std::memcpy(tmpbuff3,inptr3,K);
std::memcpy(tmpbuff4,inptr4,K);
std::memcpy(tmpbuff5,inptr5,K);
std::memcpy(tmpbuff6,inptr6,K);
std::memcpy(tmpbuff7,inptr7,K);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
inptr2 = tmpbuff2;
inptr3 = tmpbuff3;
inptr4 = tmpbuff4;
inptr5 = tmpbuff5;
inptr6 = tmpbuff6;
inptr7 = tmpbuff7;
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
}
}
for (; y < ymax; y += 8) {
const int8_t* inptr0 = inptr + y * ldin + k0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
const int8_t* inptr4 = inptr3 + ldin;
const int8_t* inptr5 = inptr4 + ldin;
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;
int K = (kmax - k0)/2;
//! read 4 * 16 in each row
for (; K > 3; K -= 4) {
if (y + 7 >= ymax) {
switch (y + 7 - ymax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
}
if (K > 0) {
if (y + 7 >= ymax) {
switch (y + 7 - ymax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
std::memcpy(tmpbuff0,inptr0,K);
std::memcpy(tmpbuff1,inptr1,K);
std::memcpy(tmpbuff2,inptr2,K);
std::memcpy(tmpbuff3,inptr3,K);
std::memcpy(tmpbuff4,inptr4,K);
std::memcpy(tmpbuff5,inptr5,K);
std::memcpy(tmpbuff6,inptr6,K);
std::memcpy(tmpbuff7,inptr7,K);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
inptr2 = tmpbuff2;
inptr3 = tmpbuff3;
inptr4 = tmpbuff4;
inptr5 = tmpbuff5;
inptr6 = tmpbuff6;
inptr7 = tmpbuff7;
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
}
}
}
//packb
static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[8];
int8_t tmpbuff0[8];
int8_t tmpbuff1[8];
int8_t tmpbuff2[8];
int8_t tmpbuff3[8];
int8_t tmpbuff4[8];
int8_t tmpbuff5[8];
int8_t tmpbuff6[8];
int8_t tmpbuff7[8];
std::memset(zerobuff, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff0, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff1, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff2, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff3, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff4, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff5, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8);
const int ksize = kmax - k0;
const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4
int8_t* outptr = out;
int8_t* outptr_interleave = nullptr;
int k = k0;
ldin /= 2;
xmax = xmax / 2;
for (; k + 7 < kmax; k += 8) {
const int8_t* inptr0 = in + k * ldin + x0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
const int8_t* inptr4 = inptr3 + ldin;
const int8_t* inptr5 = inptr4 + ldin;
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
prefetch_2x(inptr2);
prefetch_2x(inptr3);
prefetch_2x(inptr4);
prefetch_2x(inptr5);
prefetch_2x(inptr6);
prefetch_2x(inptr7);
int x = x0;
int8_t* outptr_inner = outptr;
for (; x + 3 < xmax; x += 4) {
outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
outptr_inner += ksize8;
}
if (x < xmax) {
int remainx = xmax - x;
std::memcpy(tmpbuff0,inptr0,remainx);
std::memcpy(tmpbuff1,inptr1,remainx);
std::memcpy(tmpbuff2,inptr2,remainx);
std::memcpy(tmpbuff3,inptr3,remainx);
std::memcpy(tmpbuff4,inptr4,remainx);
std::memcpy(tmpbuff5,inptr5,remainx);
std::memcpy(tmpbuff6,inptr6,remainx);
std::memcpy(tmpbuff7,inptr7,remainx);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
inptr2 = tmpbuff2;
inptr3 = tmpbuff3;
inptr4 = tmpbuff4;
inptr5 = tmpbuff5;
inptr6 = tmpbuff6;
inptr7 = tmpbuff7;
outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
outptr_inner += ksize8;
}
outptr += 64;
}
if (k < kmax) {
const int8_t* inptr0 = in + k * ldin + x0;
const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin;
const int8_t* inptr4 = inptr3 + ldin;
const int8_t* inptr5 = inptr4 + ldin;
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;
int k_remain = kmax - k - 1;
int x = x0;
int8_t* outptr_inner = outptr;
for (; x + 3 < xmax; x += 4) {
switch (k_remain) {
case 0:
inptr1 = zerobuff;
MEGDNN_FALLTHRU;
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU;
case 2:
inptr3 = zerobuff;
MEGDNN_FALLTHRU;
case 3:
inptr4 = zerobuff;
MEGDNN_FALLTHRU;
case 4:
inptr5 = zerobuff;
MEGDNN_FALLTHRU;
case 5:
inptr6 = zerobuff;
MEGDNN_FALLTHRU;
case 6:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
break;
}
outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
outptr_inner += ksize8;
}
if (x < xmax) {
switch (k_remain) {
case 0:
inptr1 = zerobuff;
MEGDNN_FALLTHRU;
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU;
case 2:
inptr3 = zerobuff;
MEGDNN_FALLTHRU;
case 3:
inptr4 = zerobuff;
MEGDNN_FALLTHRU;
case 4:
inptr5 = zerobuff;
MEGDNN_FALLTHRU;
case 5:
inptr6 = zerobuff;
MEGDNN_FALLTHRU;
case 6:
inptr7 = zerobuff;
break;
default:
megdnn_assert(0);
break;
}
int remainx = xmax - x;
outptr_interleave = outptr_inner;
std::memcpy(tmpbuff0,inptr0,remainx);
std::memcpy(tmpbuff1,inptr1,remainx);
std::memcpy(tmpbuff2,inptr2,remainx);
std::memcpy(tmpbuff3,inptr3,remainx);
std::memcpy(tmpbuff4,inptr4,remainx);
std::memcpy(tmpbuff5,inptr5,remainx);
std::memcpy(tmpbuff6,inptr6,remainx);
std::memcpy(tmpbuff7,inptr7,remainx);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
inptr2 = tmpbuff2;
inptr3 = tmpbuff3;
inptr4 = tmpbuff4;
inptr5 = tmpbuff5;
inptr6 = tmpbuff6;
inptr7 = tmpbuff7;
outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
outptr_inner += ksize8;
}
}
}
} // namespace matmul_4x4x16
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h"
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;
// ===========================gemm_s4x4x16_s4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8);
void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0,
kmax);
} else {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0,
kmax);
}
}
void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0,
kmax);
} else {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0,
kmax);
}
}
void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::QuantizedS4 &&
C_dtype.enumv() == DTypeEnum::QuantizedS16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
//! K is packed to times of 8
K = round_up<size_t>(K, 8);
const int K8 = K * 8;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int16_t* output = C + (m * LDC);
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE, B_INTERLEAVE);
output += B_INTERLEAVE;
cur_packB += K8;
}
for (; n < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE;
cur_packB += K8;
}
packA += K8;
}
for (; m < M; m += A_INTERLEAVE) {
int16_t* output = C + (m * LDC);
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE;
cur_packB += K8;
}
packA += K8;
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace aarch64 {
namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true,
gemm_s4x4x16_s4_8x8x8);
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -50,6 +50,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#else
AlgoQuint8K8x8x8 quint8_k8x8x8;
#endif
AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
......@@ -87,6 +88,7 @@ public:
#else
m_all_algos.emplace_back(&quint8_k8x8x8);
#endif
m_all_algos.emplace_back(&int4x4x16_k8x8x8);
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
......
......@@ -66,8 +66,8 @@ private:
#else
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
#endif
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16
class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16
class AlgoPack;
public:
static const AlgoPack& algo_pack();
......
......@@ -33,6 +33,8 @@ void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) {
C_candi = dtype::QuantizedS32(mul_scale(A, B));
} else if (A.enumv() == DTypeEnum::Quantized4Asymm) {
C_candi = dtype::QuantizedS32(mul_scale(A, B));
} else if (A.enumv() == DTypeEnum::QuantizedS4) {
C_candi = dtype::QuantizedS16(mul_scale(A, B));
}
if (!C.valid()) {
C = C_candi;
......@@ -169,6 +171,8 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B,
A.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
A.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS32);
} else if(A.dtype.enumv() == DTypeEnum::QuantizedS4){
megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS16);
}
megdnn_assert(param().compute_mode !=
Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16(
......
......@@ -154,6 +154,7 @@ public:
AARCH64_QUINT8_K8X8X4_DOTPROD,
AARCH64_QUINT8_GEMV_DOTPROD,
AARCH64_QUINT8_K8X8X8,
AARCH64_INT4X4X16_K8X8X8,
#else
ARMV7_F32 = 1 << 16,
ARMV7_F32_MK4_PACK_4X12,
......
......@@ -179,6 +179,42 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A,
C.compatible_ptr<dt_int32>(), M, N, K, LDA, LDB, LDC,
nA.layout.dtype, nB.layout.dtype);
}
template <bool transA, bool transB>
void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace,
const param::MatrixMul& param) {
auto convert_layout = [](const TensorLayout& layout) {
auto ret = layout;
auto param = layout.dtype.param<dtype::QuantizedS4>();
ret.dtype = dtype::QuantizedS8(param.scale);
return ret;
};
TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)};
TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(),
convert_layout(B.layout)};
auto convert_4to8 = [](const TensorND& in, const TensorND& out) {
auto ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte;
auto out_ptr =
out.compatible_ptr<int8_t>() + out.layout.span().low_byte;
for (size_t i = 0; i < in.layout.span().dist_elem(); i += 2) {
int8_t cur = ptr[i / 2];
out_ptr[i] = cur << 4;
out_ptr[i] = out_ptr[i] >> 4;
out_ptr[i + 1] = cur >> 4;
}
};
convert_4to8(A, nA);
convert_4to8(B, nB);
auto M = C.layout.shape[0], N = C.layout.shape[1];
auto K = A.layout.shape[param.transposeA ? 0 : 1];
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0],
LDC = C.layout.stride[0];
run_matrix_mul_tpl<int8_t, dt_int16, transA, transB, dt_int16>(
nA.compatible_ptr<int8_t>(), nB.compatible_ptr<int8_t>(),
C.compatible_ptr<dt_int16>(), M, N, K, LDA, LDB, LDC,
nA.layout.dtype, nB.layout.dtype);
}
} // namespace naive
} // namespace megdnn
......
......@@ -26,7 +26,8 @@ size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
MIDOUT_BEGIN(
megdnn_naive_matmul,
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) {
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm ||
A.dtype.enumv() == DTypeEnum::QuantizedS4) {
return (A.span().dist_elem() + B.span().dist_elem()) *
sizeof(uint8_t);
}
......@@ -104,6 +105,11 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B,
param.format == param::MatrixMul::Format::DEFAULT) {
exec_matrix_mul_quint4x4x32_helper<TA, TB>(A, B, C, workspace, param);
return;
} else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 &&
param.format == param::MatrixMul::Format::DEFAULT) {
exec_matrix_mul_qint4x4x16_helper<TA, TB>(A, B, C, workspace, param);
return;
}
#undef cb
megdnn_throw(ssprintf(
......
......@@ -164,6 +164,55 @@ TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K4x4x16) {
handle(), "AARCH64_INT8X8X16_K4X4X16");
}
TEST_F(AARCH64, MATRIX_MUL_INT4x4x16_K8x8x8_QUANTIZEDS4) {
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
Checker<MatrixMul> checker(handle());
checker.set_dtype(0, dtype::QuantizedS4{0.6})
.set_dtype(1, dtype::QuantizedS4{0.5})
.set_dtype(2, dtype::QuantizedS16{0.6 * 0.5})
.set_param(param);
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT4X4X16_K8X8X8"));
auto run = [&](size_t M, size_t N, size_t K) {
printf("M N K %zu %zu %zu \n", M, N, K);
TensorShape A, B;
if (param.transposeA) {
A = TensorShape{K, M};
} else {
A = TensorShape{M, K};
}
if (param.transposeB) {
B = TensorShape{N, K};
} else {
B = TensorShape{K, N};
}
checker.exec({A, B, {}});
};
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 20})
for (size_t n : {2, 4, 6, 8, 10, 12, 14, 16, 24})
for (size_t k : {2, 4, 6, 8, 10, 12, 14, 16, 32})
run(m, n, k);
for (size_t k = 4; k <= 256; k *= 8) {
for (size_t m = 4; m <= 256; m *= 4) {
for (size_t n = 4; n <= 256; n *= 4) {
run(m, n, k);
}
}
}
param.transposeA = true;
run(8,8,8);
run(16,8,16);
param.transposeB = true;
run(8,8,8);
run(16,16,16);
}
TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) {
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
handle(), "AARCH64_INT16X16X32_K12X8X1");
......@@ -410,6 +459,63 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) {
run(384, 384, 384);
}
TEST_F(AARCH64, BENCHMARK_4x4x16_vs_8x8x16) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
Benchmarker<MatrixMul> benchmarker(handle());
Benchmarker<MatrixMul> benchmarker_int4_4x4x16(handle());
benchmarker_int4_4x4x16.set_times(RUNS)
.set_dtype(0, dtype::QuantizedS4{0.3})
.set_dtype(1, dtype::QuantizedS4{0.3})
.set_dtype(2, dtype::QuantizedS16{0.09})
.set_param(param)
.set_display(false);
benchmarker.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
benchmarker.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16"));
auto run = [&](size_t M, size_t N, size_t K) {
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
auto int4416_used =
benchmarker_int4_4x4x16.exec({{M, K}, {K, N}, {}}) / RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} normal 8x8x16 used: %f ms %f "
"Gflops int4416 used %f int4416_gflops %f speedup %f\n",
M, K, N, default_used, computations / default_used, int4416_used,
computations / int4416_used, default_used / int4416_used);
};
for (int m = 32; m <= 1024; m += 32)
for (int n = 32; n <= 1024; n += 32)
for (int k = 32; k <= 512; k += 32)
run(m, n, k);
run(32, 32, 32);
run(32, 32, 8);
run(32, 32, 16);
run(32, 32, 24);
run(32 * 2, 32 * 2, 32);
run(32 * 4, 32 * 4, 32);
run(32 * 6, 32 * 6, 32);
run(32 * 8, 32 * 8, 32);
run(32 * 2, 32 * 2, 32 * 2);
run(32 * 4, 32 * 4, 32 * 3);
run(32 * 6, 32 * 6, 32 * 4);
run(32 * 8, 32 * 8, 32 * 5);
run(32 * 10, 32 * 10, 32 * 10);
run(384, 384, 384);
run(256, 256, 384);
run(512, 512, 384);
run(1024, 1024, 384);
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
......
......@@ -183,6 +183,34 @@ void IIDRNG::gen(const TensorND& tensor) {
}
return;
}
if (tensor.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
auto ptr = static_cast<int8_t*>(tensor.raw_ptr);
if (output_is_float()) {
for (size_t i = 0; i < nr_elems; i += 2) {
int8_t val0 =
tensor.layout.dtype.param<dt_qint4>()
.quantize(static_cast<float>(gen_single_val()))
.as_int8();
int8_t val1 =
tensor.layout.dtype.param<dt_qint4>()
.quantize(static_cast<float>(gen_single_val()))
.as_int8();
ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4);
}
} else {
for (size_t i = 0; i < nr_elems; i += 2) {
int8_t val0 = static_cast<int8_t>(gen_single_val());
int8_t val1 = static_cast<int8_t>(gen_single_val());
val0 = std::min(val0,DTypeTrait<dtype::QuantizedS4>::max());
val0 = std::max(val0,DTypeTrait<dtype::QuantizedS4>::min());
val1 = std::min(val1,DTypeTrait<dtype::QuantizedS4>::max());
val1 = std::max(val1,DTypeTrait<dtype::QuantizedS4>::min());
ptr[(offset + i) / 2] = (val0 & 0xF) | (val1 << 4);
}
}
return;
}
megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s",
tensor.layout.dtype.name());
}
......
......@@ -203,6 +203,67 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZED4x4x32) {
});
}
TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) {
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false);
auto GenTensorValueQuint4 = [](const TensorShape& shape,
dtype::QuantizedS4 dtype,
const std::vector<int>& values) {
TensorND tensor;
tensor.layout = {shape, dtype};
tensor.raw_ptr =
static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte()));
uint8_t* ptr = static_cast<uint8_t*>(tensor.raw_ptr);
megdnn_assert(values.size() == tensor.layout.span().dist_elem());
for (size_t i = 0; i < tensor.layout.span().dist_elem(); i += 2) {
int val0 = values[i], val1 = values[i + 1];
ptr[i / 2] =(val0 & 0xF) | (val1 << 4);
}
return tensor;
};
using Param = MatrixMul::Param;
Param param;
checker.set_param(param);
checker.set_dtype(2, dtype::QuantizedS16(0.3f * 0.3f));
checker.exect(
Testcase{
GenTensorValueQuint4(
{8, 8}, dtype::QuantizedS4(0.3f),
{-8, 7, 2, 1, 2, 3, 2, 7,
2, 5, 3, 3, 7, 4, -7, 1,
-5, 7, -4, -1, -1, 2, 4, 1,
7, 2, -6, -2, -6, 3, 4, 4,
-2, 2, 3, 0, 6, 5, 3, 4,
-1, -1, -5, 5, 2, 5, 1, 4,
6, 2, 0, 0, 3, 2, 2, 1,
-4, -3, 7, 5, 0, 3, 2, 3}),
GenTensorValueQuint4(
{8, 8}, dtype::QuantizedS4(0.3f),
{5, -8, -7, -6, 4, 7, -5, -5,
-4, 7, -3, -2, 5, 6, 4, 2,
3, -1, 2, 2, 7, 3, 6, 0,
5, 4, 0, 2, 2, 3, 3, 2,
1, -8, -7, -6, 0, -5, -4, 4,
-3, 7, 1, 6, -2, 2, -1, 5,
2, 0, 7, 6, 5, 4, 3, 2,
0, 0, 1, 0, 5, 2, 2, 6}),
{}},
Testcase{
{},
{},
TensorValue(
{8, 8}, dtype::QuantizedS16(0.3f * 0.3f),
{-60, 120, 49, 58, 58, 13, 92, 125,
-5, 0, -116, -70, 22, 9, -14, 46,
-69, 111, 44, 48, 6, 19, 42, 57,
-8, 25, 10, 16, 26, 97, -28, -12,
-12, 14, 2, 26, 48, 7, 24, 93,
-2, 45, 2, 32, -19, -1, -16, 72,
23, -44, -52, -34, 45, 53, -28, 6,
33, 45, 71, 84, 47, 10, 74, 61})
});
}
TEST_F(NAIVE, MATRIX_MUL_QUANTIZED8x8x32) {
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false);
MatrixMul::Param param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册