From 74fb63db29a23cdfd3483fc95239846078d17e4e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Jun 2022 18:13:40 +0800 Subject: [PATCH] feat(gi): make matrix_mul apply gi class type GitOrigin-RevId: 0c0029ee60d669465701333530ebcc160e13577b --- dnn/src/fallback/matrix_mul/gi/fp32/common.h | 16 +++++------ .../matrix_mul/gi/fp32/exec_sgemv.cpp | 28 +++++++++++++------ .../matrix_mul/gi/fp32/strategy_mk4_4x8.cpp | 2 ++ 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/common.h b/dnn/src/fallback/matrix_mul/gi/fp32/common.h index b969f698..f282bc1a 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/common.h +++ b/dnn/src/fallback/matrix_mul/gi/fp32/common.h @@ -193,24 +193,24 @@ static GI_FORCEINLINE void transpose_4x4_1_s( GI_FLOAT32_V2_t q0q1 = GiZipqFloat32(d0d1, d2d3); GI_FLOAT32_V2_t q2q3 = GiZipqFloat32(d4d5, d6d7); - GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[0])); + GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q0q1, 0))); outptr += 2; - GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[0])); + GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q2q3, 0))); outptr += stride; - GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[0])); + GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q0q1, 0))); outptr += 2; - GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[0])); + GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q2q3, 0))); outptr += stride; - GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[1])); + GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q0q1, 1))); outptr += 2; - GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[1])); + GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q2q3, 1))); outptr += stride; - GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[1])); + GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q0q1, 1))); outptr += 2; - GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[1])); + GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q2q3, 1))); outptr += stride; } diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp index 317091d4..adeb9c93 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp +++ b/dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp @@ -24,21 +24,26 @@ void sgemv_gi_naive_n_mk4( while (m < M) { auto Aptr0 = Aptr; auto Cptr0 = Cptr; - GI_FLOAT32_t c[4]; -#define INIT(step) c[step] = GiBroadcastFloat32(0.0f); + GI_FLOAT32_V4_t c; +#define INIT(step) GiSetSubVectorFloat32V4(c, step, GiBroadcastFloat32(0.0f)); UNROLL_CALL_RAW(4, INIT) #undef INIT auto Bptr = B; size_t k = 0; while (k < K) { GI_FLOAT32_t b = GiLoadFloat32(Bptr); - GI_FLOAT32_V2_t a[2]; -#define LOAD_A(step) a[step] = GiLoadFloat32V2(Aptr0 + step * 8); - UNROLL_CALL_RAW(2, LOAD_A) + GI_FLOAT32_V4_t a; +#define LOAD_A(step) GiSetSubVectorFloat32V4(a, step, GiLoadFloat32(Aptr0 + step * 4)); + UNROLL_CALL_RAW(4, LOAD_A) #undef LOAD_A -#define COMPT(step) \ - c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4); +#define COMPT(step) \ + t = GiSimdFmaLane( \ + GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \ + step % 4); \ + GiSetSubVectorFloat32V4(c, step, t); + + GI_FLOAT32_t t; UNROLL_CALL_RAW(4, COMPT) #undef COMPT Bptr += Bstride; @@ -46,11 +51,16 @@ void sgemv_gi_naive_n_mk4( k += PACK_SIZE; } -#define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]); +#define ADD_C(step, stride) \ + t = GiAddFloat32( \ + GiGetSubVectorFloat32V4(c, step), \ + GiGetSubVectorFloat32V4(c, step + stride)); \ + GiSetSubVectorFloat32V4(c, step, t); + GI_FLOAT32_t t; UNROLL_CALL_RAW(2, ADD_C, 2) UNROLL_CALL_RAW(1, ADD_C, 1) #undef ADD_C - GiStoreFloat32(Cptr0, c[0]); + GiStoreFloat32(Cptr0, GiGetSubVectorFloat32V4(c, 0)); Aptr += Astride; Cptr += Cstride; diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp index 2ee011a6..dc97b8c1 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp +++ b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp @@ -82,6 +82,7 @@ void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { GI_FLOAT32_t d6d7 = GiLoadFloat32(B); B = B + 4; + GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f); GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); @@ -173,6 +174,7 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { B = B + 4; GI_FLOAT32_t d6d7 = GiLoadFloat32(B); B = B + 4; + GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f); GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); -- GitLab