提交 bd50e457 编写于 作者: M Megvii Engine Team

feat(x86/rvv): make MATRIX_MUL_GI_F32_4x12 and FP32_GEMV_MK4_GI

adapt to vv and vf model

GitOrigin-RevId: 691434c59843e20f15caca366dce0d7e445325a9
上级 5c3b4e95
...@@ -31,15 +31,22 @@ void sgemv_gi_naive_n_mk4( ...@@ -31,15 +31,22 @@ void sgemv_gi_naive_n_mk4(
auto Bptr = B; auto Bptr = B;
size_t k = 0; size_t k = 0;
while (k < K) { while (k < K) {
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS)
//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use
//! GiMultiplyAddScalarFloat32
#define MLA(a, b, c, d) GiMultiplyAddScalarFloat32(a, b, *(c + d))
const float* b = Bptr;
#else
#define MLA(a, b, c, d) GiSimdFmaLane(a, b, c, d)
GI_FLOAT32_t b = GiLoadFloat32(Bptr); GI_FLOAT32_t b = GiLoadFloat32(Bptr);
#endif
GI_FLOAT32_V4_t a; GI_FLOAT32_V4_t a;
#define LOAD_A(step) GiSetSubVectorFloat32V4(a, step, GiLoadFloat32(Aptr0 + step * 4)); #define LOAD_A(step) GiSetSubVectorFloat32V4(a, step, GiLoadFloat32(Aptr0 + step * 4));
UNROLL_CALL_RAW(4, LOAD_A) UNROLL_CALL_RAW(4, LOAD_A)
#undef LOAD_A #undef LOAD_A
#define COMPT(step) \ #define COMPT(step) \
t = GiSimdFmaLane( \ t = MLA(GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \
GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \
step % 4); \ step % 4); \
GiSetSubVectorFloat32V4(c, step, t); GiSetSubVectorFloat32V4(c, step, t);
...@@ -49,6 +56,7 @@ void sgemv_gi_naive_n_mk4( ...@@ -49,6 +56,7 @@ void sgemv_gi_naive_n_mk4(
Bptr += Bstride; Bptr += Bstride;
Aptr0 += PACK_SIZE * PACK_SIZE; Aptr0 += PACK_SIZE * PACK_SIZE;
k += PACK_SIZE; k += PACK_SIZE;
#undef MLA
} }
#define ADD_C(step, stride) \ #define ADD_C(step, stride) \
......
...@@ -19,6 +19,17 @@ using namespace matmul::fallback; ...@@ -19,6 +19,17 @@ using namespace matmul::fallback;
namespace { namespace {
#undef PREFER_VF
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS)
#define PREFER_VF
#endif
#if defined(PREFER_VF)
#define MLA(a, b, c, d) GiMultiplyAddScalarFloat32(a, b, *(c + d))
#else
#define MLA(a, b, c, d) GiSimdFmaLane(a, b, c, d)
#endif
void kern_4x12( void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC, const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) { bool is_first_k, int m_remain) {
...@@ -32,8 +43,13 @@ void kern_4x12( ...@@ -32,8 +43,13 @@ void kern_4x12(
float* r2 = r1 + LDC; float* r2 = r1 + LDC;
float* r3 = r2 + LDC; float* r3 = r2 + LDC;
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, #if defined(PREFER_VF)
d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; const float* d0d1;
#else
GI_FLOAT32_t d0d1;
#endif
GI_FLOAT32_t d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, d20d21,
d22d23, d24d25, d26d27, d28d29, d30d31;
if (is_first_k) { if (is_first_k) {
d8d9 = GiBroadcastFloat32(0.0f); d8d9 = GiBroadcastFloat32(0.0f);
...@@ -99,23 +115,31 @@ void kern_4x12( ...@@ -99,23 +115,31 @@ void kern_4x12(
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
for (; K > 0; K--) { for (; K > 0; K--) {
#if defined(PREFER_VF)
d0d1 = a_ptr;
#else
d0d1 = GiLoadFloat32(a_ptr); d0d1 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); d8d9 = MLA(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); d10d11 = MLA(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); d12d13 = MLA(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); d14d15 = MLA(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); d16d17 = MLA(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); d18d19 = MLA(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); d20d21 = MLA(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); d22d23 = MLA(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); d24d25 = MLA(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); d26d27 = MLA(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); d28d29 = MLA(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); d30d31 = MLA(d30d31, d6d7, d0d1, 3);
#if defined(PREFER_VF)
d0d1 = a_ptr;
#else
d0d1 = GiLoadFloat32(a_ptr); d0d1 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d2d3 = GiLoadFloat32(b_ptr); d2d3 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
...@@ -124,18 +148,18 @@ void kern_4x12( ...@@ -124,18 +148,18 @@ void kern_4x12(
d6d7 = GiLoadFloat32(b_ptr); d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); d8d9 = MLA(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); d10d11 = MLA(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); d12d13 = MLA(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); d14d15 = MLA(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); d16d17 = MLA(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); d18d19 = MLA(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); d20d21 = MLA(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); d22d23 = MLA(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); d24d25 = MLA(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); d26d27 = MLA(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); d28d29 = MLA(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); d30d31 = MLA(d30d31, d6d7, d0d1, 3);
d2d3 = GiLoadFloat32(b_ptr); d2d3 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
...@@ -146,40 +170,52 @@ void kern_4x12( ...@@ -146,40 +170,52 @@ void kern_4x12(
} }
if (1 == oddk) { if (1 == oddk) {
#if defined(PREFER_VF)
d0d1 = a_ptr;
#else
d0d1 = GiLoadFloat32(a_ptr); d0d1 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); d8d9 = MLA(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); d10d11 = MLA(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); d12d13 = MLA(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); d14d15 = MLA(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); d16d17 = MLA(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); d18d19 = MLA(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); d20d21 = MLA(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); d22d23 = MLA(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); d24d25 = MLA(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); d26d27 = MLA(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); d28d29 = MLA(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); d30d31 = MLA(d30d31, d6d7, d0d1, 3);
} else { } else {
#if defined(PREFER_VF)
d0d1 = a_ptr;
#else
d0d1 = GiLoadFloat32(a_ptr); d0d1 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); d8d9 = MLA(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); d10d11 = MLA(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); d12d13 = MLA(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); d14d15 = MLA(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); d16d17 = MLA(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); d18d19 = MLA(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); d20d21 = MLA(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); d22d23 = MLA(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); d24d25 = MLA(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); d26d27 = MLA(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); d28d29 = MLA(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); d30d31 = MLA(d30d31, d6d7, d0d1, 3);
#if defined(PREFER_VF)
d0d1 = a_ptr;
#else
d0d1 = GiLoadFloat32(a_ptr); d0d1 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d2d3 = GiLoadFloat32(b_ptr); d2d3 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
...@@ -188,18 +224,18 @@ void kern_4x12( ...@@ -188,18 +224,18 @@ void kern_4x12(
d6d7 = GiLoadFloat32(b_ptr); d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); d8d9 = MLA(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); d10d11 = MLA(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); d12d13 = MLA(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); d14d15 = MLA(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); d16d17 = MLA(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); d18d19 = MLA(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); d20d21 = MLA(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); d22d23 = MLA(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); d24d25 = MLA(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); d26d27 = MLA(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); d28d29 = MLA(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); d30d31 = MLA(d30d31, d6d7, d0d1, 3);
} }
if (m_remain == 4) { if (m_remain == 4) {
...@@ -259,7 +295,13 @@ void kern_4x4( ...@@ -259,7 +295,13 @@ void kern_4x4(
float* r3 = r2 + LDC; float* r3 = r2 + LDC;
size_t d_size = sizeof(float); size_t d_size = sizeof(float);
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; #if defined(PREFER_VF)
const float* d0d1;
const float* d2d3;
#else
GI_FLOAT32_t d0d1, d2d3;
#endif
GI_FLOAT32_t d4d5, d6d7, d8d9, d10d11, d12d13, d14d15;
float tmp[4]; float tmp[4];
if (is_first_k) { if (is_first_k) {
d8d9 = GiBroadcastFloat32(0.0f); d8d9 = GiBroadcastFloat32(0.0f);
...@@ -412,54 +454,70 @@ void kern_4x4( ...@@ -412,54 +454,70 @@ void kern_4x4(
} }
} }
#if defined(PREFER_VF)
d0d1 = a_ptr;
#else
d0d1 = GiLoadFloat32(a_ptr); d0d1 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr); d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
for (; K > 0; K--) { for (; K > 0; K--) {
#if defined(PREFER_VF)
d2d3 = a_ptr;
#else
d2d3 = GiLoadFloat32(a_ptr); d2d3 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr); d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); d8d9 = MLA(d8d9, d4d5, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); d10d11 = MLA(d10d11, d4d5, d0d1, 1);
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); d12d13 = MLA(d12d13, d4d5, d0d1, 2);
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); d14d15 = MLA(d14d15, d4d5, d0d1, 3);
#if defined(PREFER_VF)
d0d1 = a_ptr;
#else
d0d1 = GiLoadFloat32(a_ptr); d0d1 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr); d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); d8d9 = MLA(d8d9, d6d7, d2d3, 0);
d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); d10d11 = MLA(d10d11, d6d7, d2d3, 1);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); d12d13 = MLA(d12d13, d6d7, d2d3, 2);
d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); d14d15 = MLA(d14d15, d6d7, d2d3, 3);
} }
if (1 == oddk) { if (1 == oddk) {
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); d8d9 = MLA(d8d9, d4d5, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); d10d11 = MLA(d10d11, d4d5, d0d1, 1);
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); d12d13 = MLA(d12d13, d4d5, d0d1, 2);
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); d14d15 = MLA(d14d15, d4d5, d0d1, 3);
} else { } else {
#if defined(PREFER_VF)
d2d3 = a_ptr;
#else
d2d3 = GiLoadFloat32(a_ptr); d2d3 = GiLoadFloat32(a_ptr);
#endif
a_ptr = a_ptr + 4; a_ptr = a_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr); d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4; b_ptr = b_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); d8d9 = MLA(d8d9, d4d5, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); d10d11 = MLA(d10d11, d4d5, d0d1, 1);
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); d12d13 = MLA(d12d13, d4d5, d0d1, 2);
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); d14d15 = MLA(d14d15, d4d5, d0d1, 3);
d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); d8d9 = MLA(d8d9, d6d7, d2d3, 0);
d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); d10d11 = MLA(d10d11, d6d7, d2d3, 1);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); d12d13 = MLA(d12d13, d6d7, d2d3, 2);
d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); d14d15 = MLA(d14d15, d6d7, d2d3, 3);
} }
if (m_remain == 4) { if (m_remain == 4) {
...@@ -882,6 +940,7 @@ void gi_sgemm_4x12_pack_B_t( ...@@ -882,6 +940,7 @@ void gi_sgemm_4x12_pack_B_t(
} }
} }
#undef MLA
} // namespace } // namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_4x12); MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_4x12);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册