提交 07d1d0ab 编写于 作者: M Megvii Engine Team

feat(dnn/arm64): add fp32 mk4 matmul

GitOrigin-RevId: f6df006547e08ba5b76be984a2fe87cf053c31de
上级 7ba641fe
......@@ -86,6 +86,67 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
"AlgoF32K8x12x1Impl"_hash,
aarch64::matmul::sgemm_8x12, float, float);
/* ===================== F32_MK4_8X12X1 algo ===================== */
bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}
size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoF32MK4_8x12x1::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern(
const KernSizeParam&) const {
auto f32_kern_mk4_8x12 = [](const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoF32MK4_8x12x1::get_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB,
LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<float>(),
Bptr = kern_param.B<float>();
auto Cptr = kern_param.C<float>();
aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
};
return f32_kern_mk4_8x12;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1,
megdnn_aarch64_matmul_kern,
"AlgoF32MK4_8x12x1Impl"_hash,
aarch64::matmul::sgemm_mk4_8x12, float,
float);
/* ===================== F32K4X16X1 algo ===================== */
bool MatrixMulImpl::AlgoF32K4x16x1::usable(
......
......@@ -29,6 +29,17 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
......
......@@ -1103,6 +1103,36 @@ static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1,
: "v0", "v1", "v2", "v3", "cc", "memory");
}
template <typename T>
static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1,
T* outptr) {
static_assert(sizeof(T) == 4, "interleave_2x4_4_s only support size == 4");
asm volatile(
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr1]], #64\n"
"stp q0, q4, [%[outptr]]\n"
"stp q1, q5, [%[outptr], #32]\n"
"stp q2, q6, [%[outptr], #64]\n"
"stp q3, q7, [%[outptr], #96]\n"
: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1),
[ outptr ] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory");
}
template <typename T>
static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4, "interleave_1x4_4_s only support size == 4");
asm volatile(
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "memory");
}
template <typename T>
static inline void interleave_4x8_2_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......@@ -1479,6 +1509,41 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1,
"v11", "memory");
}
template <typename T>
static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4,
"transpose_1x12_4_s only support sizeof(T) == 4");
asm volatile(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr0]], #64\n"
"ld4 {v8.4s, v9.4s, v10.4s, v11.4s},[%[inptr0]], #64\n"
"stp q0, q4, [%[outptr]] \n"
"stp q8, q1, [%[outptr], #32] \n"
"stp q5, q9, [%[outptr], #64] \n"
"stp q2, q6, [%[outptr], #96] \n"
"stp q10, q3, [%[outptr], #128] \n"
"stp q7, q11, [%[outptr], #160] \n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "memory");
}
template <typename T>
static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4,
"transpose_1x4_4_s only support sizeof(T) == 4");
asm volatile(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "memory");
}
template <typename T>
static inline void transpose_8x4_1_s(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......
......@@ -899,6 +899,10 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output,
:
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2",
"x3", "x10", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, int y0,
......
/**
* \file dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace aarch64 {
namespace matmul_mk4_8x12 {
// Overview of register layout:
//
// A 1x12 cell of Rhs is stored in 32bit in v2-v7
// A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
// A 8x12 block of accumulators is stored in 32bit in v8-v31.
//
// +--------+--------+--------+
// | v2[0-3]| v3[0-3]| v4[0-3]|
// | v5[0-3]| v6[0-3]| v7[0-3]|
// Rhs +--------+--------+--------+
//
// | | | |
//
// Lhs | | | |
//
// +--+ --- - +--------+--------+--------+
// |v0| | v8[0-3]| v9[0-3]|v10[0-3]|
// |v0| |v11[0-3]|v12[0-3]|v13[0-3]|
// |v0| |v14[0-3]|v15[0-3]|v16[0-3]|
// |v0| |v17[0-3]|v18[0-3]|v19[0-3]|
// |v1| |v20[0-3]|v21[0-3]|v22[0-3]|
// |v1| |v23[0-3]|v24[0-3]|v25[0-3]|
// |v1| |v26[0-3]|v27[0-3]|v28[0-3]|
// |v1| |v29[0-3]|v30[0-3]|v31[0-3]|
// +--+ --- - +--------+--------+--------+
//
// Accumulator
void kern_8x12(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k) {
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
float* output1 = output0 + LDC;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
asm volatile(
"cmp %w[is_first_k], #1\n"
"beq 1f\n"
"mov x1, %[output0]\n"
"mov x2, %[output1]\n"
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n"
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n"
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n"
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64\n"
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64\n"
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n"
"b 2f\n"
"1:\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"prfm pstl1keep, [%[output0]]\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"prfm pstl1keep, [%[output1]]\n"
"eor v14.16b, v14.16b, v14.16b\n"
"eor v15.16b, v15.16b, v15.16b\n"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v20.16b, v20.16b\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"eor v21.16b, v21.16b, v21.16b\n"
"eor v22.16b, v22.16b, v22.16b\n"
"eor v23.16b, v23.16b, v23.16b\n"
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
"2: \n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v12.4s, v0.4s, v3.s[0]\n"
"fmla v13.4s, v0.4s, v3.s[1]\n"
"fmla v14.4s, v0.4s, v3.s[2]\n"
"fmla v15.4s, v0.4s, v3.s[3]\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v17.4s, v0.4s, v4.s[1]\n"
"fmla v18.4s, v0.4s, v4.s[2]\n"
"fmla v19.4s, v0.4s, v4.s[3]\n"
"fmla v20.4s, v1.4s, v2.s[0]\n"
"fmla v21.4s, v1.4s, v2.s[1]\n"
"fmla v22.4s, v1.4s, v2.s[2]\n"
"ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
"fmla v24.4s, v1.4s, v3.s[0]\n"
"fmla v25.4s, v1.4s, v3.s[1]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v26.4s, v1.4s, v3.s[2]\n"
"fmla v27.4s, v1.4s, v3.s[3]\n"
"fmla v28.4s, v1.4s, v4.s[0]\n"
"fmla v29.4s, v1.4s, v4.s[1]\n"
"fmla v30.4s, v1.4s, v4.s[2]\n"
"fmla v31.4s, v1.4s, v4.s[3]\n"
"fmla v8.4s, v0.4s, v5.s[0]\n"
"fmla v9.4s, v0.4s, v5.s[1]\n"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n"
"fmla v10.4s, v0.4s, v5.s[2]\n"
"fmla v11.4s, v0.4s, v5.s[3]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v12.4s, v0.4s, v6.s[0]\n"
"fmla v13.4s, v0.4s, v6.s[1]\n"
"fmla v14.4s, v0.4s, v6.s[2]\n"
"fmla v15.4s, v0.4s, v6.s[3]\n"
"fmla v16.4s, v0.4s, v7.s[0]\n"
"fmla v17.4s, v0.4s, v7.s[1]\n"
"fmla v18.4s, v0.4s, v7.s[2]\n"
"fmla v19.4s, v0.4s, v7.s[3]\n"
"fmla v20.4s, v1.4s, v5.s[0]\n"
"fmla v21.4s, v1.4s, v5.s[1]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v22.4s, v1.4s, v5.s[2]\n"
"fmla v23.4s, v1.4s, v5.s[3]\n"
"fmla v24.4s, v1.4s, v6.s[0]\n"
"subs %w[K], %w[K], #1\n"
"fmla v25.4s, v1.4s, v6.s[1]\n"
"fmla v26.4s, v1.4s, v6.s[2]\n"
"fmla v27.4s, v1.4s, v6.s[3]\n"
"fmla v28.4s, v1.4s, v7.s[0]\n"
"fmla v29.4s, v1.4s, v7.s[1]\n"
"fmla v30.4s, v1.4s, v7.s[2]\n"
"fmla v31.4s, v1.4s, v7.s[3]\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail
"fmla v8.4s, v0.4s, v2.s[0]\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v12.4s, v0.4s, v3.s[0]\n"
"fmla v13.4s, v0.4s, v3.s[1]\n"
"fmla v14.4s, v0.4s, v3.s[2]\n"
"fmla v15.4s, v0.4s, v3.s[3]\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v17.4s, v0.4s, v4.s[1]\n"
"fmla v18.4s, v0.4s, v4.s[2]\n"
"fmla v19.4s, v0.4s, v4.s[3]\n"
"fmla v20.4s, v1.4s, v2.s[0]\n"
"fmla v21.4s, v1.4s, v2.s[1]\n"
"fmla v22.4s, v1.4s, v2.s[2]\n"
"ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
"fmla v24.4s, v1.4s, v3.s[0]\n"
"fmla v25.4s, v1.4s, v3.s[1]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v26.4s, v1.4s, v3.s[2]\n"
"fmla v27.4s, v1.4s, v3.s[3]\n"
"fmla v28.4s, v1.4s, v4.s[0]\n"
"fmla v29.4s, v1.4s, v4.s[1]\n"
"fmla v30.4s, v1.4s, v4.s[2]\n"
"fmla v31.4s, v1.4s, v4.s[3]\n"
"fmla v8.4s, v0.4s, v5.s[0]\n"
"fmla v9.4s, v0.4s, v5.s[1]\n"
"fmla v10.4s, v0.4s, v5.s[2]\n"
"fmla v11.4s, v0.4s, v5.s[3]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v12.4s, v0.4s, v6.s[0]\n"
"fmla v13.4s, v0.4s, v6.s[1]\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n"
"fmla v14.4s, v0.4s, v6.s[2]\n"
"fmla v15.4s, v0.4s, v6.s[3]\n"
"fmla v16.4s, v0.4s, v7.s[0]\n"
"fmla v17.4s, v0.4s, v7.s[1]\n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n"
"fmla v18.4s, v0.4s, v7.s[2]\n"
"fmla v19.4s, v0.4s, v7.s[3]\n"
"fmla v20.4s, v1.4s, v5.s[0]\n"
"fmla v21.4s, v1.4s, v5.s[1]\n"
"fmla v22.4s, v1.4s, v5.s[2]\n"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n"
"fmla v23.4s, v1.4s, v5.s[3]\n"
"fmla v24.4s, v1.4s, v6.s[0]\n"
"fmla v25.4s, v1.4s, v6.s[1]\n"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output1]], #64\n"
"fmla v26.4s, v1.4s, v6.s[2]\n"
"fmla v27.4s, v1.4s, v6.s[3]\n"
"fmla v28.4s, v1.4s, v7.s[0]\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n"
"fmla v29.4s, v1.4s, v7.s[1]\n"
"fmla v30.4s, v1.4s, v7.s[2]\n"
"fmla v31.4s, v1.4s, v7.s[3]\n"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n"
"b 6f\n"
// odd tail
"5:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v12.4s, v0.4s, v3.s[0]\n"
"fmla v13.4s, v0.4s, v3.s[1]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v14.4s, v0.4s, v3.s[2]\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n"
"fmla v15.4s, v0.4s, v3.s[3]\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v17.4s, v0.4s, v4.s[1]\n"
"fmla v18.4s, v0.4s, v4.s[2]\n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n"
"fmla v19.4s, v0.4s, v4.s[3]\n"
"fmla v20.4s, v1.4s, v2.s[0]\n"
"fmla v21.4s, v1.4s, v2.s[1]\n"
"fmla v22.4s, v1.4s, v2.s[2]\n"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
"fmla v24.4s, v1.4s, v3.s[0]\n"
"fmla v25.4s, v1.4s, v3.s[1]\n"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output1]], #64\n"
"fmla v26.4s, v1.4s, v3.s[2]\n"
"fmla v27.4s, v1.4s, v3.s[3]\n"
"fmla v28.4s, v1.4s, v4.s[0]\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n"
"fmla v29.4s, v1.4s, v4.s[1]\n"
"fmla v30.4s, v1.4s, v4.s[2]\n"
"fmla v31.4s, v1.4s, v4.s[3]\n"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n"
"6:\n"
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K),
[ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk),
[ output0 ] "+r"(output0), [ output1 ] "+r"(output1)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x1", "x2", "cc", "memory");
}
// Overview of register layout:
//
// A 1x12 cell of Rhs is stored in 32bit in v2-v7
// A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
// A 8x12 block of accumulators is stored in 32bit in v8-v31.
//
// +--------+
// | v2[0-3]|
// | v3[0-3]|
// Rhs +--------+
//
// | |
//
// Lhs | |
//
// +--+ --- - +--------+
// |v0| | v8[0-3]|
// |v0| |v11[0-3]|
// |v0| |v14[0-3]|
// |v0| |v17[0-3]|
// |v1| |v20[0-3]|
// |v1| |v23[0-3]|
// |v1| |v26[0-3]|
// |v1| |v29[0-3]|
// +--+ --- - +--------+
//
// Accumulator
void kern_8x4(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
float* output1 = output0 + LDC;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
//clang-format off
#define LOAD_C \
"cmp %w[n_remain], #4\n" \
"blt 11f\n" \
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \
"b 14f\n" \
"11:\n" \
"cmp %w[n_remain], #3\n" \
"blt 12f\n" \
"ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"ld1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \
"b 14f\n" \
"12:\n" \
"cmp %w[n_remain], #2\n" \
"blt 13f\n" \
"ld1 {v8.4s, v9.4s}, [%[output0]]\n" \
"ld1 {v12.4s, v13.4s},[%[output1]]\n" \
"b 14f\n" \
"13:\n" \
"ld1 {v8.4s}, [%[output0]]\n" \
"ld1 {v12.4s},[%[output1]]\n" \
"14:\n"
#define STORE_C \
"cmp %w[n_remain], #4\n" \
"blt 21f\n" \
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
"st1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \
"b 24f\n" \
"21:\n" \
"cmp %w[n_remain], #3\n" \
"blt 22f\n" \
"st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"st1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \
"b 23f\n" \
"22:\n" \
"cmp %w[n_remain], #2\n" \
"blt 23f\n" \
"st1 {v8.4s, v9.4s}, [%[output0]]\n" \
"st1 {v12.4s, v13.4s},[%[output1]]\n" \
"b 24f\n" \
"23:\n" \
"st1 {v8.4s}, [%[output0]]\n" \
"st1 {v12.4s},[%[output1]]\n" \
"24:\n"
//clang-format on
asm volatile(
// load accumulator C
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"ld1 {v2.4s}, [%[b_ptr]], #16\n"
"b 2f\n"
"1:\n"
"eor v8.16b, v8.16b, v8.16b\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"prfm pstl1keep, [%[output0]]\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"prfm pstl1keep, [%[output1]]\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v14.16b, v14.16b, v14.16b\n"
"eor v15.16b, v15.16b, v15.16b\n"
"ld1 {v2.4s}, [%[b_ptr]], #16\n"
"2: \n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n"
"ld1 {v1.4s}, [%[a_ptr]], #16\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"ld1 {v3.4s}, [%[b_ptr]], #16\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v12.4s, v1.4s, v2.s[0]\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"fmla v13.4s, v1.4s, v2.s[1]\n"
"fmla v14.4s, v1.4s, v2.s[2]\n"
"fmla v15.4s, v1.4s, v2.s[3]\n"
"fmla v8.4s, v0.4s, v3.s[0]\n"
"ld1 {v1.4s}, [%[a_ptr]], #16\n"
"fmla v9.4s, v0.4s, v3.s[1]\n"
"fmla v10.4s, v0.4s, v3.s[2]\n"
"fmla v11.4s, v0.4s, v3.s[3]\n"
"ld1 {v2.4s}, [%[b_ptr]], #16\n"
"fmla v12.4s, v1.4s, v3.s[0]\n"
"subs %w[K], %w[K], #1\n"
"fmla v13.4s, v1.4s, v3.s[1]\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"fmla v14.4s, v1.4s, v3.s[2]\n"
"fmla v15.4s, v1.4s, v3.s[3]\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail
"fmla v8.4s, v0.4s, v2.s[0]\n"
"ld1 {v1.4s}, [%[a_ptr]], #16\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"ld1 {v3.4s}, [%[b_ptr]], #16\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v12.4s, v1.4s, v2.s[0]\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"fmla v13.4s, v1.4s, v2.s[1]\n"
"fmla v14.4s, v1.4s, v2.s[2]\n"
"fmla v15.4s, v1.4s, v2.s[3]\n"
"fmla v8.4s, v0.4s, v3.s[0]\n"
"ld1 {v1.4s}, [%[a_ptr]], #16\n"
"fmla v9.4s, v0.4s, v3.s[1]\n"
"fmla v10.4s, v0.4s, v3.s[2]\n"
"fmla v11.4s, v0.4s, v3.s[3]\n"
"fmla v12.4s, v1.4s, v3.s[0]\n"
"fmla v13.4s, v1.4s, v3.s[1]\n"
"fmla v14.4s, v1.4s, v3.s[2]\n"
"fmla v15.4s, v1.4s, v3.s[3]\n"
"b 6f\n"
// odd tail
"5:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n"
"ld1 {v1.4s}, [%[a_ptr]], #16\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v12.4s, v1.4s, v2.s[0]\n"
"fmla v13.4s, v1.4s, v2.s[1]\n"
"fmla v14.4s, v1.4s, v2.s[2]\n"
"fmla v15.4s, v1.4s, v2.s[3]\n"
"6:\n" STORE_C
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K),
[ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk),
[ output0 ] "+r"(output0), [ output1 ] "+r"(output1),
[ n_remain ] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "cc", "memory");
#undef LOAD_C
#undef STORE_C
}
// Overview of register layout:
//
// A 1x12 cell of Rhs is stored in 32bit in v2-v7
// A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
// A 8x12 block of accumulators is stored in 32bit in v8-v31.
//
// +--------+--------+--------+
// | v2[0-3]| v3[0-3]| v4[0-3]|
// | v5[0-3]| v6[0-3]| v7[0-3]|
// Rhs +--------+--------+--------+
//
// | | | |
//
// Lhs | | | |
//
// +--+ --- - +--------+--------+--------+
// |v0| | v8[0-3]| v9[0-3]|v10[0-3]|
// |v0| |v11[0-3]|v12[0-3]|v13[0-3]|
// |v0| |v14[0-3]|v15[0-3]|v16[0-3]|
// |v0| |v17[0-3]|v18[0-3]|v19[0-3]|
// +--+ --- - +--------+--------+--------+
//
// Accumulator
void kern_4x12(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
asm volatile(
"cmp %w[is_first_k], #1\n"
"beq 1f\n"
"mov x1, %[output0]\n"
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n"
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n"
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n"
"b 2f\n"
"1:\n"
"eor v8.16b, v8.16b, v8.16b\n"
"eor v9.16b, v9.16b, v9.16b\n"
"eor v10.16b, v10.16b, v10.16b\n"
"prfm pstl1keep, [%[output0]]\n"
"eor v11.16b, v11.16b, v11.16b\n"
"eor v12.16b, v12.16b, v12.16b\n"
"eor v13.16b, v13.16b, v13.16b\n"
"eor v14.16b, v14.16b, v14.16b\n"
"eor v15.16b, v15.16b, v15.16b\n"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"2: \n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v12.4s, v0.4s, v3.s[0]\n"
"fmla v13.4s, v0.4s, v3.s[1]\n"
"ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n"
"fmla v14.4s, v0.4s, v3.s[2]\n"
"fmla v15.4s, v0.4s, v3.s[3]\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v17.4s, v0.4s, v4.s[1]\n"
"fmla v18.4s, v0.4s, v4.s[2]\n"
"fmla v19.4s, v0.4s, v4.s[3]\n"
"fmla v8.4s, v1.4s, v5.s[0]\n"
"fmla v9.4s, v1.4s, v5.s[1]\n"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n"
"fmla v10.4s, v1.4s, v5.s[2]\n"
"fmla v11.4s, v1.4s, v5.s[3]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v12.4s, v1.4s, v6.s[0]\n"
"fmla v13.4s, v1.4s, v6.s[1]\n"
"subs %w[K], %w[K], #1\n"
"fmla v14.4s, v1.4s, v6.s[2]\n"
"fmla v15.4s, v1.4s, v6.s[3]\n"
"fmla v16.4s, v1.4s, v7.s[0]\n"
"fmla v17.4s, v1.4s, v7.s[1]\n"
"fmla v18.4s, v1.4s, v7.s[2]\n"
"fmla v19.4s, v1.4s, v7.s[3]\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail
"fmla v8.4s, v0.4s, v2.s[0]\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n"
"fmla v12.4s, v0.4s, v3.s[0]\n"
"fmla v13.4s, v0.4s, v3.s[1]\n"
"fmla v14.4s, v0.4s, v3.s[2]\n"
"fmla v15.4s, v0.4s, v3.s[3]\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v17.4s, v0.4s, v4.s[1]\n"
"fmla v18.4s, v0.4s, v4.s[2]\n"
"fmla v19.4s, v0.4s, v4.s[3]\n"
"fmla v8.4s, v1.4s, v5.s[0]\n"
"fmla v9.4s, v1.4s, v5.s[1]\n"
"fmla v10.4s, v1.4s, v5.s[2]\n"
"fmla v11.4s, v1.4s, v5.s[3]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v12.4s, v1.4s, v6.s[0]\n"
"fmla v13.4s, v1.4s, v6.s[1]\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n"
"fmla v14.4s, v1.4s, v6.s[2]\n"
"fmla v15.4s, v1.4s, v6.s[3]\n"
"fmla v16.4s, v1.4s, v7.s[0]\n"
"fmla v17.4s, v1.4s, v7.s[1]\n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n"
"fmla v18.4s, v1.4s, v7.s[2]\n"
"fmla v19.4s, v1.4s, v7.s[3]\n"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n"
"b 6f\n"
// odd tail
"5:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v12.4s, v0.4s, v3.s[0]\n"
"fmla v13.4s, v0.4s, v3.s[1]\n"
"fmla v14.4s, v0.4s, v3.s[2]\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n"
"fmla v15.4s, v0.4s, v3.s[3]\n"
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v17.4s, v0.4s, v4.s[1]\n"
"fmla v18.4s, v0.4s, v4.s[2]\n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n"
"fmla v19.4s, v0.4s, v4.s[3]\n"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n"
"6:\n"
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K),
[ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk),
[ output0 ] "+r"(output0)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"x1", "cc", "memory");
}
// Overview of register layout:
//
// A 2x4 cell of Rhs is stored in 32bit in v2 - v3
// A 4x2 cell of Lhs is stored in 32bit in v0 - v1
// A 4x4 block of accumulators is stored in 32bit in v4-v6
//
// +--------+
// | v2[0-3]|
// | v5[0-3]|
// Rhs +--------+
//
// | |
//
// Lhs | |
//
// +--+ --- - +--------+
// |v0| | v8[0-3]|
// |v0| |v11[0-3]|
// |v0| |v14[0-3]|
// |v0| |v17[0-3]|
// +--+ --- - +--------+
//
// Accumulator
void kern_4x4(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
//clang-format off
#define LOAD_C \
"cmp %w[n_remain], #4\n" \
"blt 11f\n" \
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
"b 14f\n" \
"11:\n" \
"cmp %w[n_remain], #3\n" \
"blt 12f\n" \
"ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"b 14f\n" \
"12:\n" \
"cmp %w[n_remain], #2\n" \
"blt 13f\n" \
"ld1 {v8.4s, v9.4s}, [%[output0]]\n" \
"b 14f\n" \
"13:\n" \
"ld1 {v8.4s}, [%[output0]]\n" \
"14:\n"
#define STORE_C \
"cmp %w[n_remain], #4\n" \
"blt 21f\n" \
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
"b 24f\n" \
"21:\n" \
"cmp %w[n_remain], #3\n" \
"blt 22f\n" \
"st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"b 23f\n" \
"22:\n" \
"cmp %w[n_remain], #2\n" \
"blt 23f\n" \
"st1 {v8.4s, v9.4s}, [%[output0]]\n" \
"b 24f\n" \
"23:\n" \
"st1 {v8.4s}, [%[output0]]\n" \
"24:\n"
//clang-format on
asm volatile(
// load accumulator C
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"ld1 {v2.4s}, [%[b_ptr]], #16\n"
"b 2f\n"
"1:\n"
"eor v8.16b, v8.16b, v8.16b\n"
"ld1 {v2.4s}, [%[b_ptr]], #16\n"
"eor v9.16b, v9.16b, v9.16b\n"
"ld1 {v0.4s}, [%[a_ptr]], #16\n"
"eor v10.16b, v10.16b, v10.16b\n"
"prfm pstl1keep, [%[output0]]\n"
"eor v11.16b, v11.16b, v11.16b\n"
"2: \n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v8.4s, v1.4s, v3.s[0]\n"
"fmla v9.4s, v1.4s, v3.s[1]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"fmla v10.4s, v1.4s, v3.s[2]\n"
"fmla v11.4s, v1.4s, v3.s[3]\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"subs %w[K], %w[K], #1\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail
"fmla v8.4s, v0.4s, v2.s[0]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"fmla v8.4s, v1.4s, v3.s[0]\n"
"fmla v9.4s, v1.4s, v3.s[1]\n"
"fmla v10.4s, v1.4s, v3.s[2]\n"
"fmla v11.4s, v1.4s, v3.s[3]\n"
"b 6f\n"
// odd tail
"5:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n"
"fmla v9.4s, v0.4s, v2.s[1]\n"
"fmla v10.4s, v0.4s, v2.s[2]\n"
"fmla v11.4s, v0.4s, v2.s[3]\n"
"6:\n" STORE_C
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K),
[ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk),
[ output0 ] "+r"(output0), [ n_remain ] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory");
#undef LOAD_C
#undef STORE_C
}
void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, int y0,
int ymax, int k0, int kmax) {
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int PACK_SIZE_32 = 4 * 8;
constexpr int PACK_SIZE_16 = 4 * 4;
constexpr int PACK_C_SIZE = 4;
int y = y0;
for (; y + 7 < ymax; y += 8) {
const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0;
const float* inptr1 = inptr0 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
int k = (kmax - k0);
for (; k > 3; k -= 4) {
interleave_2x4_4_s(inptr0, inptr1, outptr);
outptr += PACK_SIZE_32;
}
}
for (; y < ymax; y += 4) {
const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0;
prefetch_2x(inptr0);
int K = (kmax - k0);
for (; K > 3; K -= 4) {
interleave_1x4_4_s(inptr0, outptr);
outptr += PACK_SIZE_16;
}
}
}
void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
float tmpbuff[16] = {0.0f};
constexpr int PACK_C_SIZE = 4;
int ksize = kmax - k0;
int ksize12 = ksize * 12;
int ksize4 = (ksize << 2);
float* outptr_base = out;
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12;
int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE;
prefetch_3x(inptr);
int x = x0;
auto outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) {
auto outptr_interleave = outptr;
transpose_1x12_4_s(inptr, outptr_interleave);
outptr += ksize12;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
transpose_1x4_4_s(inptr, outptr_interleave);
outptr += ksize4;
}
if (x < xmax) {
std::memcpy(tmpbuff, inptr,
sizeof(float) * (xmax - x) * PACK_C_SIZE);
auto outptr_interleave = outptr;
const float* tmp_ptr = &tmpbuff[0];
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave);
outptr += ksize4;
}
outptr_base += 12 * 4;
outptr_base4 += 4 * 4;
}
}
} // namespace matmul_mk4_8x12
} // aarch64
} // megdnn
// vim: syntax=cpp.doxygen
......@@ -12,6 +12,7 @@
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h"
#include "src/common/utils.h"
using namespace megdnn;
......@@ -163,4 +164,80 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
}
}
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12);
void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose_A) const {
megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A");
matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax);
}
void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose_B) const {
megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B");
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}
void sgemm_mk4_8x12::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
constexpr size_t PACK_C_SIZE = 4;
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
float* output = C + (m / PACK_C_SIZE * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
matmul_mk4_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE * PACK_C_SIZE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
packA += K8;
}
for (; m < M; m += A_INTERLEAVE4) {
float* output = C + (m / PACK_C_SIZE * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE * PACK_C_SIZE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
packA += K4;
}
}
// vim: syntax=cpp.doxygen
......@@ -20,6 +20,9 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true,
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true,
sgemm_4x16);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false,
sgemm_mk4_8x12);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true,
sgemm_nopack_4x16);
......
......@@ -18,6 +18,7 @@ using namespace aarch64;
class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32K8x12x1 f32K8x12x1;
AlgoF32MK4_8x12x1 f32_mk4_8x12x1;
AlgoF32K4x16x1 f32k4x16x1;
AlgoF32MK4_4x16 f32mk4_4x16;
AlgoF32Gemv f32_gemv;
......@@ -53,6 +54,7 @@ public:
AlgoPack() {
all_algos.emplace_back(&f32_gemv);
all_algos.emplace_back(&f32K8x12x1);
all_algos.emplace_back(&f32_mk4_8x12x1);
all_algos.emplace_back(&f32k4x16x1);
all_algos.emplace_back(&f32mk4_4x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......
......@@ -22,6 +22,7 @@ public:
private:
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4
class AlgoF32Gemv; // Aarch64 F32 Gemv
......
......@@ -244,6 +244,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable(
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_2x3_8x8_f16;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
......@@ -252,6 +253,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable(
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
......
......@@ -38,6 +38,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
using Strategy = winograd::winograd_2x3_4x4_f;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
......@@ -46,6 +47,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
......@@ -319,6 +321,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
using Strategy = winograd::winograd_6x3_4x4_f;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
......@@ -327,6 +330,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
(opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
......
......@@ -217,6 +217,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable(
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_2x3_8x8_s8;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>(
......@@ -224,6 +225,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable(
param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
((opr->param().format == param::ConvBias::Format::NCHW &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8) ||
(opr->param().format == param::ConvBias::Format::NCHW_WINOGRAD &&
......
......@@ -31,6 +31,12 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) {
"AARCH64_F32K4X16X1");
}
TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1);
}
TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) {
//! nbase should be 4 in order to test the last rest 4 in N dim
matrix_mul::check_matrix_mul(
......@@ -527,6 +533,15 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_MK4) {
dtype::Float32{});
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_PACK_MK4) {
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(16);
matrix_mul::benchmark_with_contrast(
handle(), args, dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, "AARCH64_F32_MK4_K8X12X1",
param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, "AARCH64_F32K8X12X1");
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) {
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
matrix_mul::benchmark_with_contrast(
......
......@@ -40,8 +40,8 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args(
size_t nbase) {
std::vector<TestArg> args;
for (size_t m : {1, 2, 3, 4, 5})
for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24})
for (size_t k : {1, 2, 3, 4, 5})
for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24})
for (size_t k : {1, 2, 3, 4, 5, 9, 10})
args.emplace_back(m, n * nbase, k, 0);
return args;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册