提交 5160f973 编写于 作者: Z zhaojiaying01

add armv8 version of gemm

上级 b5c14d86
...@@ -107,20 +107,22 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, ...@@ -107,20 +107,22 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
*buffer++ = *a3++; *buffer++ = *a3++;
} }
} }
int i = m - m_tail;
a0 = &A(i, 0); if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda; a1 = a0 + lda;
a2 = a0 + 2 * lda; a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda; a3 = a0 + 3 * lda;
if (m_tail != 0) { switch (m_tail) {
if (m_tail <= 3) { case 1:
a3 = zero;
}
if (m_tail <= 2) {
a2 = zero;
}
if (m_tail <= 1) {
a1 = zero; a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
break;
default:
break;
} }
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *buffer++ = *a0++;
...@@ -150,28 +152,89 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, ...@@ -150,28 +152,89 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
*buffer++ = *a5++; *buffer++ = *a5++;
} }
} }
int i = m - m_tail; if (m_tail != 0) {
a0 = &A(i, 0); a0 = &A(m - m_tail, 0);
a1 = a0 + lda; a1 = a0 + lda;
a2 = a0 + 2 * lda; a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda; a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda; a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda; a5 = a0 + 5 * lda;
if (m_tail != 0) { switch (m_tail) {
if (m_tail <= 5) { case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero; a5 = zero;
break;
default:
break;
} }
if (m_tail <= 4) { for (int j = 0; j < k; ++j) {
a4 = zero; *buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
*buffer++ = *a4++;
*buffer++ = *a5++;
} }
if (m_tail <= 3) {
a3 = zero;
} }
if (m_tail <= 2) { }
a2 = zero;
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const float *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
for (int i = 0; i < m - m_tail; i += MR) {
a0 = A + i * lda;
a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda;
a4 = A + (i + 4) * lda;
a5 = A + (i + 5) * lda;
a6 = A + (i + 6) * lda;
a7 = A + (i + 7) * lda;
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
*buffer++ = *a4++;
*buffer++ = *a5++;
*buffer++ = *a6++;
*buffer++ = *a7++;
}
} }
if (m_tail <= 1) { if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda;
a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda;
a6 = a0 + 6 * lda;
a7 = a0 + 7 * lda;
switch (m_tail) {
case 1:
a1 = zero; a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
case 6:
a6 = zero;
case 7:
a7 = zero;
break;
default:
break;
} }
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *buffer++ = *a0++;
...@@ -180,6 +243,8 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, ...@@ -180,6 +243,8 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
*buffer++ = *a3++; *buffer++ = *a3++;
*buffer++ = *a4++; *buffer++ = *a4++;
*buffer++ = *a5++; *buffer++ = *a5++;
*buffer++ = *a6++;
*buffer++ = *a7++;
} }
} }
} }
...@@ -234,15 +299,78 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -234,15 +299,78 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
} }
} }
#if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const float *b0;
for (int j = 0; j < n - n_tail; j += NR) {
for (int i = 0; i < k; ++i) {
b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[buffer]], #48 \n\t"
: [buffer] "+r"(buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2");
}
}
if (n_tail != 0) {
for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail);
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = *b0++;
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const float *b0;
for (int j = 0; j < n - n_tail; j += NR) {
for (int i = 0; i < k; ++i) {
b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[buffer]], #64 \n\t"
: [buffer] "+r"(buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3");
}
}
if (n_tail != 0) {
for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail);
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = *b0++;
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
#endif // __aarch64__
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu) { float beta, float *c, float *C, int ldc, bool relu) {
#pragma omp parallel for #pragma omp parallel for
for (int j = 0; j < nc; j += NR) { for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) { for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
} }
} }
...@@ -271,9 +399,14 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, ...@@ -271,9 +399,14 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
#pragma omp parallel for #pragma omp parallel for
for (int j = 0; j < nc; j += NR) { for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) { for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
} }
} }
...@@ -1956,10 +2089,20 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -1956,10 +2089,20 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
int mc, nc; int mc, nc;
for (int j = 0; j < n; j += NC) { for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) { for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc, InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc,
relu); relu);
} }
...@@ -2009,10 +2152,20 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -2009,10 +2152,20 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
int mc, nc; int mc, nc;
for (int j = 0; j < n; j += NC) { for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) { for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i); &C(i, j), ldc, relu, new_scale + i, new_bias + i);
} }
...@@ -2239,6 +2392,192 @@ void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -2239,6 +2392,192 @@ void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
#endif // __ARM_NEON #endif // __ARM_NEON
} }
#if __aarch64__
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k;
int step = 4 * ldc;
asm volatile(
"dup v5.4s, wzr \n\t"
"dup v6.4s, wzr \n\t"
"dup v7.4s, wzr \n\t"
"dup v8.4s, wzr \n\t"
"dup v9.4s, wzr \n\t"
"dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
"dup v17.4s, wzr \n\t"
"dup v18.4s, wzr \n\t"
"dup v19.4s, wzr \n\t"
"dup v20.4s, wzr \n\t"
"dup v21.4s, wzr \n\t"
"dup v22.4s, wzr \n\t"
"dup v23.4s, wzr \n\t"
"dup v24.4s, wzr \n\t"
"dup v25.4s, wzr \n\t"
"dup v26.4s, wzr \n\t"
"dup v27.4s, wzr \n\t"
"dup v28.4s, wzr \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"prfm pldl1keep, [%[a_ptr], #32] \n\t"
"prfm pldl1keep, [%[b_ptr], #48] \n\t"
"ld1 {v0.4s, v1.4s}, [%[a_ptr]], #32 \n\t"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48 \n\t"
"fmla v5.4s, v2.4s, v0.s[0] \n\t"
"fmla v6.4s, v3.4s, v0.s[0] \n\t"
"fmla v7.4s, v4.4s, v0.s[0] \n\t"
"fmla v8.4s, v2.4s, v0.s[1] \n\t"
"fmla v9.4s, v3.4s, v0.s[1] \n\t"
"fmla v10.4s, v4.4s, v0.s[1] \n\t"
"fmla v11.4s, v2.4s, v0.s[2] \n\t"
"fmla v12.4s, v3.4s, v0.s[2] \n\t"
"fmla v13.4s, v4.4s, v0.s[2] \n\t"
"fmla v14.4s, v2.4s, v0.s[3] \n\t"
"fmla v15.4s, v3.4s, v0.s[3] \n\t"
"fmla v16.4s, v4.4s, v0.s[3] \n\t"
"fmla v17.4s, v2.4s, v1.s[0] \n\t"
"fmla v18.4s, v3.4s, v1.s[0] \n\t"
"fmla v19.4s, v4.4s, v1.s[0] \n\t"
"fmla v20.4s, v2.4s, v1.s[1] \n\t"
"fmla v21.4s, v3.4s, v1.s[1] \n\t"
"fmla v22.4s, v4.4s, v1.s[1] \n\t"
"fmla v23.4s, v2.4s, v1.s[2] \n\t"
"fmla v24.4s, v3.4s, v1.s[2] \n\t"
"fmla v25.4s, v4.4s, v1.s[2] \n\t"
"fmla v26.4s, v2.4s, v1.s[3] \n\t"
"fmla v27.4s, v3.4s, v1.s[3] \n\t"
"fmla v28.4s, v4.4s, v1.s[3] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"st1 {v5.4s, v6.4s, v7.4s}, [%[c]], %[step] \n\t"
"st1 {v8.4s, v9.4s, v10.4s}, [%[c]], %[step] \n\t"
"st1 {v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t"
"st1 {v14.4s, v15.4s, v16.4s}, [%[c]], %[step] \n\t"
"st1 {v17.4s, v18.4s, v19.4s}, [%[c]], %[step] \n\t"
"st1 {v20.4s, v21.4s, v22.4s}, [%[c]], %[step] \n\t"
"st1 {v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s}, [%[c]], %[step] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[step] "r"(step)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28");
}
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k;
int step = 4 * ldc;
int step1 = 4 * 6;
asm volatile(
"dup v6.4s, wzr \n\t"
"dup v7.4s, wzr \n\t"
"dup v8.4s, wzr \n\t"
"dup v9.4s, wzr \n\t"
"dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
"dup v17.4s, wzr \n\t"
"dup v18.4s, wzr \n\t"
"dup v19.4s, wzr \n\t"
"dup v20.4s, wzr \n\t"
"dup v21.4s, wzr \n\t"
"dup v22.4s, wzr \n\t"
"dup v23.4s, wzr \n\t"
"dup v24.4s, wzr \n\t"
"dup v25.4s, wzr \n\t"
"dup v26.4s, wzr \n\t"
"dup v27.4s, wzr \n\t"
"dup v28.4s, wzr \n\t"
"dup v29.4s, wzr \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"prfm pldl1keep, [%[a_ptr], #24] \n\t"
"prfm pldl1keep, [%[b_ptr], #64] \n\t"
"ld1 {v0.4s, v1.4s}, [%[a_ptr]], %[step1] \n\t"
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], #64 \n\t"
"fmla v6.4s, v2.4s, v0.s[0] \n\t"
"fmla v7.4s, v3.4s, v0.s[0] \n\t"
"fmla v8.4s, v4.4s, v0.s[0] \n\t"
"fmla v9.4s, v5.4s, v0.s[0] \n\t"
"fmla v10.4s, v2.4s, v0.s[1] \n\t"
"fmla v11.4s, v3.4s, v0.s[1] \n\t"
"fmla v12.4s, v4.4s, v0.s[1] \n\t"
"fmla v13.4s, v5.4s, v0.s[1] \n\t"
"fmla v14.4s, v2.4s, v0.s[2] \n\t"
"fmla v15.4s, v3.4s, v0.s[2] \n\t"
"fmla v16.4s, v4.4s, v0.s[2] \n\t"
"fmla v17.4s, v5.4s, v0.s[2] \n\t"
"fmla v18.4s, v2.4s, v0.s[3] \n\t"
"fmla v19.4s, v3.4s, v0.s[3] \n\t"
"fmla v20.4s, v4.4s, v0.s[3] \n\t"
"fmla v21.4s, v5.4s, v0.s[3] \n\t"
"fmla v22.4s, v2.4s, v1.s[0] \n\t"
"fmla v23.4s, v3.4s, v1.s[0] \n\t"
"fmla v24.4s, v4.4s, v1.s[0] \n\t"
"fmla v25.4s, v5.4s, v1.s[0] \n\t"
"fmla v26.4s, v2.4s, v1.s[1] \n\t"
"fmla v27.4s, v3.4s, v1.s[1] \n\t"
"fmla v28.4s, v4.4s, v1.s[1] \n\t"
"fmla v29.4s, v5.4s, v1.s[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t"
"st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t"
"st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t"
"st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t"
"st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[step] "r"(step), [step1] "r"(step1)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29");
}
#endif // __aarch64__
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -19,8 +19,13 @@ limitations under the License. */ ...@@ -19,8 +19,13 @@ limitations under the License. */
#define B(i, j) B[(i)*ldb + (j)] #define B(i, j) B[(i)*ldb + (j)]
#define C(i, j) C[(i)*ldc + (j)] #define C(i, j) C[(i)*ldc + (j)]
#if __aarch64__
#define MR 6
#define NR 16
#else
#define MR 6 #define MR 6
#define NR 8 #define NR 8
#endif
#define s_min(i, j) ((i) < (j) ? (i) : (j)) #define s_min(i, j) ((i) < (j) ? (i) : (j))
...@@ -43,10 +48,16 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, ...@@ -43,10 +48,16 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor) // 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
...@@ -70,6 +81,8 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, ...@@ -70,6 +81,8 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc); void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写 // 分块矩阵乘法结果回写
// C = A * B // C = A * B
...@@ -114,10 +127,6 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -114,10 +127,6 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias); bool relu, float *new_scale, float *new_bias);
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
const double *B, int ldb, float beta, double *C, int ldc);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册