diff --git a/dnn/src/arm_common/matrix_mul/algos.cpp b/dnn/src/arm_common/matrix_mul/algos.cpp index 41e3fc4a49811c4e2fc9553cabc858412ba01266..d939a2521fa8c9d6cd9b22d31167dba7f248fbd5 100644 --- a/dnn/src/arm_common/matrix_mul/algos.cpp +++ b/dnn/src/arm_common/matrix_mul/algos.cpp @@ -239,46 +239,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(const KernSizeParam&) return f32_gemv_kern; } -/* ================== F32 Gemv MK4 algo ================== */ -namespace { -void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("f32_gemv_mk4_kern"_hash)) { - auto M = kern_param.M, N = kern_param.N, K = kern_param.K; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; - const auto Aptr = kern_param.A(), Bptr = kern_param.B(); - auto Cptr = kern_param.C(); - gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); - } - MIDOUT_END(); -} -} // anonymous namespace - -bool MatrixMulImpl::AlgoF32GemvMK4::usable(const KernSizeParam& kern_size_param) const { - // enumerate the M, N, K, only usable when preferred - auto M = kern_size_param.M; - auto N = kern_size_param.N; - auto K = kern_size_param.K; - auto LDB = kern_size_param.LDB; - - return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - kern_size_param.format == param::MatrixMul::Format::MK4 && - kern_size_param.B_type == kern_size_param.A_type && - kern_size_param.C_type == kern_size_param.A_type && - kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && - !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; -} - -bool MatrixMulImpl::AlgoF32GemvMK4::preferred( - const KernSizeParam& kern_size_param) const { - MEGDNN_MARK_USED_VAR(kern_size_param); - return true; -} - -MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( - const KernSizeParam&) const { - return f32_gemv_mk4_kern; -} - /* ===================== F32 Gevm algo ===================== */ namespace { template diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index 2d9ac4bd5f5954e2c9465717ce5f307d36d78102..a3ab3e3e60d5e8f0e8fca9cbb5d3474bc893c9fa 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -95,22 +95,6 @@ public: MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) }; -class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { -public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; - } - const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } - bool usable(const KernSizeParam&) const override; - bool preferred(const KernSizeParam&) const override; - size_t get_workspace(const KernSizeParam&) const override { return 0; } - kern_t get_kern(const KernSizeParam&) const override; - AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } - PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4) -}; - #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { public: diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index 6cfef7f1ef6a42123afaa9aad71425d5c85997eb..4d3bfb1c621e3a53aa215110f32c2c8d7d26680d 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -26,7 +26,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; #endif AlgoGevm gevm; - AlgoF32GemvMK4 f32_gemv_mk4; SmallVector m_all_algos; fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; @@ -42,7 +41,6 @@ public: #endif m_all_algos.emplace_back(&int8x8x32_gemv); m_all_algos.emplace_back(&int8x8x32_gemv_mk4); - m_all_algos.emplace_back(&f32_gemv_mk4); m_all_algos.emplace_back(&gevm); for (auto&& algo : m_all_algos) { diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index e79aebe098622328a6c6d9e26ccd8088039f9307..e2a3c68c5ddce19dfe11b30f007bc59e1693bb14 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -34,7 +34,6 @@ public: protected: class AlgoF32Gemv; // Arm_common F32 Gemv - class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 class AlgoGevm; // Arm_common Gevm(support int8 and fp32) diff --git a/dnn/src/fallback/matrix_mul/algos.cpp b/dnn/src/fallback/matrix_mul/algos.cpp index da02af79375d78301d901faf83fd5c387caf98c7..f1d30d7a4d3f2c8da1c0a9e07f068462646a45c7 100644 --- a/dnn/src/fallback/matrix_mul/algos.cpp +++ b/dnn/src/fallback/matrix_mul/algos.cpp @@ -17,11 +17,15 @@ #include "src/naive/matrix_mul/matrix_mul_helper.h" +#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" + #include "midout.h" MIDOUT_DECL(megdnn_fb_matmul_f32_kern) MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) MIDOUT_DECL(megdnn_fb_matmul_naive) +MIDOUT_DECL(megdnn_fb_gi_exec_fp32) +MIDOUT_DECL(megdnn_fb_gi_matmul_kern) using namespace megdnn; using namespace fallback; @@ -205,4 +209,99 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(const KernSizeParam&) c return kern_naive; } +/* ================== F32 Gemv MK4 gi algo ================== */ +namespace { +void gi_f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_fb_gi_exec_fp32, midout_iv("f32_gemv_mk4_gi_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + gi_gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoF32GiGemvMK4::usable( + const KernSizeParam& kern_size_param) const { + // enumerate the M, N, K, only usable when preferred + auto M = kern_size_param.M; + auto N = kern_size_param.N; + auto K = kern_size_param.K; + auto LDB = kern_size_param.LDB; + + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && + !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; +} + +bool MatrixMulImpl::AlgoF32GiGemvMK4::preferred( + const KernSizeParam& kern_size_param) const { + MEGDNN_MARK_USED_VAR(kern_size_param); + return true; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiGemvMK4::get_kern( + const KernSizeParam&) const { + return gi_f32_gemv_mk4_kern; +} + +/* ================== F32 Gemm MK4 gi algo ================== */ +namespace { +void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_fb_gi_matmul_kern, midout_iv("gi_f32_mk4_4x8_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); + } + MIDOUT_END(); +} + +} // anonymous namespace +bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && + !kern_size_param.trB; +} + +size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN( + megdnn_fb_gi_matmul_kern, + midout_iv("AlgoF32GiMK4_4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved< + matmul::fallback::gi_sgemm_nopack_4x8, false>( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); + return 0; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( + const KernSizeParam&) const { + return gi_f32_mk4_4x8_kern; +} // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/algos.h b/dnn/src/fallback/matrix_mul/algos.h index 218d750ed26fea76514e6fd7a00addff1b141339..fe62c13f16a83f30f7902ad955852a3dd93c4e15 100644 --- a/dnn/src/fallback/matrix_mul/algos.h +++ b/dnn/src/fallback/matrix_mul/algos.h @@ -80,6 +80,34 @@ public: DEFAULT) }; +class MatrixMulImpl::AlgoF32GiGemvMK4 : public AlgoBase { +public: + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + } + const char* name() const override { return "FB_GI_F32_GEMV_MK4"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) + MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_GEMV_MK4) +}; + +class MatrixMulImpl::AlgoF32GiMK4_4x8 final : public AlgoBase { +public: + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "FB_GI_F32_MK4_4x8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) + MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) +}; + } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/matrix_mul/generic_strategy.h b/dnn/src/fallback/matrix_mul/generic_strategy.h index d06d10d6e3ca1ddd4583a65b605366d28b58114b..3129981880f42c3d7bd81f6969d25185c7e04308 100644 --- a/dnn/src/fallback/matrix_mul/generic_strategy.h +++ b/dnn/src/fallback/matrix_mul/generic_strategy.h @@ -16,6 +16,8 @@ namespace matmul { namespace fallback { MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); +MEGDNN_REG_GEMM_STRATEGY_NOPACK( + float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); } // namespace fallback } // namespace matmul diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2034a0fcd793719b49154b8edadfa1e6cd410598 --- /dev/null +++ b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp @@ -0,0 +1,101 @@ +/** + * \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2022 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" +#include "include/megdnn/oprs.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/general_intrinsic/gi_float.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_fp32_gi_sgemv) + +using namespace megdnn; +using namespace fallback; + +namespace { + +void sgemv_gi_naive_n_mk4( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { + constexpr size_t PACK_SIZE = 4; + megdnn_assert( + N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0); + auto Aptr = A; + auto Cptr = C; + size_t m = 0; + while (m < M) { + auto Aptr0 = Aptr; + auto Cptr0 = Cptr; + GI_FLOAT32_t c[4]; +#define INIT(step) c[step] = GiBroadcastFloat32(0.0f); + UNROLL_CALL_RAW(4, INIT) +#undef INIT + auto Bptr = B; + size_t k = 0; + while (k < K) { + GI_FLOAT32_t b = GiLoadFloat32(Bptr); + GI_FLOAT32_V2_t a[2]; +#if defined(GI_TEST_NAIVE) +#define LOAD_A(step) \ + a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ + a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); +#elif defined(__arm__) || defined(__aarch64__) +#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); +#else +#define LOAD_A(step) \ + a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ + a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); +#endif + UNROLL_CALL_RAW(2, LOAD_A) +#undef LOAD_A + +#define COMPT(step) \ + c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4); + UNROLL_CALL_RAW(4, COMPT) +#undef COMPT + Bptr += Bstride; + Aptr0 += PACK_SIZE * PACK_SIZE; + k += PACK_SIZE; + } + +#define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]); + UNROLL_CALL_RAW(2, ADD_C, 2) + UNROLL_CALL_RAW(1, ADD_C, 1) +#undef ADD_C + GiStoreFloat32(Cptr0, c[0]); + + Aptr += Astride; + Cptr += Cstride; + m += PACK_SIZE; + } +} + +} // namespace + +namespace megdnn { +namespace fallback { + +void gi_gemv_like_mk4( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1 && Bstride == 4); + MIDOUT_BEGIN(megdnn_fp32_gi_sgemv, midout_iv("F32_GEMV_NCHW_GI_44_N"_hash)) { + return sgemv_gi_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); + } + MIDOUT_END(); +} + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h new file mode 100644 index 0000000000000000000000000000000000000000..90f60829a93fb6b8cf43040addf45d5c13eccf45 --- /dev/null +++ b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h @@ -0,0 +1,25 @@ +/** + * \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2022 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include + +namespace megdnn { +namespace fallback { + +void gi_gemv_like_mk4( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7e91da5a2de326776c91a49bc7d2d8fc5b8c4c7 --- /dev/null +++ b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp @@ -0,0 +1,349 @@ +/** + * \file dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2022 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/common/utils.h" +#include "src/fallback/general_intrinsic/gi_float.h" +#include "src/fallback/matrix_mul/generic_strategy.h" + +using namespace megdnn; +using namespace matmul::fallback; + +namespace { + +void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { + LDB = LDB - 4; + K = K - 4; + + GI_FLOAT32_t d8d9 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d10d11 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d12d13 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d14d15 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d16d17 = GiBroadcastFloat32(0.0f); + GI_FLOAT32_t d18d19 = GiBroadcastFloat32(0.0f); + GI_FLOAT32_t d20d21 = GiBroadcastFloat32(0.0f); + GI_FLOAT32_t d22d23 = GiBroadcastFloat32(0.0f); + + GI_FLOAT32_t d0d1 = GiLoadFloat32(B); + B = B + 4; + + d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); + d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); + + for (; K > 0; K -= 4) { + d8d9 = GiLoadFloat32(A); + A = A + 4; + d10d11 = GiLoadFloat32(A); + A = A + 4; + d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); + d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); + + B = B + LDB; + d0d1 = GiLoadFloat32(B); + B = B + 4; + d12d13 = GiLoadFloat32(A); + A = A + 4; + d14d15 = GiLoadFloat32(A); + A = A + 4; + + d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); + d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); + } + + d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); + d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); + d16d17 = GiAddFloat32(d16d17, d20d21); + d18d19 = GiAddFloat32(d18d19, d22d23); + d16d17 = GiAddFloat32(d16d17, d18d19); + + GiStoreFloat32(C, d16d17); + C = C + 4; +} + +void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { + LDB = (LDB - 16); + K = K - 4; + + GI_FLOAT32_t d8d9 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d10d11 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d12d13 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d14d15 = GiLoadFloat32(A); + A = A + 4; + + GI_FLOAT32_t d0d1 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d2d3 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d4d5 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d6d7 = GiLoadFloat32(B); + B = B + 4; + + GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); + GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); + GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); + GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); + + d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); + d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); + d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); + d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); + + for (; K > 0; K -= 4) { + d8d9 = GiLoadFloat32(A); + A = A + 4; + d10d11 = GiLoadFloat32(A); + A = A + 4; + + d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); + d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); + d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); + d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); + + B = B + LDB; + + d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); + d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); + d0d1 = GiLoadFloat32(B); + B = B + 4; + d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); + d2d3 = GiLoadFloat32(B); + B = B + 4; + d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); + d4d5 = GiLoadFloat32(B); + B = B + 4; + + d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); + d6d7 = GiLoadFloat32(B); + B = B + 4; + d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); + d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); + d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); + + d12d13 = GiLoadFloat32(A); + A = A + 4; + d14d15 = GiLoadFloat32(A); + A = A + 4; + + d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); + d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); + d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); + d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); + } + + d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); + d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); + d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); + d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); + + d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); + d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); + d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); + d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); + + GiStoreFloat32(C, d16d17); + C = C + 4; + GiStoreFloat32(C, d18d19); + C = C + 4; + GiStoreFloat32(C, d20d21); + C = C + 4; + GiStoreFloat32(C, d22d23); + C = C + 4; +} + +void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { + LDB -= 32; + GI_FLOAT32_t d8d9 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d10d11 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d12d13 = GiLoadFloat32(A); + A = A + 4; + GI_FLOAT32_t d14d15 = GiLoadFloat32(A); + A = A + 4; + + GI_FLOAT32_t d0d1 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d2d3 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d4d5 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d6d7 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); + d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); + GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); + d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); + d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); + d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); + d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); + d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); + d0d1 = GiLoadFloat32(B); + B = B + 4; + d2d3 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); + d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); + GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); + d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); + d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); + d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); + d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); + d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); + + d4d5 = GiLoadFloat32(B); + B = B + 4; + d6d7 = GiLoadFloat32(B); + B = B + 4; + GI_FLOAT32_t d24d25 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); + d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); + GI_FLOAT32_t d26d27 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); + d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); + d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); + d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); + d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); + d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); + GI_FLOAT32_t d28d29 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); + d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); + GI_FLOAT32_t d30d31 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); + d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); + d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); + d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); + d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); + d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); + + B = B + LDB; + K = K - 4; + for (; K > 0; K -= 4) { + d8d9 = GiLoadFloat32(A); + A = A + 4; + d10d11 = GiLoadFloat32(A); + A = A + 4; + d12d13 = GiLoadFloat32(A); + A = A + 4; + d14d15 = GiLoadFloat32(A); + A = A + 4; + + d0d1 = GiLoadFloat32(B); + B = B + 4; + d2d3 = GiLoadFloat32(B); + B = B + 4; + d4d5 = GiLoadFloat32(B); + B = B + 4; + d6d7 = GiLoadFloat32(B); + B = B + 4; + d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); + d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); + d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); + d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); + d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); + d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); + d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); + d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); + d0d1 = GiLoadFloat32(B); + B = B + 4; + d2d3 = GiLoadFloat32(B); + B = B + 4; + d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); + d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); + d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); + d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); + d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); + d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); + d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); + d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); + + d4d5 = GiLoadFloat32(B); + B = B + 4; + d6d7 = GiLoadFloat32(B); + B = B + 4; + d24d25 = GiSimdFmaLane(d24d25, d8d9, d0d1, 0); + d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); + d26d27 = GiSimdFmaLane(d26d27, d8d9, d2d3, 0); + d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); + d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); + d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); + d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); + d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); + d28d29 = GiSimdFmaLane(d28d29, d8d9, d4d5, 0); + d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); + d30d31 = GiSimdFmaLane(d30d31, d8d9, d6d7, 0); + d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); + d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); + d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); + d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); + d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); + B = B + LDB; + } + GiStoreFloat32(C, d16d17); + C = C + 4; + GiStoreFloat32(C, d18d19); + C = C + 4; + GiStoreFloat32(C, d20d21); + C = C + 4; + GiStoreFloat32(C, d22d23); + C = C + 4; + GiStoreFloat32(C, d24d25); + C = C + 4; + GiStoreFloat32(C, d26d27); + C = C + 4; + GiStoreFloat32(C, d28d29); + C = C + 4; + GiStoreFloat32(C, d30d31); + C = C + 4; +} + +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_4x8); + +void gi_sgemm_nopack_4x8::kern( + const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, + size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { + constexpr size_t MB = 4; + constexpr size_t KB = 4; + constexpr size_t NB = 8; + constexpr size_t NB_HALF = 4; + + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); + + for (size_t m = 0; m < M; m += MB) { + float* output = C + (m / MB) * LDC; + const float* cur_B = B; + size_t n = 0; + for (; n + NB - 1 < N; n += NB) { + kern_4x8(A, cur_B, LDB, K, output); + cur_B += KB * NB; + output += MB * NB; + } + if (N - n >= 4) { + kern_4x4(A, cur_B, LDB, K, output); + cur_B += KB * NB_HALF; + output += MB * NB_HALF; + n += 4; + } + while (n < N) { + kern_4x1(A, cur_B, LDB, K, output); + cur_B += KB; + output += MB; + n++; + } + A += LDA; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index 6ac7778cdb0493c3e91a5227416e850db5bd7df1..8db8f19e51312ea83f4b538694a9a1c6757018f7 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -36,6 +36,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF32K8x12x1 f32_k8x12x1; AlgoGemv gemv; AlgoNaive naive; + AlgoF32GiGemvMK4 f32_gemv_mk4; + AlgoF32GiMK4_4x8 f32_mk4_4x8; SmallVector m_all_algos; AlgoBase::Mapper m_all_algos_map; @@ -44,6 +46,8 @@ public: m_all_algos.emplace_back(&gemv); m_all_algos.emplace_back(&f32_k8x12x1); m_all_algos.emplace_back(&naive); + m_all_algos.emplace_back(&f32_gemv_mk4); + m_all_algos.emplace_back(&f32_mk4_4x8); for (auto&& algo : m_all_algos) { m_all_algos_map.emplace(algo->info().desc, algo); } diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index b74de9369bfb7c4c3b49724e11d011d8b5e946d8..99868e0b8afbc8d2b9f9bc3ca0eb48bcceb18442 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -112,6 +112,8 @@ public: FB_F32K8x12x1 = 1 << 0, FB_GEMV, FB_NAIVE, + FB_GI_F32_GEMV_MK4, + FB_GI_F32_MK4_4x8, #if MEGDNN_X86 //! x86 @@ -131,7 +133,6 @@ public: ARM_COMMON_INT8X8X32_GEMV, ARM_COMMON_INT8X8X32_GEMV_MK4, ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, - ARM_COMMON_F32_GEMV_MK4, ARM_COMMON_F16_GEMV, ARM_COMMON_GEVM, #if MEGDNN_AARCH64 @@ -236,7 +237,9 @@ public: }; private: - class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 + class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 + class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 + class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 class AlgoGemv; class AlgoNaive; class AlgoPack; diff --git a/dnn/test/fallback/matrix_mul.cpp b/dnn/test/fallback/matrix_mul.cpp index 638146288ec0a84536cdbc57d2392e3d727cf70f..33b153d4c63646d8a55a9e77efc55914f87a19a3 100644 --- a/dnn/test/fallback/matrix_mul.cpp +++ b/dnn/test/fallback/matrix_mul.cpp @@ -45,6 +45,13 @@ TEST_F(FALLBACK, MATRIX_MUL) { checker.execl({AL, BL, CL}); } } + +TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); +} + TEST_F(FALLBACK, MATRIX_MUL_RECORD) { TaskRecordChecker checker(1); using Param = MatrixMul::Param;