From a450d0f5b8be8d5530d04069e3dc1da84c86df60 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 31 Jan 2023 17:07:19 +0800 Subject: [PATCH] feat(fallback): add fp16 mk8 8x8 matmul GitOrigin-RevId: 1a50a8a7be433f3ba688eeaa1af460458e5f0b53 --- dnn/src/fallback/matrix_mul/algos.cpp | 59 +++ dnn/src/fallback/matrix_mul/algos.h | 15 + .../fallback/matrix_mul/generic_strategy.h | 6 + .../matrix_mul/gi/fp16/strategy_mk8_8x8.cpp | 485 ++++++++++++++++++ dnn/src/fallback/matrix_mul/opr_impl.cpp | 7 + dnn/src/fallback/matrix_mul/opr_impl.h | 2 + dnn/test/fallback/matrix_mul.cpp | 18 + 7 files changed, 592 insertions(+) create mode 100644 dnn/src/fallback/matrix_mul/gi/fp16/strategy_mk8_8x8.cpp diff --git a/dnn/src/fallback/matrix_mul/algos.cpp b/dnn/src/fallback/matrix_mul/algos.cpp index 52f104a2b..a27086272 100644 --- a/dnn/src/fallback/matrix_mul/algos.cpp +++ b/dnn/src/fallback/matrix_mul/algos.cpp @@ -297,7 +297,66 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( const KernSizeParam&) const { return gi_f32_mk4_4x8_kern; } +#if defined(GI_SUPPORT_F16) +/* ================== F16 Gemm MK8 gi algo ================== */ +namespace { +void gi_f16_mk8_8x8_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_fb_gi_matmul_kern, midout_iv("gi_f16_mk8_8x8_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + matmul::fallback::gi_sgemm_nopack_mk8_8x8_fp16 strategy(A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved< + matmul::fallback::gi_sgemm_nopack_mk8_8x8_fp16, false>( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); + } + MIDOUT_END(); +} + +} // anonymous namespace +bool MatrixMulImpl::AlgoF16GiMK8_8x8::usable( + const KernSizeParam& kern_size_param) const { + constexpr size_t MB = 8; + constexpr size_t KB = 8; + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK8 && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float16() && !kern_size_param.trA && + !kern_size_param.trB && kern_size_param.M % MB == 0 && + kern_size_param.K % KB == 0; +} +size_t MatrixMulImpl::AlgoF16GiMK8_8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN( + megdnn_fb_gi_matmul_kern, + midout_iv("AlgoF16GiMK8_8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + matmul::fallback::gi_sgemm_nopack_mk8_8x8_fp16 strategy(A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved< + matmul::fallback::gi_sgemm_nopack_mk8_8x8_fp16, false>( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); + return 0; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16GiMK8_8x8::get_kern( + const KernSizeParam&) const { + return gi_f16_mk8_8x8_kern; +} +#endif /* ===================== F32 algo gi mk4 pack K4x12 ===================== */ namespace { void f32_gi_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { diff --git a/dnn/src/fallback/matrix_mul/algos.h b/dnn/src/fallback/matrix_mul/algos.h index fcf0ff37b..7f29a44d4 100644 --- a/dnn/src/fallback/matrix_mul/algos.h +++ b/dnn/src/fallback/matrix_mul/algos.h @@ -2,6 +2,7 @@ #include #include "src/common/algo_base.h" +#include "src/fallback/general_intrinsic/gi_common.h" #include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/opr_impl.h" @@ -97,6 +98,20 @@ public: MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) }; +#if defined(GI_SUPPORT_F16) +class MatrixMulImpl::AlgoF16GiMK8_8x8 final : public AlgoBase { +public: + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "FB_GI_F16_MK8_8x8"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 8, AlgoDataType::FLOAT16, MK8) + MEGDNN_DECL_ALGO_TYPE(FB_GI_F16_MK8_8x8) +}; +#endif + class MatrixMulImpl::AlgoF32GiMK4Pack4x12 final : public AlgoBase { public: AlgoAttribute attribute() const override { diff --git a/dnn/src/fallback/matrix_mul/generic_strategy.h b/dnn/src/fallback/matrix_mul/generic_strategy.h index 3b28c95e5..fbc7d098a 100644 --- a/dnn/src/fallback/matrix_mul/generic_strategy.h +++ b/dnn/src/fallback/matrix_mul/generic_strategy.h @@ -1,4 +1,5 @@ #pragma once +#include "src/fallback/general_intrinsic/gi_common.h" #include "src/fallback/matrix_mul/gemm_common.h" namespace megdnn { @@ -8,6 +9,11 @@ namespace fallback { MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); MEGDNN_REG_GEMM_STRATEGY_NOPACK( float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); +#if defined(GI_SUPPORT_F16) +MEGDNN_REG_GEMM_STRATEGY_NOPACK( + dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, + gi_sgemm_nopack_mk8_8x8_fp16); +#endif MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12); MEGDNN_REG_GEMM_STRATEGY( float, float, float, 4, 12, 1, false, false, gi_sgemm_mk4_pack_4x12); diff --git a/dnn/src/fallback/matrix_mul/gi/fp16/strategy_mk8_8x8.cpp b/dnn/src/fallback/matrix_mul/gi/fp16/strategy_mk8_8x8.cpp new file mode 100644 index 000000000..03378ddfb --- /dev/null +++ b/dnn/src/fallback/matrix_mul/gi/fp16/strategy_mk8_8x8.cpp @@ -0,0 +1,485 @@ +#include "src/fallback/general_intrinsic/gi_float16.h" + +#if defined(GI_SUPPORT_F16) +#include "src/common/utils.h" +#include "src/fallback/matrix_mul/generic_strategy.h" + +using namespace megdnn; +using namespace matmul::fallback; + +namespace { + +#define MLA GiMultiplyAddScalarFloat16 +void kern_8x1( + const gi_float16_t* A, const gi_float16_t* B, size_t LDB, size_t K, + gi_float16_t* C) { + LDB = LDB - 8; + K = K - 8; + + GI_FLOAT16_t d0 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d1 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d2 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d3 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d4 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d5 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d6 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d7 = GiLoadFloat16(A); + A = A + 8; + + GI_FLOAT16_t vfzero = GiBroadcastFloat16(0.0); + + GI_FLOAT16_t d8 = MLA(vfzero, d0, *(B)); + d8 = MLA(d8, d1, *(B + 1)); + d8 = MLA(d8, d2, *(B + 2)); + d8 = MLA(d8, d3, *(B + 3)); + d8 = MLA(d8, d4, *(B + 4)); + d8 = MLA(d8, d5, *(B + 5)); + d8 = MLA(d8, d6, *(B + 6)); + d8 = MLA(d8, d7, *(B + 7)); + B += 8; + + B += LDB; + + for (; K > 0; K -= 8) { + d0 = GiLoadFloat16(A); + A = A + 8; + d1 = GiLoadFloat16(A); + A = A + 8; + d2 = GiLoadFloat16(A); + A = A + 8; + d3 = GiLoadFloat16(A); + A = A + 8; + d4 = GiLoadFloat16(A); + A = A + 8; + d5 = GiLoadFloat16(A); + A = A + 8; + d6 = GiLoadFloat16(A); + A = A + 8; + d7 = GiLoadFloat16(A); + A = A + 8; + + d8 = MLA(d8, d0, *(B)); + d8 = MLA(d8, d1, *(B + 1)); + d8 = MLA(d8, d2, *(B + 2)); + d8 = MLA(d8, d3, *(B + 3)); + d8 = MLA(d8, d4, *(B + 4)); + d8 = MLA(d8, d5, *(B + 5)); + d8 = MLA(d8, d6, *(B + 6)); + d8 = MLA(d8, d7, *(B + 7)); + B += 8; + + B += LDB; + } + + GiStoreFloat16(C, d8); +} + +void kern_8x4( + const gi_float16_t* A, const gi_float16_t* B, size_t LDB, size_t K, + gi_float16_t* C) { + LDB = LDB - 32; + K = K - 8; + + GI_FLOAT16_t d0 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d1 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d2 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d3 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d4 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d5 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d6 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d7 = GiLoadFloat16(A); + A = A + 8; + + GI_FLOAT16_t vfzero = GiBroadcastFloat16(0.0); + + GI_FLOAT16_t d8 = MLA(vfzero, d0, *(B)); + d8 = MLA(d8, d1, *(B + 1)); + d8 = MLA(d8, d2, *(B + 2)); + d8 = MLA(d8, d3, *(B + 3)); + d8 = MLA(d8, d4, *(B + 4)); + d8 = MLA(d8, d5, *(B + 5)); + d8 = MLA(d8, d6, *(B + 6)); + d8 = MLA(d8, d7, *(B + 7)); + B += 8; + + GI_FLOAT16_t d9 = MLA(vfzero, d0, *(B)); + d9 = MLA(d9, d1, *(B + 1)); + d9 = MLA(d9, d2, *(B + 2)); + d9 = MLA(d9, d3, *(B + 3)); + d9 = MLA(d9, d4, *(B + 4)); + d9 = MLA(d9, d5, *(B + 5)); + d9 = MLA(d9, d6, *(B + 6)); + d9 = MLA(d9, d7, *(B + 7)); + B += 8; + + GI_FLOAT16_t d10 = MLA(vfzero, d0, *(B)); + d10 = MLA(d10, d1, *(B + 1)); + d10 = MLA(d10, d2, *(B + 2)); + d10 = MLA(d10, d3, *(B + 3)); + d10 = MLA(d10, d4, *(B + 4)); + d10 = MLA(d10, d5, *(B + 5)); + d10 = MLA(d10, d6, *(B + 6)); + d10 = MLA(d10, d7, *(B + 7)); + B += 8; + + GI_FLOAT16_t d11 = MLA(vfzero, d0, *(B)); + d11 = MLA(d11, d1, *(B + 1)); + d11 = MLA(d11, d2, *(B + 2)); + d11 = MLA(d11, d3, *(B + 3)); + d11 = MLA(d11, d4, *(B + 4)); + d11 = MLA(d11, d5, *(B + 5)); + d11 = MLA(d11, d6, *(B + 6)); + d11 = MLA(d11, d7, *(B + 7)); + B += 8; + + B += LDB; + + for (; K > 0; K -= 8) { + d0 = GiLoadFloat16(A); + A = A + 8; + d1 = GiLoadFloat16(A); + A = A + 8; + d2 = GiLoadFloat16(A); + A = A + 8; + d3 = GiLoadFloat16(A); + A = A + 8; + d4 = GiLoadFloat16(A); + A = A + 8; + d5 = GiLoadFloat16(A); + A = A + 8; + d6 = GiLoadFloat16(A); + A = A + 8; + d7 = GiLoadFloat16(A); + A = A + 8; + + d8 = MLA(d8, d0, *(B)); + d8 = MLA(d8, d1, *(B + 1)); + d8 = MLA(d8, d2, *(B + 2)); + d8 = MLA(d8, d3, *(B + 3)); + d8 = MLA(d8, d4, *(B + 4)); + d8 = MLA(d8, d5, *(B + 5)); + d8 = MLA(d8, d6, *(B + 6)); + d8 = MLA(d8, d7, *(B + 7)); + B += 8; + + d9 = MLA(d9, d0, *(B)); + d9 = MLA(d9, d1, *(B + 1)); + d9 = MLA(d9, d2, *(B + 2)); + d9 = MLA(d9, d3, *(B + 3)); + d9 = MLA(d9, d4, *(B + 4)); + d9 = MLA(d9, d5, *(B + 5)); + d9 = MLA(d9, d6, *(B + 6)); + d9 = MLA(d9, d7, *(B + 7)); + B += 8; + + d10 = MLA(d10, d0, *(B)); + d10 = MLA(d10, d1, *(B + 1)); + d10 = MLA(d10, d2, *(B + 2)); + d10 = MLA(d10, d3, *(B + 3)); + d10 = MLA(d10, d4, *(B + 4)); + d10 = MLA(d10, d5, *(B + 5)); + d10 = MLA(d10, d6, *(B + 6)); + d10 = MLA(d10, d7, *(B + 7)); + B += 8; + + d11 = MLA(d11, d0, *(B)); + d11 = MLA(d11, d1, *(B + 1)); + d11 = MLA(d11, d2, *(B + 2)); + d11 = MLA(d11, d3, *(B + 3)); + d11 = MLA(d11, d4, *(B + 4)); + d11 = MLA(d11, d5, *(B + 5)); + d11 = MLA(d11, d6, *(B + 6)); + d11 = MLA(d11, d7, *(B + 7)); + B += 8; + + B += LDB; + } + + GiStoreFloat16(C, d8); + C = C + 8; + GiStoreFloat16(C, d9); + C = C + 8; + GiStoreFloat16(C, d10); + C = C + 8; + GiStoreFloat16(C, d11); + C = C + 8; +} + +void kern_8x8( + const gi_float16_t* A, const gi_float16_t* B, size_t LDB, size_t K, + gi_float16_t* C) { + LDB -= 64; + K = K - 8; + + GI_FLOAT16_t d0 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d1 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d2 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d3 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d4 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d5 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d6 = GiLoadFloat16(A); + A = A + 8; + GI_FLOAT16_t d7 = GiLoadFloat16(A); + A = A + 8; + + GI_FLOAT16_t vfzero = GiZeroFloat16(); + + GI_FLOAT16_t d8 = MLA(vfzero, d0, *(B)); + d8 = MLA(d8, d1, *(B + 1)); + d8 = MLA(d8, d2, *(B + 2)); + d8 = MLA(d8, d3, *(B + 3)); + d8 = MLA(d8, d4, *(B + 4)); + d8 = MLA(d8, d5, *(B + 5)); + d8 = MLA(d8, d6, *(B + 6)); + d8 = MLA(d8, d7, *(B + 7)); + B = B + 8; + + GI_FLOAT16_t d9 = MLA(vfzero, d0, *(B)); + d9 = MLA(d9, d1, *(B + 1)); + d9 = MLA(d9, d2, *(B + 2)); + d9 = MLA(d9, d3, *(B + 3)); + d9 = MLA(d9, d4, *(B + 4)); + d9 = MLA(d9, d5, *(B + 5)); + d9 = MLA(d9, d6, *(B + 6)); + d9 = MLA(d9, d7, *(B + 7)); + B = B + 8; + + GI_FLOAT16_t d10 = MLA(vfzero, d0, *(B)); + d10 = MLA(d10, d1, *(B + 1)); + d10 = MLA(d10, d2, *(B + 2)); + d10 = MLA(d10, d3, *(B + 3)); + d10 = MLA(d10, d4, *(B + 4)); + d10 = MLA(d10, d5, *(B + 5)); + d10 = MLA(d10, d6, *(B + 6)); + d10 = MLA(d10, d7, *(B + 7)); + B = B + 8; + + GI_FLOAT16_t d11 = MLA(vfzero, d0, *(B)); + d11 = MLA(d11, d1, *(B + 1)); + d11 = MLA(d11, d2, *(B + 2)); + d11 = MLA(d11, d3, *(B + 3)); + d11 = MLA(d11, d4, *(B + 4)); + d11 = MLA(d11, d5, *(B + 5)); + d11 = MLA(d11, d6, *(B + 6)); + d11 = MLA(d11, d7, *(B + 7)); + B = B + 8; + + GI_FLOAT16_t d12 = MLA(vfzero, d0, *(B)); + d12 = MLA(d12, d1, *(B + 1)); + d12 = MLA(d12, d2, *(B + 2)); + d12 = MLA(d12, d3, *(B + 3)); + d12 = MLA(d12, d4, *(B + 4)); + d12 = MLA(d12, d5, *(B + 5)); + d12 = MLA(d12, d6, *(B + 6)); + d12 = MLA(d12, d7, *(B + 7)); + B = B + 8; + + GI_FLOAT16_t d13 = MLA(vfzero, d0, *(B)); + d13 = MLA(d13, d1, *(B + 1)); + d13 = MLA(d13, d2, *(B + 2)); + d13 = MLA(d13, d3, *(B + 3)); + d13 = MLA(d13, d4, *(B + 4)); + d13 = MLA(d13, d5, *(B + 5)); + d13 = MLA(d13, d6, *(B + 6)); + d13 = MLA(d13, d7, *(B + 7)); + B = B + 8; + + GI_FLOAT16_t d14 = MLA(vfzero, d0, *(B)); + d14 = MLA(d14, d1, *(B + 1)); + d14 = MLA(d14, d2, *(B + 2)); + d14 = MLA(d14, d3, *(B + 3)); + d14 = MLA(d14, d4, *(B + 4)); + d14 = MLA(d14, d5, *(B + 5)); + d14 = MLA(d14, d6, *(B + 6)); + d14 = MLA(d14, d7, *(B + 7)); + B = B + 8; + + GI_FLOAT16_t d15 = MLA(vfzero, d0, *(B)); + d15 = MLA(d15, d1, *(B + 1)); + d15 = MLA(d15, d2, *(B + 2)); + d15 = MLA(d15, d3, *(B + 3)); + d15 = MLA(d15, d4, *(B + 4)); + d15 = MLA(d15, d5, *(B + 5)); + d15 = MLA(d15, d6, *(B + 6)); + d15 = MLA(d15, d7, *(B + 7)); + B = B + 8; + + B = B + LDB; + for (; K > 0; K -= 8) { + d0 = GiLoadFloat16(A); + A = A + 8; + d1 = GiLoadFloat16(A); + A = A + 8; + d2 = GiLoadFloat16(A); + A = A + 8; + d3 = GiLoadFloat16(A); + A = A + 8; + d4 = GiLoadFloat16(A); + A = A + 8; + d5 = GiLoadFloat16(A); + A = A + 8; + d6 = GiLoadFloat16(A); + A = A + 8; + d7 = GiLoadFloat16(A); + A = A + 8; + + d8 = MLA(d8, d0, *(B)); + d8 = MLA(d8, d1, *(B + 1)); + d8 = MLA(d8, d2, *(B + 2)); + d8 = MLA(d8, d3, *(B + 3)); + d8 = MLA(d8, d4, *(B + 4)); + d8 = MLA(d8, d5, *(B + 5)); + d8 = MLA(d8, d6, *(B + 6)); + d8 = MLA(d8, d7, *(B + 7)); + B = B + 8; + + d9 = MLA(d9, d0, *(B)); + d9 = MLA(d9, d1, *(B + 1)); + d9 = MLA(d9, d2, *(B + 2)); + d9 = MLA(d9, d3, *(B + 3)); + d9 = MLA(d9, d4, *(B + 4)); + d9 = MLA(d9, d5, *(B + 5)); + d9 = MLA(d9, d6, *(B + 6)); + d9 = MLA(d9, d7, *(B + 7)); + B = B + 8; + + d10 = MLA(d10, d0, *(B)); + d10 = MLA(d10, d1, *(B + 1)); + d10 = MLA(d10, d2, *(B + 2)); + d10 = MLA(d10, d3, *(B + 3)); + d10 = MLA(d10, d4, *(B + 4)); + d10 = MLA(d10, d5, *(B + 5)); + d10 = MLA(d10, d6, *(B + 6)); + d10 = MLA(d10, d7, *(B + 7)); + B = B + 8; + + d11 = MLA(d11, d0, *(B)); + d11 = MLA(d11, d1, *(B + 1)); + d11 = MLA(d11, d2, *(B + 2)); + d11 = MLA(d11, d3, *(B + 3)); + d11 = MLA(d11, d4, *(B + 4)); + d11 = MLA(d11, d5, *(B + 5)); + d11 = MLA(d11, d6, *(B + 6)); + d11 = MLA(d11, d7, *(B + 7)); + B = B + 8; + + d12 = MLA(d12, d0, *(B)); + d12 = MLA(d12, d1, *(B + 1)); + d12 = MLA(d12, d2, *(B + 2)); + d12 = MLA(d12, d3, *(B + 3)); + d12 = MLA(d12, d4, *(B + 4)); + d12 = MLA(d12, d5, *(B + 5)); + d12 = MLA(d12, d6, *(B + 6)); + d12 = MLA(d12, d7, *(B + 7)); + B = B + 8; + + d13 = MLA(d13, d0, *(B)); + d13 = MLA(d13, d1, *(B + 1)); + d13 = MLA(d13, d2, *(B + 2)); + d13 = MLA(d13, d3, *(B + 3)); + d13 = MLA(d13, d4, *(B + 4)); + d13 = MLA(d13, d5, *(B + 5)); + d13 = MLA(d13, d6, *(B + 6)); + d13 = MLA(d13, d7, *(B + 7)); + B = B + 8; + + d14 = MLA(d14, d0, *(B)); + d14 = MLA(d14, d1, *(B + 1)); + d14 = MLA(d14, d2, *(B + 2)); + d14 = MLA(d14, d3, *(B + 3)); + d14 = MLA(d14, d4, *(B + 4)); + d14 = MLA(d14, d5, *(B + 5)); + d14 = MLA(d14, d6, *(B + 6)); + d14 = MLA(d14, d7, *(B + 7)); + B = B + 8; + + d15 = MLA(d15, d0, *(B)); + d15 = MLA(d15, d1, *(B + 1)); + d15 = MLA(d15, d2, *(B + 2)); + d15 = MLA(d15, d3, *(B + 3)); + d15 = MLA(d15, d4, *(B + 4)); + d15 = MLA(d15, d5, *(B + 5)); + d15 = MLA(d15, d6, *(B + 6)); + d15 = MLA(d15, d7, *(B + 7)); + B = B + 8 + LDB; + } + GiStoreFloat16(C, d8); + C = C + 8; + GiStoreFloat16(C, d9); + C = C + 8; + GiStoreFloat16(C, d10); + C = C + 8; + GiStoreFloat16(C, d11); + C = C + 8; + GiStoreFloat16(C, d12); + C = C + 8; + GiStoreFloat16(C, d13); + C = C + 8; + GiStoreFloat16(C, d14); + C = C + 8; + GiStoreFloat16(C, d15); + C = C + 8; +} + +#undef MLA +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_mk8_8x8_fp16); + +void gi_sgemm_nopack_mk8_8x8_fp16::kern( + const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C, + size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA, + bool trB) const { + constexpr size_t MB = 8; + constexpr size_t KB = 8; + constexpr size_t NB = 8; + constexpr size_t NB_HALF = 4; + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); + + for (size_t m = 0; m < M; m += MB) { + gi_float16_t* output = reinterpret_cast(C) + (m / MB) * LDC; + const gi_float16_t* cur_B = reinterpret_cast(B); + size_t n = 0; + for (; n + NB - 1 < N; n += NB) { + kern_8x8(reinterpret_cast(A), cur_B, LDB, K, output); + cur_B += KB * NB; + output += MB * NB; + } + if (N - n >= 4) { + kern_8x4(reinterpret_cast(A), cur_B, LDB, K, output); + cur_B += KB * NB_HALF; + output += MB * NB_HALF; + n += 4; + } + while (n < N) { + kern_8x1(reinterpret_cast(A), cur_B, LDB, K, output); + cur_B += KB; + output += MB; + n++; + } + A += LDA; + } +} + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index 8343fbb26..a57be5c53 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -5,6 +5,7 @@ #include "src/common/algo_chooser.h" #include "src/common/metahelper.h" #include "src/common/utils.h" +#include "src/fallback/general_intrinsic/gi_common.h" #include "src/fallback/matrix_mul/algos.h" #include "src/fallback/matrix_mul/gemm_impl.h" #include "src/fallback/matrix_mul/generic_strategy.h" @@ -30,6 +31,9 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF32GiMK4_4x8 f32_mk4_4x8; AlgoF32GiMK4Pack4x12 f32_mk4_gi_pack_4x12; AlgoF32Gi4x12 f32_4x8; +#if defined(GI_SUPPORT_F16) + AlgoF16GiMK8_8x8 f16_mk8_8x8; +#endif SmallVector m_all_algos; AlgoBase::Mapper m_all_algos_map; @@ -39,6 +43,9 @@ public: m_all_algos.emplace_back(&f32_mk4_4x8); m_all_algos.emplace_back(&f32_mk4_gi_pack_4x12); m_all_algos.emplace_back(&f32_4x8); +#if defined(GI_SUPPORT_F16) + m_all_algos.emplace_back(&f16_mk8_8x8); +#endif m_all_algos.emplace_back(&gemv); m_all_algos.emplace_back(&f32_k8x12x1); m_all_algos.emplace_back(&naive); diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index b0fed7b54..8dd29e62b 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -103,6 +103,7 @@ public: FB_NAIVE, FB_GI_F32_GEMV_MK4, FB_GI_F32_MK4_4x8, + FB_GI_F16_MK8_8x8, FB_GI_F32_MK4_PACK_4x12, FB_GI_F32_4x12, @@ -237,6 +238,7 @@ private: class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 class AlgoF32GiMK4Pack4x12; // fallback F32 gi Gemm pack NCHW44 class AlgoF32Gi4x12; // fallback F32 gi Gemm + class AlgoF16GiMK8_8x8; class AlgoGemv; class AlgoNaive; class AlgoPack; diff --git a/dnn/test/fallback/matrix_mul.cpp b/dnn/test/fallback/matrix_mul.cpp index 3a8531dbb..b03ae5142 100644 --- a/dnn/test/fallback/matrix_mul.cpp +++ b/dnn/test/fallback/matrix_mul.cpp @@ -1,4 +1,5 @@ #include "test/common/matrix_mul.h" +#include "src/fallback/general_intrinsic/gi_common.h" #include "test/common/checker.h" #include "test/common/rng.h" #include "test/common/task_record_check.h" @@ -42,6 +43,14 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); } +#if defined(GI_SUPPORT_F16) +TEST_F(FALLBACK, MATRIX_MUL_FP16_MK8_GI) { + matrix_mul::check_matrix_mul( + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), + "FB_GI_F16_MK8_8x8", param::MatrixMul::Format::MK8, 1); +} +#endif + TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), @@ -183,6 +192,15 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F32_MK4_4x8) { "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4); } +#if defined(GI_SUPPORT_F16) +TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F16_MK8_8x8) { + auto args = matrix_mul::get_benchmark_matmul_args(); + matrix_mul::benchmark_single_algo( + handle(), args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, + "FB_GI_F16_MK8_8x8", param::MatrixMul::Format::MK8); +} +#endif + #endif } // namespace test } // namespace megdnn -- GitLab