From bd50e457ee04a22082526c08c4b2fc30db089c9a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 Jul 2022 18:34:18 +0800 Subject: [PATCH] feat(x86/rvv): make MATRIX_MUL_GI_F32_4x12 and FP32_GEMV_MK4_GI adapt to vv and vf model GitOrigin-RevId: 691434c59843e20f15caca366dce0d7e445325a9 --- .../matrix_mul/gi/fp32/exec_sgemv.cpp | 12 +- .../matrix_mul/gi/fp32/strategy_4x12.cpp | 229 +++++++++++------- 2 files changed, 154 insertions(+), 87 deletions(-) diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp index adeb9c936..476fed8af 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp +++ b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp @@ -31,15 +31,22 @@ void sgemv_gi_naive_n_mk4( auto Bptr = B; size_t k = 0; while (k < K) { +#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) +//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use +//! GiMultiplyAddScalarFloat32 +#define MLA(a, b, c, d) GiMultiplyAddScalarFloat32(a, b, *(c + d)) + const float* b = Bptr; +#else +#define MLA(a, b, c, d) GiSimdFmaLane(a, b, c, d) GI_FLOAT32_t b = GiLoadFloat32(Bptr); +#endif GI_FLOAT32_V4_t a; #define LOAD_A(step) GiSetSubVectorFloat32V4(a, step, GiLoadFloat32(Aptr0 + step * 4)); UNROLL_CALL_RAW(4, LOAD_A) #undef LOAD_A #define COMPT(step) \ - t = GiSimdFmaLane( \ - GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \ + t = MLA(GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \ step % 4); \ GiSetSubVectorFloat32V4(c, step, t); @@ -49,6 +56,7 @@ void sgemv_gi_naive_n_mk4( Bptr += Bstride; Aptr0 += PACK_SIZE * PACK_SIZE; k += PACK_SIZE; +#undef MLA } #define ADD_C(step, stride) \ diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp index 2a842ff43..0944ab3b4 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp +++ b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp @@ -19,6 +19,17 @@ using namespace matmul::fallback; namespace { +#undef PREFER_VF +#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) +#define PREFER_VF +#endif + +#if defined(PREFER_VF) +#define MLA(a, b, c, d) GiMultiplyAddScalarFloat32(a, b, *(c + d)) +#else +#define MLA(a, b, c, d) GiSimdFmaLane(a, b, c, d) +#endif + void kern_4x12( const float* packA, const float* packB, int K, float* output, int LDC, bool is_first_k, int m_remain) { @@ -32,8 +43,13 @@ void kern_4x12( float* r2 = r1 + LDC; float* r3 = r2 + LDC; - GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, - d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; +#if defined(PREFER_VF) + const float* d0d1; +#else + GI_FLOAT32_t d0d1; +#endif + GI_FLOAT32_t d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, d20d21, + d22d23, d24d25, d26d27, d28d29, d30d31; if (is_first_k) { d8d9 = GiBroadcastFloat32(0.0f); @@ -99,23 +115,31 @@ void kern_4x12( b_ptr = b_ptr + 4; for (; K > 0; K--) { +#if defined(PREFER_VF) + d0d1 = a_ptr; +#else d0d1 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); - d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); - d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); - d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); - d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); - d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); - d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); - d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); - d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); - d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); - d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); - + d8d9 = MLA(d8d9, d2d3, d0d1, 0); + d10d11 = MLA(d10d11, d4d5, d0d1, 0); + d12d13 = MLA(d12d13, d6d7, d0d1, 0); + d14d15 = MLA(d14d15, d2d3, d0d1, 1); + d16d17 = MLA(d16d17, d4d5, d0d1, 1); + d18d19 = MLA(d18d19, d6d7, d0d1, 1); + d20d21 = MLA(d20d21, d2d3, d0d1, 2); + d22d23 = MLA(d22d23, d4d5, d0d1, 2); + d24d25 = MLA(d24d25, d6d7, d0d1, 2); + d26d27 = MLA(d26d27, d2d3, d0d1, 3); + d28d29 = MLA(d28d29, d4d5, d0d1, 3); + d30d31 = MLA(d30d31, d6d7, d0d1, 3); + +#if defined(PREFER_VF) + d0d1 = a_ptr; +#else d0d1 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; d2d3 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; @@ -124,18 +148,18 @@ void kern_4x12( d6d7 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); - d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); - d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); - d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); - d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); - d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); - d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); - d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); - d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); - d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); - d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); + d8d9 = MLA(d8d9, d2d3, d0d1, 0); + d10d11 = MLA(d10d11, d4d5, d0d1, 0); + d12d13 = MLA(d12d13, d6d7, d0d1, 0); + d14d15 = MLA(d14d15, d2d3, d0d1, 1); + d16d17 = MLA(d16d17, d4d5, d0d1, 1); + d18d19 = MLA(d18d19, d6d7, d0d1, 1); + d20d21 = MLA(d20d21, d2d3, d0d1, 2); + d22d23 = MLA(d22d23, d4d5, d0d1, 2); + d24d25 = MLA(d24d25, d6d7, d0d1, 2); + d26d27 = MLA(d26d27, d2d3, d0d1, 3); + d28d29 = MLA(d28d29, d4d5, d0d1, 3); + d30d31 = MLA(d30d31, d6d7, d0d1, 3); d2d3 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; @@ -146,40 +170,52 @@ void kern_4x12( } if (1 == oddk) { +#if defined(PREFER_VF) + d0d1 = a_ptr; +#else d0d1 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); - d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); - d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); - d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); - d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); - d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); - d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); - d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); - d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); - d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); - d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); + d8d9 = MLA(d8d9, d2d3, d0d1, 0); + d10d11 = MLA(d10d11, d4d5, d0d1, 0); + d12d13 = MLA(d12d13, d6d7, d0d1, 0); + d14d15 = MLA(d14d15, d2d3, d0d1, 1); + d16d17 = MLA(d16d17, d4d5, d0d1, 1); + d18d19 = MLA(d18d19, d6d7, d0d1, 1); + d20d21 = MLA(d20d21, d2d3, d0d1, 2); + d22d23 = MLA(d22d23, d4d5, d0d1, 2); + d24d25 = MLA(d24d25, d6d7, d0d1, 2); + d26d27 = MLA(d26d27, d2d3, d0d1, 3); + d28d29 = MLA(d28d29, d4d5, d0d1, 3); + d30d31 = MLA(d30d31, d6d7, d0d1, 3); } else { +#if defined(PREFER_VF) + d0d1 = a_ptr; +#else d0d1 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); - d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); - d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); - d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); - d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); - d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); - d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); - d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); - d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); - d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); - d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); - + d8d9 = MLA(d8d9, d2d3, d0d1, 0); + d10d11 = MLA(d10d11, d4d5, d0d1, 0); + d12d13 = MLA(d12d13, d6d7, d0d1, 0); + d14d15 = MLA(d14d15, d2d3, d0d1, 1); + d16d17 = MLA(d16d17, d4d5, d0d1, 1); + d18d19 = MLA(d18d19, d6d7, d0d1, 1); + d20d21 = MLA(d20d21, d2d3, d0d1, 2); + d22d23 = MLA(d22d23, d4d5, d0d1, 2); + d24d25 = MLA(d24d25, d6d7, d0d1, 2); + d26d27 = MLA(d26d27, d2d3, d0d1, 3); + d28d29 = MLA(d28d29, d4d5, d0d1, 3); + d30d31 = MLA(d30d31, d6d7, d0d1, 3); + +#if defined(PREFER_VF) + d0d1 = a_ptr; +#else d0d1 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; d2d3 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; @@ -188,18 +224,18 @@ void kern_4x12( d6d7 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); - d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); - d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); - d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); - d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); - d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); - d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); - d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); - d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); - d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); - d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); + d8d9 = MLA(d8d9, d2d3, d0d1, 0); + d10d11 = MLA(d10d11, d4d5, d0d1, 0); + d12d13 = MLA(d12d13, d6d7, d0d1, 0); + d14d15 = MLA(d14d15, d2d3, d0d1, 1); + d16d17 = MLA(d16d17, d4d5, d0d1, 1); + d18d19 = MLA(d18d19, d6d7, d0d1, 1); + d20d21 = MLA(d20d21, d2d3, d0d1, 2); + d22d23 = MLA(d22d23, d4d5, d0d1, 2); + d24d25 = MLA(d24d25, d6d7, d0d1, 2); + d26d27 = MLA(d26d27, d2d3, d0d1, 3); + d28d29 = MLA(d28d29, d4d5, d0d1, 3); + d30d31 = MLA(d30d31, d6d7, d0d1, 3); } if (m_remain == 4) { @@ -259,7 +295,13 @@ void kern_4x4( float* r3 = r2 + LDC; size_t d_size = sizeof(float); - GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; +#if defined(PREFER_VF) + const float* d0d1; + const float* d2d3; +#else + GI_FLOAT32_t d0d1, d2d3; +#endif + GI_FLOAT32_t d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; float tmp[4]; if (is_first_k) { d8d9 = GiBroadcastFloat32(0.0f); @@ -412,54 +454,70 @@ void kern_4x4( } } +#if defined(PREFER_VF) + d0d1 = a_ptr; +#else d0d1 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; d4d5 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; for (; K > 0; K--) { +#if defined(PREFER_VF) + d2d3 = a_ptr; +#else d2d3 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; d6d7 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); - d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); - d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); - d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); + d8d9 = MLA(d8d9, d4d5, d0d1, 0); + d10d11 = MLA(d10d11, d4d5, d0d1, 1); + d12d13 = MLA(d12d13, d4d5, d0d1, 2); + d14d15 = MLA(d14d15, d4d5, d0d1, 3); +#if defined(PREFER_VF) + d0d1 = a_ptr; +#else d0d1 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; d4d5 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); - d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); - d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); - d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); + d8d9 = MLA(d8d9, d6d7, d2d3, 0); + d10d11 = MLA(d10d11, d6d7, d2d3, 1); + d12d13 = MLA(d12d13, d6d7, d2d3, 2); + d14d15 = MLA(d14d15, d6d7, d2d3, 3); } if (1 == oddk) { - d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); - d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); - d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); - d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); + d8d9 = MLA(d8d9, d4d5, d0d1, 0); + d10d11 = MLA(d10d11, d4d5, d0d1, 1); + d12d13 = MLA(d12d13, d4d5, d0d1, 2); + d14d15 = MLA(d14d15, d4d5, d0d1, 3); } else { +#if defined(PREFER_VF) + d2d3 = a_ptr; +#else d2d3 = GiLoadFloat32(a_ptr); +#endif a_ptr = a_ptr + 4; d6d7 = GiLoadFloat32(b_ptr); b_ptr = b_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); - d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); - d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); - d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); + d8d9 = MLA(d8d9, d4d5, d0d1, 0); + d10d11 = MLA(d10d11, d4d5, d0d1, 1); + d12d13 = MLA(d12d13, d4d5, d0d1, 2); + d14d15 = MLA(d14d15, d4d5, d0d1, 3); - d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); - d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); - d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); - d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); + d8d9 = MLA(d8d9, d6d7, d2d3, 0); + d10d11 = MLA(d10d11, d6d7, d2d3, 1); + d12d13 = MLA(d12d13, d6d7, d2d3, 2); + d14d15 = MLA(d14d15, d6d7, d2d3, 3); } if (m_remain == 4) { @@ -882,6 +940,7 @@ void gi_sgemm_4x12_pack_B_t( } } +#undef MLA } // namespace MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_4x12); -- GitLab