提交 0d82e9b7 编写于 作者: M Megvii Engine Team

feat(x86/rvv): opt FB_GI_F32_MK4_4x8

GitOrigin-RevId: 9e17de18b4dd79e3b2902e1286e278a9a43c10b6
上级 3fbceb3a
......@@ -7,6 +7,9 @@ using namespace matmul::fallback;
namespace {
//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use
//! GiMultiplyAddScalarFloat32
#define MLA GiMultiplyAddScalarFloat32
void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) {
LDB = LDB - 4;
K = K - 4;
......@@ -24,34 +27,32 @@ void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) {
GI_FLOAT32_t d20d21 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d22d23 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
d16d17 = MLA(d16d17, d8d9, *(B));
d18d19 = MLA(d18d19, d10d11, *(B + 1));
B = B + 4;
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1);
for (; K > 0; K -= 4) {
d8d9 = GiLoadFloat32(A);
A = A + 4;
d10d11 = GiLoadFloat32(A);
A = A + 4;
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3);
d20d21 = MLA(d20d21, d12d13, *(B + 2 - 4));
d22d23 = MLA(d22d23, d14d15, *(B + 3 - 4));
B = B + LDB;
d0d1 = GiLoadFloat32(B);
B = B + 4;
d12d13 = GiLoadFloat32(A);
A = A + 4;
d14d15 = GiLoadFloat32(A);
A = A + 4;
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1);
d16d17 = MLA(d16d17, d8d9, *(B));
d18d19 = MLA(d18d19, d10d11, *(B + 1));
B = B + 4;
}
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3);
d20d21 = MLA(d20d21, d12d13, *(B + 2 - 4));
d22d23 = MLA(d22d23, d14d15, *(B + 3 - 4));
d16d17 = GiAddFloat32(d16d17, d20d21);
d18d19 = GiAddFloat32(d18d19, d22d23);
d16d17 = GiAddFloat32(d16d17, d18d19);
......@@ -73,25 +74,19 @@ void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) {
GI_FLOAT32_t d14d15 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);
GI_FLOAT32_t d16d17 = MLA(vfzero, d8d9, *(B));
d16d17 = MLA(d16d17, d10d11, *(B + 1));
GI_FLOAT32_t d18d19 = MLA(vfzero, d8d9, *(B + 4));
d18d19 = MLA(d18d19, d10d11, *(B + 5));
GI_FLOAT32_t d20d21 = MLA(vfzero, d8d9, *(B + 8));
d20d21 = MLA(d20d21, d10d11, *(B + 9));
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
GI_FLOAT32_t d22d23 = MLA(vfzero, d8d9, *(B + 12));
d22d23 = MLA(d22d23, d10d11, *(B + 13));
B = B + 16;
for (; K > 0; K -= 4) {
d8d9 = GiLoadFloat32(A);
......@@ -99,51 +94,50 @@ void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) {
d10d11 = GiLoadFloat32(A);
A = A + 4;
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d16d17 = MLA(d16d17, d12d13, *(B + 2 - 16));
d16d17 = MLA(d16d17, d14d15, *(B + 3 - 16));
d18d19 = MLA(d18d19, d12d13, *(B + 6 - 16));
d18d19 = MLA(d18d19, d14d15, *(B + 7 - 16));
d20d21 = MLA(d20d21, d12d13, *(B + 10 - 16));
d20d21 = MLA(d20d21, d14d15, *(B + 11 - 16));
d22d23 = MLA(d22d23, d12d13, *(B + 14 - 16));
d22d23 = MLA(d22d23, d14d15, *(B + 15 - 16));
B = B + LDB;
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
B = B + 4;
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d2d3 = GiLoadFloat32(B);
B = B + 4;
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
d4d5 = GiLoadFloat32(B);
B = B + 4;
d16d17 = MLA(d16d17, d8d9, *(B));
d16d17 = MLA(d16d17, d10d11, *(B + 1));
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d6d7 = GiLoadFloat32(B);
B = B + 4;
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0);
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0);
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0);
d18d19 = MLA(d18d19, d8d9, *(B + 4));
d18d19 = MLA(d18d19, d10d11, *(B + 5));
d20d21 = MLA(d20d21, d8d9, *(B + 8));
d20d21 = MLA(d20d21, d10d11, *(B + 9));
d22d23 = MLA(d22d23, d8d9, *(B + 12));
d22d23 = MLA(d22d23, d10d11, *(B + 13));
d12d13 = GiLoadFloat32(A);
A = A + 4;
d14d15 = GiLoadFloat32(A);
A = A + 4;
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
B = B + 16;
}
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d16d17 = MLA(d16d17, d12d13, *(B + 2 - 16));
d16d17 = MLA(d16d17, d14d15, *(B + 3 - 16));
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
d18d19 = MLA(d18d19, d12d13, *(B + 6 - 16));
d18d19 = MLA(d18d19, d14d15, *(B + 7 - 16));
d20d21 = MLA(d20d21, d12d13, *(B + 10 - 16));
d20d21 = MLA(d20d21, d14d15, *(B + 11 - 16));
d22d23 = MLA(d22d23, d12d13, *(B + 14 - 16));
d22d23 = MLA(d22d23, d14d15, *(B + 15 - 16));
GiStoreFloat32(C, d16d17);
C = C + 4;
......@@ -166,56 +160,55 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) {
GI_FLOAT32_t d14d15 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d16d17 = MLA(vfzero, d8d9, *(B));
d16d17 = MLA(d16d17, d10d11, *(B + 1));
d16d17 = MLA(d16d17, d12d13, *(B + 2));
d16d17 = MLA(d16d17, d14d15, *(B + 3));
B = B + 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(B);
GI_FLOAT32_t d18d19 = MLA(vfzero, d8d9, *(B));
d18d19 = MLA(d18d19, d10d11, *(B + 1));
d18d19 = MLA(d18d19, d12d13, *(B + 2));
d18d19 = MLA(d18d19, d14d15, *(B + 3));
B = B + 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(B);
GI_FLOAT32_t d20d21 = MLA(vfzero, d8d9, *(B));
d20d21 = MLA(d20d21, d10d11, *(B + 1));
d20d21 = MLA(d20d21, d12d13, *(B + 2));
d20d21 = MLA(d20d21, d14d15, *(B + 3));
B = B + 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
GI_FLOAT32_t d22d23 = MLA(vfzero, d8d9, *(B));
d22d23 = MLA(d22d23, d10d11, *(B + 1));
d22d23 = MLA(d22d23, d12d13, *(B + 2));
d22d23 = MLA(d22d23, d14d15, *(B + 3));
B = B + 4;
GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
GI_FLOAT32_t d24d25 = MLA(vfzero, d8d9, *(B));
d24d25 = MLA(d24d25, d10d11, *(B + 1));
d24d25 = MLA(d24d25, d12d13, *(B + 2));
d24d25 = MLA(d24d25, d14d15, *(B + 3));
B = B + 4;
d2d3 = GiLoadFloat32(B);
GI_FLOAT32_t d26d27 = MLA(vfzero, d8d9, *(B));
d26d27 = MLA(d26d27, d10d11, *(B + 1));
d26d27 = MLA(d26d27, d12d13, *(B + 2));
d26d27 = MLA(d26d27, d14d15, *(B + 3));
B = B + 4;
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
d4d5 = GiLoadFloat32(B);
GI_FLOAT32_t d28d29 = MLA(vfzero, d8d9, *(B));
d28d29 = MLA(d28d29, d10d11, *(B + 1));
d28d29 = MLA(d28d29, d12d13, *(B + 2));
d28d29 = MLA(d28d29, d14d15, *(B + 3));
B = B + 4;
d6d7 = GiLoadFloat32(B);
GI_FLOAT32_t d30d31 = MLA(vfzero, d8d9, *(B));
d30d31 = MLA(d30d31, d10d11, *(B + 1));
d30d31 = MLA(d30d31, d12d13, *(B + 2));
d30d31 = MLA(d30d31, d14d15, *(B + 3));
B = B + 4;
GI_FLOAT32_t d24d25 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1);
GI_FLOAT32_t d26d27 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1);
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3);
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2);
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3);
GI_FLOAT32_t d28d29 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1);
GI_FLOAT32_t d30d31 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2);
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1);
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3);
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2);
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3);
B = B + LDB;
K = K - 4;
......@@ -229,56 +222,53 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) {
d14d15 = GiLoadFloat32(A);
A = A + 4;
d0d1 = GiLoadFloat32(B);
B = B + 4;
d2d3 = GiLoadFloat32(B);
d16d17 = MLA(d16d17, d8d9, *(B));
d16d17 = MLA(d16d17, d10d11, *(B + 1));
d16d17 = MLA(d16d17, d12d13, *(B + 2));
d16d17 = MLA(d16d17, d14d15, *(B + 3));
B = B + 4;
d4d5 = GiLoadFloat32(B);
d18d19 = MLA(d18d19, d8d9, *(B));
d18d19 = MLA(d18d19, d10d11, *(B + 1));
d18d19 = MLA(d18d19, d12d13, *(B + 2));
d18d19 = MLA(d18d19, d14d15, *(B + 3));
B = B + 4;
d6d7 = GiLoadFloat32(B);
d20d21 = MLA(d20d21, d8d9, *(B));
d20d21 = MLA(d20d21, d10d11, *(B + 1));
d20d21 = MLA(d20d21, d12d13, *(B + 2));
d20d21 = MLA(d20d21, d14d15, *(B + 3));
B = B + 4;
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0);
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
d22d23 = MLA(d22d23, d8d9, *(B));
d22d23 = MLA(d22d23, d10d11, *(B + 1));
d22d23 = MLA(d22d23, d12d13, *(B + 2));
d22d23 = MLA(d22d23, d14d15, *(B + 3));
B = B + 4;
d2d3 = GiLoadFloat32(B);
d24d25 = MLA(d24d25, d8d9, *(B));
d24d25 = MLA(d24d25, d10d11, *(B + 1));
d24d25 = MLA(d24d25, d12d13, *(B + 2));
d24d25 = MLA(d24d25, d14d15, *(B + 3));
B = B + 4;
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
d4d5 = GiLoadFloat32(B);
d26d27 = MLA(d26d27, d8d9, *(B));
d26d27 = MLA(d26d27, d10d11, *(B + 1));
d26d27 = MLA(d26d27, d12d13, *(B + 2));
d26d27 = MLA(d26d27, d14d15, *(B + 3));
B = B + 4;
d6d7 = GiLoadFloat32(B);
d28d29 = MLA(d28d29, d8d9, *(B));
d28d29 = MLA(d28d29, d10d11, *(B + 1));
d28d29 = MLA(d28d29, d12d13, *(B + 2));
d28d29 = MLA(d28d29, d14d15, *(B + 3));
B = B + 4;
d24d25 = GiSimdFmaLane(d24d25, d8d9, d0d1, 0);
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1);
d26d27 = GiSimdFmaLane(d26d27, d8d9, d2d3, 0);
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1);
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3);
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2);
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3);
d28d29 = GiSimdFmaLane(d28d29, d8d9, d4d5, 0);
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1);
d30d31 = GiSimdFmaLane(d30d31, d8d9, d6d7, 0);
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2);
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1);
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3);
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2);
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3);
B = B + LDB;
d30d31 = MLA(d30d31, d8d9, *(B));
d30d31 = MLA(d30d31, d10d11, *(B + 1));
d30d31 = MLA(d30d31, d12d13, *(B + 2));
d30d31 = MLA(d30d31, d14d15, *(B + 3));
B = B + 4 + LDB;
}
GiStoreFloat32(C, d16d17);
C = C + 4;
......@@ -298,6 +288,7 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) {
C = C + 4;
}
#undef MLA
} // namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_4x8);
......
......@@ -176,6 +176,13 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_GI_PACK_MK4) {
"FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4);
}
TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F32_MK4_4x8) {
auto args = matrix_mul::get_benchmark_matmul_args();
matrix_mul::benchmark_single_algo(
handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4);
}
#endif
} // namespace test
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册