提交 74fb63db 编写于 作者: M Megvii Engine Team

feat(gi): make matrix_mul apply gi class type

GitOrigin-RevId: 0c0029ee60d669465701333530ebcc160e13577b
上级 45b26400
......@@ -193,24 +193,24 @@ static GI_FORCEINLINE void transpose_4x4_1_s(
GI_FLOAT32_V2_t q0q1 = GiZipqFloat32(d0d1, d2d3);
GI_FLOAT32_V2_t q2q3 = GiZipqFloat32(d4d5, d6d7);
GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[0]));
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q0q1, 0)));
outptr += 2;
GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[0]));
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q2q3, 0)));
outptr += stride;
GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[0]));
GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q0q1, 0)));
outptr += 2;
GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[0]));
GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q2q3, 0)));
outptr += stride;
GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[1]));
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q0q1, 1)));
outptr += 2;
GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[1]));
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(q2q3, 1)));
outptr += stride;
GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[1]));
GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q0q1, 1)));
outptr += 2;
GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[1]));
GiSt1Float32(outptr, GiGetHighFloat32(GiGetSubVectorFloat32V2(q2q3, 1)));
outptr += stride;
}
......
......@@ -24,21 +24,26 @@ void sgemv_gi_naive_n_mk4(
while (m < M) {
auto Aptr0 = Aptr;
auto Cptr0 = Cptr;
GI_FLOAT32_t c[4];
#define INIT(step) c[step] = GiBroadcastFloat32(0.0f);
GI_FLOAT32_V4_t c;
#define INIT(step) GiSetSubVectorFloat32V4(c, step, GiBroadcastFloat32(0.0f));
UNROLL_CALL_RAW(4, INIT)
#undef INIT
auto Bptr = B;
size_t k = 0;
while (k < K) {
GI_FLOAT32_t b = GiLoadFloat32(Bptr);
GI_FLOAT32_V2_t a[2];
#define LOAD_A(step) a[step] = GiLoadFloat32V2(Aptr0 + step * 8);
UNROLL_CALL_RAW(2, LOAD_A)
GI_FLOAT32_V4_t a;
#define LOAD_A(step) GiSetSubVectorFloat32V4(a, step, GiLoadFloat32(Aptr0 + step * 4));
UNROLL_CALL_RAW(4, LOAD_A)
#undef LOAD_A
#define COMPT(step) \
c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4);
#define COMPT(step) \
t = GiSimdFmaLane( \
GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \
step % 4); \
GiSetSubVectorFloat32V4(c, step, t);
GI_FLOAT32_t t;
UNROLL_CALL_RAW(4, COMPT)
#undef COMPT
Bptr += Bstride;
......@@ -46,11 +51,16 @@ void sgemv_gi_naive_n_mk4(
k += PACK_SIZE;
}
#define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]);
#define ADD_C(step, stride) \
t = GiAddFloat32( \
GiGetSubVectorFloat32V4(c, step), \
GiGetSubVectorFloat32V4(c, step + stride)); \
GiSetSubVectorFloat32V4(c, step, t);
GI_FLOAT32_t t;
UNROLL_CALL_RAW(2, ADD_C, 2)
UNROLL_CALL_RAW(1, ADD_C, 1)
#undef ADD_C
GiStoreFloat32(Cptr0, c[0]);
GiStoreFloat32(Cptr0, GiGetSubVectorFloat32V4(c, 0));
Aptr += Astride;
Cptr += Cstride;
......
......@@ -82,6 +82,7 @@ void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) {
GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
......@@ -173,6 +174,7 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) {
B = B + 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册