提交 27ef788f 编写于 作者: M Megvii Engine Team

feat(dnn/armv7): add armv7 mk4 matmul

GitOrigin-RevId: 8ef24bf53b19f7863b8a51da5a81135c941d1a72
上级 efb60be2
...@@ -707,7 +707,7 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, ...@@ -707,7 +707,7 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output,
"cmp %w[n_remain], #3\n" \ "cmp %w[n_remain], #3\n" \
"blt 22f\n" \ "blt 22f\n" \
"st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"b 23f\n" \ "b 24f\n" \
"22:\n" \ "22:\n" \
"cmp %w[n_remain], #2\n" \ "cmp %w[n_remain], #2\n" \
"blt 23f\n" \ "blt 23f\n" \
......
...@@ -85,6 +85,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, ...@@ -85,6 +85,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern,
"AlgoF32Impl"_hash, "AlgoF32Impl"_hash,
armv7::matmul::sgemm_4x12, float, float); armv7::matmul::sgemm_4x12, float, float);
/* ===================== F32 algo mk4 K4x12 ===================== */
namespace {
void f32_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("f32_mk4_pack_4x12_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
auto Cptr = kern_param.C<float>();
armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<armv7::matmul::sgemm_mk4_pack_4x12>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoF32MK4Pack4x12::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB && kern_size_param.M % 4 == 0 &&
kern_size_param.K % 4 == 0 && !kern_size_param.trA &&
!kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("AlgoF32MK4Pack4x12::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<
armv7::matmul::sgemm_mk4_pack_4x12>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern(
const KernSizeParam&) const {
return f32_mk4_pack_4x12_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12,
megdnn_armv7_matmul_kern,
"AlgoF32MK4Pack4x12"_hash,
armv7::matmul::sgemm_mk4_pack_4x12, float,
float);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/* ===================== F16 K4x16x1 algo ===================== */ /* ===================== F16 K4x16x1 algo ===================== */
namespace { namespace {
......
...@@ -29,6 +29,17 @@ public: ...@@ -29,6 +29,17 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
......
...@@ -1120,6 +1120,62 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, ...@@ -1120,6 +1120,62 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1,
: "q0", "q1", "q2", "memory"); : "q0", "q1", "q2", "memory");
} }
template <typename T>
static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4,
"transpose_1x12_4_s only support sizeof(T) == 4");
asm volatile(
"vld4.32 {d0-d3}, [%[inptr0]]!\n"
"vld4.32 {d4-d7}, [%[inptr0]]!\n"
"vld4.32 {d8-d11}, [%[inptr0]]!\n"
"vld4.32 {d12-d15}, [%[inptr0]]!\n"
"vld4.32 {d16-d19}, [%[inptr0]]!\n"
"vld4.32 {d20-d23}, [%[inptr0]]!\n"
"vswp d1, d4\n"
"vswp d3, d6\n"
"vswp d9, d12\n"
"vswp d11, d14\n"
"vswp d17, d20\n"
"vswp d19, d22\n"
"vst1.32 {d0-d1}, [%[outptr]]! \n"
"vst1.32 {d8-d9}, [%[outptr]]! \n"
"vst1.32 {d16-d17}, [%[outptr]]! \n"
"vst1.32 {d4-d5}, [%[outptr]]! \n"
"vst1.32 {d12-d13}, [%[outptr]]! \n"
"vst1.32 {d20-d21}, [%[outptr]]! \n"
"vst1.32 {d2-d3}, [%[outptr]]! \n"
"vst1.32 {d10-d11}, [%[outptr]]! \n"
"vst1.32 {d18-d19}, [%[outptr]]! \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"vst1.32 {d14-d15}, [%[outptr]]! \n"
"vst1.32 {d22-d23}, [%[outptr]]! \n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "memory");
}
template <typename T>
static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4,
"transpose_1x4_4_s only support sizeof(T) == 4");
asm volatile(
"vld4.32 {d0-d3}, [%[inptr0]]!\n"
"vld4.32 {d4-d7}, [%[inptr0]]!\n"
"vswp d1, d4\n"
"vswp d3, d6\n"
"vst1.32 {d0-d1}, [%[outptr]]! \n"
"vst1.32 {d4-d5}, [%[outptr]]! \n"
"vst1.32 {d2-d3}, [%[outptr]]! \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
:
: "q0", "q1", "q2", "q3", "memory");
}
template <typename T> template <typename T>
static inline void transpose_4(const T*& inptr0, const T*& inptr1, static inline void transpose_4(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3, T* outptr, const T*& inptr2, const T*& inptr3, T* outptr,
......
...@@ -18,6 +18,9 @@ namespace matmul { ...@@ -18,6 +18,9 @@ namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true,
sgemm_4x12); sgemm_4x12);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, false,
sgemm_mk4_pack_4x12);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 8, 1, false, true, MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 8, 1, false, true,
sgemm_nopack_4x8); sgemm_nopack_4x8);
......
/**
* \file dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/armv7/matrix_mul/fp32/strategy.h"
#include "src/armv7/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
using namespace megdnn;
using namespace armv7;
using namespace armv7::matmul;
namespace {
// Overview of register layout:
//
// A 1x12 cell of Rhs is stored in 32bit in q1-q3
// A 4x1 cell of Lhs is stored in 132bit in q0
// A 4x12 block of accumulators is stored in 32bit in q4-q15.
//
// +--------+--------+--------+
// | q1[0-3]| q2[0-3]| q3[0-3]|
// Rhs +--------+--------+--------+
//
// | | | |
//
// Lhs | | | |
//
// +--+ - - - - +--------+--------+--------+
// |q0| | q4[0-3]| q5[0-3]| q6[0-3]|
// |q0| | q7[0-3]| q8[0-3]| q9[0-3]|
// |q0| |q10[0-3]|q11[0-3]|q12[0-3]|
// |q0| |q13[0-3]|q14[0-3]|q15[0-3]|
// +--+ - - - - +--------+--------+--------+
//
// Accumulator
void kern_4x12(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
asm volatile(
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"mov r1, %[output0]\n"
"vld1.32 {d8-d11}, [r1]!\n"
"vld1.32 {d12-d15}, [r1]!\n"
"vld1.32 {d16-d19}, [r1]!\n"
"vld1.32 {d20-d23}, [r1]!\n"
"vld1.32 {d24-d27}, [r1]!\n"
"vld1.32 {d28-d31}, [r1]!\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"b 2f\n"
"1:\n"
"veor.32 q4, q4, q4\n"
"pld [%[output0]]\n"
"veor.32 q5, q4, q4\n"
"veor.32 q6, q4, q4\n"
"veor.32 q7, q4, q4\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"veor.32 q8, q4, q4\n"
"veor.32 q9, q4, q4\n"
"veor.32 q10, q4, q4\n"
"veor.32 q11, q4, q4\n"
"vld1.32 {d4-d7}, [%[b_ptr]]!\n"
"veor.32 q12, q4, q4\n"
"veor.32 q13, q4, q4\n"
"veor.32 q14, q4, q4\n"
"veor.32 q15, q4, q4\n"
"2: \n"
"cmp %[K], #0\n"
"beq 4f\n"
"3:\n"
"vmla.f32 q4, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q8, q0, d6[0]\n"
"vmla.f32 q9, q0, d6[1]\n"
"vmla.f32 q10, q0, d7[0]\n"
"vld1.32 {d2-d3}, [%[a_ptr]]!\n"
"vmla.f32 q11, q0, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q12, q0, d4[0]\n"
"vmla.f32 q13, q0, d4[1]\n"
"vmla.f32 q14, q0, d5[0]\n"
"vmla.f32 q15, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q4, q1, d6[0]\n"
"subs %[K], %[K], #1\n"
"vmla.f32 q5, q1, d6[1]\n"
"vmla.f32 q6, q1, d7[0]\n"
"vmla.f32 q7, q1, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q8, q1, d4[0]\n"
"vmla.f32 q9, q1, d4[1]\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"vmla.f32 q10, q1, d5[0]\n"
"vmla.f32 q11, q1, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q12, q1, d6[0]\n"
"vmla.f32 q13, q1, d6[1]\n"
"vmla.f32 q14, q1, d7[0]\n"
"vmla.f32 q15, q1, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"bne 3b\n"
"4:\n"
"cmp %[oddk], #1\n"
"beq 5f\n"
// Even tail
"vmla.f32 q4, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q8, q0, d6[0]\n"
"vmla.f32 q9, q0, d6[1]\n"
"vmla.f32 q10, q0, d7[0]\n"
"vld1.32 {d2-d3}, [%[a_ptr]]!\n"
"vmla.f32 q11, q0, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q12, q0, d4[0]\n"
"vmla.f32 q13, q0, d4[1]\n"
"vmla.f32 q14, q0, d5[0]\n"
"vmla.f32 q15, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q4, q1, d6[0]\n"
"subs %[K], %[K], #1\n"
"vmla.f32 q5, q1, d6[1]\n"
"vmla.f32 q6, q1, d7[0]\n"
"vmla.f32 q7, q1, d7[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q8, q1, d4[0]\n"
"vmla.f32 q9, q1, d4[1]\n"
"vst1.32 {d8-d11}, [%[output0]]!\n"
"vmla.f32 q10, q1, d5[0]\n"
"vmla.f32 q11, q1, d5[1]\n"
"vst1.32 {d12-d15}, [%[output0]]!\n"
"vmla.f32 q12, q1, d6[0]\n"
"vmla.f32 q13, q1, d6[1]\n"
"vst1.32 {d16-d19}, [%[output0]]!\n"
"vmla.f32 q14, q1, d7[0]\n"
"vmla.f32 q15, q1, d7[1]\n"
"vst1.32 {d20-d23}, [%[output0]]!\n"
"vst1.32 {d24-d27}, [%[output0]]!\n"
"vst1.32 {d28-d31}, [%[output0]]!\n"
"b 6f\n"
// odd tail
"5:\n"
"vmla.f32 q4, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q8, q0, d6[0]\n"
"vst1.32 {d8-d11}, [%[output0]]!\n"
"vmla.f32 q9, q0, d6[1]\n"
"vmla.f32 q10, q0, d7[0]\n"
"vst1.32 {d12-d15}, [%[output0]]!\n"
"vmla.f32 q11, q0, d7[1]\n"
"vmla.f32 q12, q0, d4[0]\n"
"vst1.32 {d16-d19}, [%[output0]]!\n"
"vmla.f32 q13, q0, d4[1]\n"
"vst1.32 {d20-d23}, [%[output0]]!\n"
"vmla.f32 q14, q0, d5[0]\n"
"vst1.32 {d24-d27}, [%[output0]]!\n"
"vmla.f32 q15, q0, d5[1]\n"
"vst1.32 {d28-d31}, [%[output0]]!\n"
"6:\n"
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K),
[ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk),
[ output0 ] "+r"(output0)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q13", "q14", "q15", "r1", "cc", "memory");
}
// Overview of register layout:
//
// A 2x4 cell of Rhs is stored in 32bit in v2 - v3
// A 4x2 cell of Lhs is stored in 32bit in v0 - v1
// A 4x4 block of accumulators is stored in 32bit in v4-v6
//
// +--------+
// | q2[0-3]|
// | q5[0-3]|
// Rhs +--------+
//
// | |
//
// Lhs | |
//
// +--+ --- - +--------+
// |q0| | q8[0-3]|
// |q0| |q11[0-3]|
// |q0| |q14[0-3]|
// |q0| |q17[0-3]|
// +--+ --- - +--------+
//
// Accumulator
void kern_4x4(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
//clang-format off
#define LOAD_C \
"cmp %[n_remain], #4\n" \
"blt 11f\n" \
"vld1.32 {d8-d11}, [r1]!\n" \
"vld1.32 {d12-d15}, [r1]!\n" \
"b 14f\n" \
"11:\n" \
"cmp %[n_remain], #3\n" \
"blt 12f\n" \
"vld1.32 {d8-d11}, [r1]!\n" \
"vld1.32 {d12-d13}, [r1]!\n" \
"b 14f\n" \
"12:\n" \
"cmp %[n_remain], #2\n" \
"blt 13f\n" \
"vld1.32 {d8-d11}, [r1]\n" \
"b 14f\n" \
"13:\n" \
"vld1.32 {d8-d9}, [r1]\n" \
"14:\n"
#define STORE_C \
"cmp %[n_remain], #4\n" \
"blt 21f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"vst1.32 {d12-d15}, [%[output]]!\n" \
"b 24f\n" \
"21:\n" \
"cmp %[n_remain], #3\n" \
"blt 22f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"vst1.32 {d12-d13}, [%[output]]!\n" \
"b 24f\n" \
"22:\n" \
"cmp %[n_remain], #2\n" \
"blt 23f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"b 24f\n" \
"23:\n" \
"vst1.32 {d8-d9}, [%[output]]!\n" \
"24:\n"
//clang-format on
asm volatile(
"cmp %[is_first_k], #1\n"
"beq 1f\n"
"mov r1, %[output]\n" LOAD_C
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"b 2f\n"
"1:\n"
"veor.32 q4, q4, q4\n"
"pld [%[output]]\n"
"veor.32 q5, q4, q4\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"veor.32 q6, q4, q4\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"veor.32 q7, q4, q4\n"
"2: \n"
"cmp %[K], #0\n"
"beq 4f\n"
"3:\n"
"vmla.f32 q4, q0, d4[0]\n"
"vld1.32 {d2-d3}, [%[a_ptr]]!\n"
"vmla.f32 q5, q0, d4[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vld1.32 {d4-d5}, [%[b_ptr]]!\n"
"vmla.f32 q4, q1, d6[0]\n"
"subs %[K], %[K], #1\n"
"vmla.f32 q5, q1, d6[1]\n"
"vld1.32 {d0-d1}, [%[a_ptr]]!\n"
"vmla.f32 q6, q1, d7[0]\n"
"vmla.f32 q7, q1, d7[1]\n"
"bne 3b\n"
"4:\n"
"cmp %[oddk], #1\n"
"beq 5f\n"
// Even tail
"vmla.f32 q4, q0, d4[0]\n"
"vld1.32 {d2-d3}, [%[a_ptr]]!\n"
"vmla.f32 q5, q0, d4[1]\n"
"vld1.32 {d6-d7}, [%[b_ptr]]!\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"vmla.f32 q4, q1, d6[0]\n"
"vmla.f32 q5, q1, d6[1]\n"
"vmla.f32 q6, q1, d7[0]\n"
"vmla.f32 q7, q1, d7[1]\n"
"b 6f\n"
// odd tail
"5:\n"
"vmla.f32 q4, q0, d4[0]\n"
"vmla.f32 q5, q0, d4[1]\n"
"vmla.f32 q6, q0, d5[0]\n"
"vmla.f32 q7, q0, d5[1]\n"
"6:\n" STORE_C
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K),
[ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk),
[ output ] "+r"(output), [ n_remain ] "+r"(n_remain)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "r1", "cc",
"memory");
#undef LOAD_C
#undef STORE_C
}
} // namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_pack_4x12);
//! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy
//! the weight
void sgemm_mk4_pack_4x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool) const {
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int PACK_C_SIZE = 4;
size_t cp_length = (kmax - k0) * PACK_C_SIZE;
for (int m = y0; m < ymax; m += 4) {
const float* src = in + (m / PACK_C_SIZE) * ldin + k0 * PACK_C_SIZE;
memcpy(out, src, cp_length * sizeof(float));
out += cp_length;
}
}
void sgemm_mk4_pack_4x12::pack_B(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose_B) const {
megdnn_assert(!transpose_B);
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
float tmpbuff[16] = {0.0f};
constexpr int PACK_C_SIZE = 4;
int ksize = kmax - k0;
int ksize12 = ksize * 12;
int ksize4 = (ksize << 2);
float* outptr_base = out;
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12;
int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE;
prefetch_3x(inptr);
int x = x0;
auto outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) {
auto outptr_interleave = outptr;
transpose_1x12_4_s(inptr, outptr_interleave);
outptr += ksize12;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
transpose_1x4_4_s(inptr, outptr_interleave);
outptr += ksize4;
}
if (x < xmax) {
memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE);
auto outptr_interleave = outptr;
const float* tmp_ptr = &tmpbuff[0];
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave);
outptr += ksize4;
}
outptr_base += 12 * PACK_C_SIZE;
outptr_base4 += 4 * PACK_C_SIZE;
}
}
void sgemm_mk4_pack_4x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
constexpr int PACK_C_SIZE = 4;
constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K4 = K * 4;
size_t m = 0;
for (; m < M; m += A_INTERLEAVE) {
float* output = C + (m / 4 * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
kern_4x12(packA, cur_packB, K, output, LDC, is_first_k);
output += PACK_C_SIZE * B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += PACK_C_SIZE * 4;
cur_packB += K4;
}
packA += K4;
}
}
// vim: syntax=cpp.doxygen
...@@ -20,6 +20,7 @@ using namespace armv7; ...@@ -20,6 +20,7 @@ using namespace armv7;
class MatrixMulImpl::AlgoPack : NonCopyableObj { class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32 f32; AlgoF32 f32;
AlgoF32MK4Pack4x12 f32_mk4_pack_4x12;
AlgoF32MK4_4x8 f32_mk4_4x8; AlgoF32MK4_4x8 f32_mk4_4x8;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16K4x16x1 f16_k4x16x1; AlgoF16K4x16x1 f16_k4x16x1;
...@@ -48,6 +49,7 @@ public: ...@@ -48,6 +49,7 @@ public:
AlgoPack() { AlgoPack() {
all_algos.emplace_back(&f32_gemv); all_algos.emplace_back(&f32_gemv);
all_algos.emplace_back(&f32); all_algos.emplace_back(&f32);
all_algos.emplace_back(&f32_mk4_pack_4x12);
all_algos.emplace_back(&f32_mk4_4x8); all_algos.emplace_back(&f32_mk4_4x8);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16_k4x16x1); all_algos.emplace_back(&f16_k4x16x1);
......
...@@ -21,6 +21,7 @@ public: ...@@ -21,6 +21,7 @@ public:
SmallVector<AlgoBase*> algo_pack() override; SmallVector<AlgoBase*> algo_pack() override;
private: private:
class AlgoF32; // Armv7 F32 class AlgoF32; // Armv7 F32
class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack
class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack
class AlgoF32Gemv; // Armv7 F32 Gemv class AlgoF32Gemv; // Armv7 F32 Gemv
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8
......
...@@ -1287,23 +1287,27 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { ...@@ -1287,23 +1287,27 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
#undef cb #undef cb
} }
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 4, 7}, 1); get_nchw44_conv_bias_args({2, 4, 7}, 1);
#if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
} #elif MEGDNN_ARMV7
check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
#endif #endif
}
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({3, 5, 6}, 2); get_nchw44_conv_bias_args({3, 5, 6}, 2);
#if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
} #elif MEGDNN_ARMV7
check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
#endif #endif
}
/***************************** Conv1x1 Algo Test ***********************/ /***************************** Conv1x1 Algo Test ***********************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
...@@ -1316,14 +1320,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { ...@@ -1316,14 +1320,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
#endif #endif
} }
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, false, false); get_nchw44_conv_bias_args({1}, 1, true, false, false);
#if MEGDNN_AARCH64
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24");
} #elif MEGDNN_ARMV7
check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24");
#endif #endif
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) {
using namespace conv_bias; using namespace conv_bias;
......
...@@ -28,6 +28,12 @@ TEST_F(ARMV7, MATRIX_MUL_MK4) { ...@@ -28,6 +28,12 @@ TEST_F(ARMV7, MATRIX_MUL_MK4) {
"ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4);
} }
TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"ARMV7_F32_MK4_PACK_4X12", param::MatrixMul::Format::MK4, 1);
}
TEST_F(ARMV7, MATRIX_MUL_MK4_INT8) { TEST_F(ARMV7, MATRIX_MUL_MK4_INT8) {
std::vector<matrix_mul::TestArg> args; std::vector<matrix_mul::TestArg> args;
for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11}) for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11})
...@@ -349,6 +355,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) { ...@@ -349,6 +355,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) {
dtype::Float32{}); dtype::Float32{});
} }
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_PACK_MK4) {
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
matrix_mul::benchmark_with_contrast(
handle(), args, dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, "ARMV7_F32_MK4_PACK_4X12",
param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
dtype::Float32{});
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) { TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) {
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4); auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4);
matrix_mul::benchmark_with_contrast( matrix_mul::benchmark_with_contrast(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册