From a54d9cb9cd7b6b493690cfe65c3b6532a8c11f86 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 Jul 2022 13:53:21 +0800 Subject: [PATCH] feat(x86/rvv): opt FB_GI_F32_MK4_PACK_4x12 algo GitOrigin-RevId: a80805c119c2d572d9ea6447a3a32d0c2e2063fc --- .../matrix_mul/gi/fp32/strategy_mk_4x12.cpp | 227 +++++++++--------- 1 file changed, 110 insertions(+), 117 deletions(-) diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp index 2752e8d70..65f66cc95 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp +++ b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp @@ -20,6 +20,9 @@ using namespace matmul::fallback; namespace { +//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use +//! GiMultiplyAddScalarFloat32 +#define MLA GiMultiplyAddScalarFloat32 void kern_4x12( const float* packA, const float* packB, int K, float* output, int LDC, bool is_first_k) { @@ -32,24 +35,18 @@ void kern_4x12( K = ((K + 1) / 2) - 1; float* r1 = output; - GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, - d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; + GI_FLOAT32_t d0d1, d2d3, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, d20d21, + d22d23, d24d25, d26d27, d28d29, d30d31; if (is_first_k) { d8d9 = GiBroadcastFloat32(0.0f); d10d11 = GiBroadcastFloat32(0.0f); d12d13 = GiBroadcastFloat32(0.0f); d14d15 = GiBroadcastFloat32(0.0f); - d0d1 = GiLoadFloat32(a_ptr); - a_ptr = a_ptr + 4; d16d17 = GiBroadcastFloat32(0.0f); d18d19 = GiBroadcastFloat32(0.0f); d20d21 = GiBroadcastFloat32(0.0f); d22d23 = GiBroadcastFloat32(0.0f); - d4d5 = GiLoadFloat32(b_ptr); - b_ptr = b_ptr + 4; - d6d7 = GiLoadFloat32(b_ptr); - b_ptr = b_ptr + 4; d24d25 = GiBroadcastFloat32(0.0f); d26d27 = GiBroadcastFloat32(0.0f); d28d29 = GiBroadcastFloat32(0.0f); @@ -84,145 +81,144 @@ void kern_4x12( r1 = r1 + 4; d30d31 = GiLoadFloat32(r1); r1 = r1 + 4; - + } + for (; K > 0; K--) { d0d1 = GiLoadFloat32(a_ptr); a_ptr = a_ptr + 4; - d4d5 = GiLoadFloat32(b_ptr); + d8d9 = MLA(d8d9, d0d1, *(b_ptr)); + d10d11 = MLA(d10d11, d0d1, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d0d1, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d0d1, *(b_ptr + 3)); b_ptr = b_ptr + 4; - } - for (; K > 0; K--) { - d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); - d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); - d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); - d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); - d4d5 = GiLoadFloat32(b_ptr); + + d16d17 = MLA(d16d17, d0d1, *(b_ptr)); + d18d19 = MLA(d18d19, d0d1, *(b_ptr + 1)); + d20d21 = MLA(d20d21, d0d1, *(b_ptr + 2)); + d22d23 = MLA(d22d23, d0d1, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); - d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); - d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); + + d24d25 = MLA(d24d25, d0d1, *(b_ptr)); + d26d27 = MLA(d26d27, d0d1, *(b_ptr + 1)); + d28d29 = MLA(d28d29, d0d1, *(b_ptr + 2)); + d30d31 = MLA(d30d31, d0d1, *(b_ptr + 3)); + b_ptr = b_ptr + 4; + d2d3 = GiLoadFloat32(a_ptr); a_ptr = a_ptr + 4; - d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); - d6d7 = GiLoadFloat32(b_ptr); - b_ptr = b_ptr + 4; - d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); - d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); - d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); - d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); - d4d5 = GiLoadFloat32(b_ptr); - b_ptr = b_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); - d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); - d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); - d6d7 = GiLoadFloat32(b_ptr); + d8d9 = MLA(d8d9, d2d3, *(b_ptr)); + d10d11 = MLA(d10d11, d2d3, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d2d3, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d2d3, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0); - d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1); - d0d1 = GiLoadFloat32(a_ptr); - a_ptr = a_ptr + 4; - d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2); - d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3); - d4d5 = GiLoadFloat32(b_ptr); + + d16d17 = MLA(d16d17, d2d3, *(b_ptr)); + d18d19 = MLA(d18d19, d2d3, *(b_ptr + 1)); + d20d21 = MLA(d20d21, d2d3, *(b_ptr + 2)); + d22d23 = MLA(d22d23, d2d3, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0); - d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1); - d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2); - d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3); - d6d7 = GiLoadFloat32(b_ptr); + + d24d25 = MLA(d24d25, d2d3, *(b_ptr)); + d26d27 = MLA(d26d27, d2d3, *(b_ptr + 1)); + d28d29 = MLA(d28d29, d2d3, *(b_ptr + 2)); + d30d31 = MLA(d30d31, d2d3, *(b_ptr + 3)); b_ptr = b_ptr + 4; } + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; if (1 == oddk) { - d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); - d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); - d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); - d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); - d4d5 = GiLoadFloat32(b_ptr); + d8d9 = MLA(d8d9, d0d1, *(b_ptr)); + d10d11 = MLA(d10d11, d0d1, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d0d1, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d0d1, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); + + d16d17 = MLA(d16d17, d0d1, *(b_ptr)); GiStoreFloat32(output0, d8d9); output0 = output0 + 4; GiStoreFloat32(output0, d10d11); output0 = output0 + 4; - d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); - d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); + d18d19 = MLA(d18d19, d0d1, *(b_ptr + 1)); + d20d21 = MLA(d20d21, d0d1, *(b_ptr + 2)); GiStoreFloat32(output0, d12d13); output0 = output0 + 4; GiStoreFloat32(output0, d14d15); output0 = output0 + 4; - d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); - d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); + d22d23 = MLA(d22d23, d0d1, *(b_ptr + 3)); + b_ptr = b_ptr + 4; + + d24d25 = MLA(d24d25, d0d1, *(b_ptr)); GiStoreFloat32(output0, d16d17); output0 = output0 + 4; GiStoreFloat32(output0, d18d19); output0 = output0 + 4; - d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); + d26d27 = MLA(d26d27, d0d1, *(b_ptr + 1)); GiStoreFloat32(output0, d20d21); output0 = output0 + 4; GiStoreFloat32(output0, d22d23); output0 = output0 + 4; - d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); + d28d29 = MLA(d28d29, d0d1, *(b_ptr + 2)); GiStoreFloat32(output0, d24d25); output0 = output0 + 4; GiStoreFloat32(output0, d26d27); output0 = output0 + 4; - d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); + d30d31 = MLA(d30d31, d0d1, *(b_ptr + 3)); GiStoreFloat32(output0, d28d29); output0 = output0 + 4; GiStoreFloat32(output0, d30d31); output0 = output0 + 4; - + b_ptr = b_ptr + 4; } else { - d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); - d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); - d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); - d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); - d4d5 = GiLoadFloat32(b_ptr); + d8d9 = MLA(d8d9, d0d1, *(b_ptr)); + d10d11 = MLA(d10d11, d0d1, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d0d1, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d0d1, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); - d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); - d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); + + d16d17 = MLA(d16d17, d0d1, *(b_ptr)); + d18d19 = MLA(d18d19, d0d1, *(b_ptr + 1)); + d20d21 = MLA(d20d21, d0d1, *(b_ptr + 2)); d2d3 = GiLoadFloat32(a_ptr); a_ptr = a_ptr + 4; - d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); - d6d7 = GiLoadFloat32(b_ptr); + d22d23 = MLA(d22d23, d0d1, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); - d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); - d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); - d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); - d4d5 = GiLoadFloat32(b_ptr); + + d24d25 = MLA(d24d25, d0d1, *(b_ptr)); + d26d27 = MLA(d26d27, d0d1, *(b_ptr + 1)); + d28d29 = MLA(d28d29, d0d1, *(b_ptr + 2)); + d30d31 = MLA(d30d31, d0d1, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); - d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); - d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); - d6d7 = GiLoadFloat32(b_ptr); + d8d9 = MLA(d8d9, d2d3, *(b_ptr)); + d10d11 = MLA(d10d11, d2d3, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d2d3, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d2d3, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0); - d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1); + + d16d17 = MLA(d16d17, d2d3, *(b_ptr)); + d18d19 = MLA(d18d19, d2d3, *(b_ptr + 1)); GiStoreFloat32(output0, d8d9); output0 = output0 + 4; GiStoreFloat32(output0, d10d11); output0 = output0 + 4; - d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2); - d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3); + d20d21 = MLA(d20d21, d2d3, *(b_ptr + 2)); + d22d23 = MLA(d22d23, d2d3, *(b_ptr + 3)); GiStoreFloat32(output0, d12d13); output0 = output0 + 4; GiStoreFloat32(output0, d14d15); output0 = output0 + 4; - d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0); - d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1); + b_ptr = b_ptr + 4; + + d24d25 = MLA(d24d25, d2d3, *(b_ptr)); + d26d27 = MLA(d26d27, d2d3, *(b_ptr + 1)); GiStoreFloat32(output0, d16d17); output0 = output0 + 4; GiStoreFloat32(output0, d18d19); output0 = output0 + 4; - d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2); - d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3); + d28d29 = MLA(d28d29, d2d3, *(b_ptr + 2)); + d30d31 = MLA(d30d31, d2d3, *(b_ptr + 3)); GiStoreFloat32(output0, d20d21); output0 = output0 + 4; GiStoreFloat32(output0, d22d23); @@ -235,6 +231,7 @@ void kern_4x12( output0 = output0 + 4; GiStoreFloat32(output0, d30d31); output0 = output0 + 4; + b_ptr = b_ptr + 4; } } @@ -249,7 +246,7 @@ void kern_4x4( K = ((K + 1) / 2) - 1; float* r1 = output; - GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; + GI_FLOAT32_t d0d1, d2d3, d8d9, d10d11, d12d13, d14d15; if (is_first_k) { d8d9 = GiBroadcastFloat32(0.0f); @@ -260,9 +257,6 @@ void kern_4x4( d12d13 = GiBroadcastFloat32(0.0f); - d4d5 = GiLoadFloat32(b_ptr); - b_ptr = b_ptr + 4; - d14d15 = GiBroadcastFloat32(0.0f); } else { if (n_remain == 4) { @@ -293,44 +287,43 @@ void kern_4x4( } for (; K > 0; K--) { - d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); + d8d9 = MLA(d8d9, d0d1, *(b_ptr)); d2d3 = GiLoadFloat32(a_ptr); a_ptr = a_ptr + 4; - d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); - d6d7 = GiLoadFloat32(b_ptr); + d10d11 = MLA(d10d11, d0d1, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d0d1, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d0d1, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); - d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); - d4d5 = GiLoadFloat32(b_ptr); - b_ptr = b_ptr + 4; - d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); - d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); + d8d9 = MLA(d8d9, d2d3, *(b_ptr)); + d10d11 = MLA(d10d11, d2d3, *(b_ptr + 1)); d0d1 = GiLoadFloat32(a_ptr); a_ptr = a_ptr + 4; - d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); + d12d13 = MLA(d12d13, d2d3, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d2d3, *(b_ptr + 3)); + b_ptr = b_ptr + 4; } if (1 == oddk) { - d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); - d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); - d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); - d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); + d8d9 = MLA(d8d9, d0d1, *(b_ptr)); + d10d11 = MLA(d10d11, d0d1, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d0d1, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d0d1, *(b_ptr + 3)); + b_ptr = b_ptr + 4; } else { - d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); + d8d9 = MLA(d8d9, d0d1, *(b_ptr)); d2d3 = GiLoadFloat32(a_ptr); a_ptr = a_ptr + 4; - d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); - d6d7 = GiLoadFloat32(b_ptr); + d10d11 = MLA(d10d11, d0d1, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d0d1, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d0d1, *(b_ptr + 3)); b_ptr = b_ptr + 4; - d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); - d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); - d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); - d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); - d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); - d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); + d8d9 = MLA(d8d9, d2d3, *(b_ptr)); + d10d11 = MLA(d10d11, d2d3, *(b_ptr + 1)); + d12d13 = MLA(d12d13, d2d3, *(b_ptr + 2)); + d14d15 = MLA(d14d15, d2d3, *(b_ptr + 3)); + b_ptr = b_ptr + 4; } if (n_remain == 4) { @@ -359,7 +352,7 @@ void kern_4x4( output = output + 4; } } - +#undef MLA } // namespace MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_mk4_pack_4x12); -- GitLab