提交 ce0f2bfd 编写于 作者: H hjchen2

Optimize gemm data package, it will bring 22% speedup for ocr detection model

上级 c070770c
...@@ -27,390 +27,418 @@ namespace paddle_mobile { ...@@ -27,390 +27,418 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
// 将A矩阵分块复制到连续内存(RowMajor) #if __ARM_NEON
void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, inline float32x4_t vandq_f32(float32x4_t x, uint32x4_t mask) {
float *buffer) { return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
const float *a0, *a1, *a2, *a3;
for (int i = 0; i < m - m_tail; i += MR) {
a0 = A + i * lda;
a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda;
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
}
}
if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda;
a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
}
}
} }
#endif
void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer, const bool parallel) {
const int i_length = m - m_tail; uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5};
for (int i = 0; i < i_length; i += MR) { int remain_k = k & 0x3;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k));
#pragma omp parallel for if (parallel)
for (int i = 0; i < m - 5; i += 6) {
const float *a0 = A + i * lda; const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda; const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda; const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda; const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda; const float *a5 = A + (i + 5) * lda;
float *local_buffer = buffer + i * k; float *out_ptr = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++; int loops = k >> 2;
*local_buffer++ = *a1++; if (loops > 0) {
*local_buffer++ = *a2++; #if __aarch64__
*local_buffer++ = *a3++; for (int l = 0; l < loops; ++l) {
*local_buffer++ = *a4++; float32x4_t _d0 = vld1q_f32(a0);
*local_buffer++ = *a5++; float32x4_t _d1 = vld1q_f32(a1);
} float32x4_t _d2 = vld1q_f32(a2);
} float32x4_t _d3 = vld1q_f32(a3);
if (m_tail != 0) { float32x4_t _d4 = vld1q_f32(a4);
const float *a0 = &A(i_length, 0); float32x4_t _d5 = vld1q_f32(a5);
const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda; float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
const float *a3 = a0 + 3 * lda; float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
const float *a4 = a0 + 4 * lda; float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
const float *a5 = a0 + 5 * lda; _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0]));
float *local_buffer = buffer + i_length * k; _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
switch (m_tail) { _d2 =
case 1: vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
a1 = zero; _d3 =
case 2: vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1]));
a2 = zero;
case 3: vst1q_f32(out_ptr, _d0);
a3 = zero; vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0]));
case 4: vst1q_f32(out_ptr + 6, _d1);
a4 = zero; vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1]));
case 5: vst1q_f32(out_ptr + 12, _d2);
a5 = zero; vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0]));
break; vst1q_f32(out_ptr + 18, _d3);
default: vst1_f32(out_ptr + 22, vget_high_f32(_q3.val[1]));
break;
} a0 += 4;
for (int j = 0; j < k; ++j) { a1 += 4;
*local_buffer++ = *a0++; a2 += 4;
*local_buffer++ = *a1++; a3 += 4;
*local_buffer++ = *a2++; a4 += 4;
*local_buffer++ = *a3++; a5 += 4;
*local_buffer++ = *a4++; out_ptr += 24;
*local_buffer++ = *a5++;
} }
#else
asm volatile(
"loop_4k_%=: \n"
"vld1.32 {d0-d1}, [%[a0]]! \n"
"vld1.32 {d2-d3}, [%[a1]]! \n"
"vld1.32 {d4-d5}, [%[a2]]! \n"
"vld1.32 {d6-d7}, [%[a3]]! \n"
"vld1.32 {d8-d9}, [%[a4]]! \n"
"vld1.32 {d10-d11}, [%[a5]]! \n"
"vtrn.32 q0, q1 \n"
"vtrn.32 q2, q3 \n"
"vtrn.32 q4, q5 \n"
"vswp.32 d1, d4 \n"
"vswp.32 d3, d6 \n"
"vst1.32 {q0}, [%[out]]! \n"
"vst1.32 {d8}, [%[out]]! \n"
"vst1.32 {q1}, [%[out]]! \n"
"vst1.32 {d10}, [%[out]]! \n"
"vst1.32 {q2}, [%[out]]! \n"
"vst1.32 {d9}, [%[out]]! \n"
"vst1.32 {q3}, [%[out]]! \n"
"vst1.32 {d11}, [%[out]]! \n"
"subs %[loops], #1 \n"
"bne loop_4k_%= \n"
: [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2),
[a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5");
#endif
} }
}
void Gemm::PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, if (remain_k > 0) {
float *buffer) { float32x4_t _d0 = vld1q_f32(a0);
const int i_length = m - m_tail; float32x4_t _d1 = vld1q_f32(a1);
#pragma omp parallel for float32x4_t _d2 = vld1q_f32(a2);
for (int i = 0; i < i_length; i += MR) { float32x4_t _d3 = vld1q_f32(a3);
const float *a0 = A + i * lda; float32x4_t _d4 = vld1q_f32(a4);
const float *a1 = A + (i + 1) * lda; float32x4_t _d5 = vld1q_f32(a5);
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda; _d0 = vandq_f32(_d0, vmask1);
const float *a4 = A + (i + 4) * lda; _d1 = vandq_f32(_d1, vmask1);
const float *a5 = A + (i + 5) * lda; _d2 = vandq_f32(_d2, vmask1);
float *local_buffer = buffer + i * k; _d3 = vandq_f32(_d3, vmask1);
for (int j = 0; j < k; ++j) { _d4 = vandq_f32(_d4, vmask1);
*local_buffer++ = *a0++; _d5 = vandq_f32(_d5, vmask1);
*local_buffer++ = *a1++;
*local_buffer++ = *a2++; float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
*local_buffer++ = *a3++; float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
*local_buffer++ = *a4++; float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
*local_buffer++ = *a5++; _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0]));
} _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
} _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
if (m_tail != 0) {
const float *a0 = &A(i_length, 0); switch (remain_k) {
const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3: case 3:
a3 = zero; vst1q_f32(out_ptr + 12, _d2);
case 4: vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0]));
a4 = zero; case 2:
case 5: vst1q_f32(out_ptr + 6, _d1);
a5 = zero; vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1]));
break; case 1:
vst1q_f32(out_ptr, _d0);
vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0]));
default: default:
break; break;
} }
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
} }
} }
}
void Gemm::PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, int remain_m = m % 6;
float *buffer) { if (remain_m) {
const int i_length = m - m_tail; int remain_m_start = m - remain_m;
for (int i = 0; i < i_length; i += MR) { const float *a0 = A + remain_m_start * lda;
const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda;
const float *a6 = A + (i + 6) * lda;
const float *a7 = A + (i + 7) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
*local_buffer++ = *a6++;
*local_buffer++ = *a7++;
}
}
if (m_tail != 0) {
const float *a0 = &A(i_length, 0);
const float *a1 = a0 + lda; const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda; const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda; const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda; const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda; const float *a5 = a0 + 5 * lda;
const float *a6 = a0 + 6 * lda; float *out_ptr = buffer + remain_m_start * k;
const float *a7 = a0 + 7 * lda;
float *local_buffer = buffer + i_length * k; uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_m));
switch (m_tail) { uint32x4_t vmask3 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_m));
case 1:
a1 = zero; int loops = k >> 2;
case 2: if (loops > 0) {
a2 = zero; #if __aarch64__
case 3: for (int l = 0; l < loops; ++l) {
a3 = zero; float32x4_t _d0 = vld1q_f32(a0);
case 4: float32x4_t _d1 = vld1q_f32(a1);
a4 = zero; float32x4_t _d2 = vld1q_f32(a2);
case 5: float32x4_t _d3 = vld1q_f32(a3);
a5 = zero; float32x4_t _d4 = vld1q_f32(a4);
case 6: float32x4_t _d5 = vld1q_f32(a5);
a6 = zero;
case 7: float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
a7 = zero; float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
break; float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
default: _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0]));
break; _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
} _d2 =
for (int j = 0; j < k; ++j) { vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
*local_buffer++ = *a0++; _d3 =
*local_buffer++ = *a1++; vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1]));
*local_buffer++ = *a2++;
*local_buffer++ = *a3++; _d0 = vandq_f32(_d0, vmask2);
*local_buffer++ = *a4++; _d1 = vandq_f32(_d1, vmask2);
*local_buffer++ = *a5++; _d2 = vandq_f32(_d2, vmask2);
*local_buffer++ = *a6++; _d3 = vandq_f32(_d3, vmask2);
*local_buffer++ = *a7++; _d4 = vandq_f32(_q3.val[0], vmask3);
_d5 = vandq_f32(_q3.val[1], vmask3);
vst1q_f32(out_ptr, _d0);
vst1_f32(out_ptr + 4, vget_low_f32(_d4));
vst1q_f32(out_ptr + 6, _d1);
vst1_f32(out_ptr + 10, vget_low_f32(_d5));
vst1q_f32(out_ptr + 12, _d2);
vst1_f32(out_ptr + 16, vget_high_f32(_d4));
vst1q_f32(out_ptr + 18, _d3);
vst1_f32(out_ptr + 22, vget_high_f32(_d5));
a0 += 4;
a1 += 4;
a2 += 4;
a3 += 4;
a4 += 4;
a5 += 4;
out_ptr += 24;
} }
#else
asm volatile(
"loop_4k_%=: \n"
"vld1.32 {d0-d1}, [%[a0]]! \n"
"vld1.32 {d2-d3}, [%[a1]]! \n"
"vld1.32 {d4-d5}, [%[a2]]! \n"
"vld1.32 {d6-d7}, [%[a3]]! \n"
"vld1.32 {d8-d9}, [%[a4]]! \n"
"vld1.32 {d10-d11}, [%[a5]]! \n"
"vtrn.32 q0, q1 \n"
"vtrn.32 q2, q3 \n"
"vtrn.32 q4, q5 \n"
"vswp.32 d1, d4 \n"
"vswp.32 d3, d6 \n"
"vbif q0, %q[vzero], %q[vmask2] \n"
"vbif q1, %q[vzero], %q[vmask2] \n"
"vbif q2, %q[vzero], %q[vmask2] \n"
"vbif q3, %q[vzero], %q[vmask2] \n"
"vbif q4, %q[vzero], %q[vmask3] \n"
"vbif q5, %q[vzero], %q[vmask3] \n"
"vst1.32 {q0}, [%[out]]! \n"
"vst1.32 {d8}, [%[out]]! \n"
"vst1.32 {q1}, [%[out]]! \n"
"vst1.32 {d10}, [%[out]]! \n"
"vst1.32 {q2}, [%[out]]! \n"
"vst1.32 {d9}, [%[out]]! \n"
"vst1.32 {q3}, [%[out]]! \n"
"vst1.32 {d11}, [%[out]]! \n"
"subs %[loops], #1 \n"
"bne loop_4k_%= \n"
: [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2),
[a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops)
: [vmask2] "w"(vmask2), [vmask3] "w"(vmask3), [vzero] "w"(vzero)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5");
#endif
} }
}
void Gemm::PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, if (remain_k > 0) {
float *buffer) { float32x4_t _d0 = vld1q_f32(a0);
const int i_length = m - m_tail; float32x4_t _d1 = vld1q_f32(a1);
#pragma omp parallel for float32x4_t _d2 = vld1q_f32(a2);
for (int i = 0; i < i_length; i += MR) { float32x4_t _d3 = vld1q_f32(a3);
const float *a0 = A + i * lda; float32x4_t _d4 = vld1q_f32(a4);
const float *a1 = A + (i + 1) * lda; float32x4_t _d5 = vld1q_f32(a5);
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda; _d0 = vandq_f32(_d0, vmask1);
const float *a4 = A + (i + 4) * lda; _d1 = vandq_f32(_d1, vmask1);
const float *a5 = A + (i + 5) * lda; _d2 = vandq_f32(_d2, vmask1);
const float *a6 = A + (i + 6) * lda; _d3 = vandq_f32(_d3, vmask1);
const float *a7 = A + (i + 7) * lda; _d4 = vandq_f32(_d4, vmask1);
float *local_buffer = buffer + i * k; _d5 = vandq_f32(_d5, vmask1);
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++; float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
*local_buffer++ = *a1++; float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
*local_buffer++ = *a2++; float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
*local_buffer++ = *a3++; _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0]));
*local_buffer++ = *a4++; _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
*local_buffer++ = *a5++; _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
*local_buffer++ = *a6++; // _d3 = vcombine_f32(vget_high_f32(_q0.val[1]),
*local_buffer++ = *a7++; // vget_high_f32(_q1.val[1]));
}
} _d0 = vandq_f32(_d0, vmask2);
if (m_tail != 0) { _d1 = vandq_f32(_d1, vmask2);
const float *a0 = &A(i_length, 0); _d2 = vandq_f32(_d2, vmask2);
const float *a1 = a0 + lda; // _d3 = vandq_f32(_d3, vmask2);
const float *a2 = a0 + 2 * lda; _d4 = vandq_f32(_q3.val[0], vmask3);
const float *a3 = a0 + 3 * lda; _d5 = vandq_f32(_q3.val[1], vmask3);
const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda; switch (remain_k) {
const float *a6 = a0 + 6 * lda;
const float *a7 = a0 + 7 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3: case 3:
a3 = zero; vst1q_f32(out_ptr + 12, _d2);
case 4: vst1_f32(out_ptr + 16, vget_high_f32(_d4));
a4 = zero; case 2:
case 5: vst1q_f32(out_ptr + 6, _d1);
a5 = zero; vst1_f32(out_ptr + 10, vget_low_f32(_d5));
case 6: case 1:
a6 = zero; vst1q_f32(out_ptr, _d0);
case 7: vst1_f32(out_ptr + 4, vget_low_f32(_d4));
a7 = zero;
break;
default: default:
break; break;
} }
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
*local_buffer++ = *a6++;
*local_buffer++ = *a7++;
} }
} }
} }
// 将B矩阵分块复制到连续内存(RowMajor) // 将B矩阵分块复制到连续内存(RowMajor)
void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer, const bool parallel) {
const int j_length = n - n_tail; const int j_length = n - n_tail;
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k; #pragma omp parallel for if (parallel)
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
int j = 0;
for (; j < j_length - 31; j += 32) {
float *local_buffer0 = buffer + j * k + i * NR;
float *local_buffer1 = buffer + (j + 8) * k + i * NR;
float *local_buffer2 = buffer + (j + 16) * k + i * NR;
float *local_buffer3 = buffer + (j + 24) * k + i * NR;
const float *b0 = B + i * ldb + j;
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%[b0]] \n"
"ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[b0]], #32 \n"
"ld1 {v4.4s, v5.4s}, [%[b0]], #32 \n"
"ld1 {v6.4s, v7.4s}, [%[b0]] \n"
"st1 {v0.4s, v1.4s}, [%[local_buffer0]], #32 \n"
"st1 {v2.4s, v3.4s}, [%[local_buffer1]], #32 \n"
"st1 {v4.4s, v5.4s}, [%[local_buffer2]], #32 \n"
"st1 {v6.4s, v7.4s}, [%[local_buffer3]], #32 \n"
: [local_buffer0] "+r"(local_buffer0),
[local_buffer1] "+r"(local_buffer1),
[local_buffer2] "+r"(local_buffer2),
[local_buffer3] "+r"(local_buffer3), [b0] "+r"(b0)
:
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#else
asm volatile(
// "pld [%[b]] \n"
"vld1.32 {q0, q1}, [%[b0]]! \n"
"vld1.32 {q2, q3}, [%[b0]]! \n"
"vld1.32 {q4, q5}, [%[b0]]! \n"
"vld1.32 {q6, q7}, [%[b0]]! \n"
"vst1.32 {q0, q1}, [%[local_buffer0]]! \n"
"vst1.32 {q2, q3}, [%[local_buffer1]]! \n"
"vst1.32 {q4, q5}, [%[local_buffer2]]! \n"
"vst1.32 {q6, q7}, [%[local_buffer3]]! \n"
: [local_buffer0] "+r"(local_buffer0),
[local_buffer1] "+r"(local_buffer1),
[local_buffer2] "+r"(local_buffer2),
[local_buffer3] "+r"(local_buffer3), [b0] "+r"(b0)
:
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
#endif // __aarch64__
}
for (; j < j_length - 15; j += 16) {
float *local_buffer0 = buffer + j * k + i * NR;
float *local_buffer1 = buffer + (j + 8) * k + i * NR;
const float *b0 = &B(i, j); const float *b0 = &B(i, j);
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
asm volatile( asm volatile(
"prfm pldl1keep, [%[b0]] \n\t" "prfm pldl1keep, [%[b0]] \n"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n"
"st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" "ld1 {v2.4s, v3.4s}, [%[b0]] \n"
: [local_buffer] "+r"(local_buffer) "st1 {v0.4s, v1.4s}, [%[local_buffer0]], #32 \n"
: [b0] "r"(b0) "st1 {v2.4s, v3.4s}, [%[local_buffer1]], #32 \n"
: "memory", "v0", "v1"); : [local_buffer0] "+r"(local_buffer0),
[local_buffer1] "+r"(local_buffer1), [b0] "+r"(b0)
:
: "memory", "v0", "v1", "v2", "v3");
#else #else
asm volatile( asm volatile(
// "pld [%[b0]] \n\t" // "pld [%[b0]] \n"
"vld1.32 {q0, q1}, [%[b0]] \n\t" "vld1.32 {q0, q1}, [%[b0]]! \n"
"vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" "vld1.32 {q2, q3}, [%[b0]] \n"
: [local_buffer] "+r"(local_buffer) "vst1.32 {q0, q1}, [%[local_buffer0]]! \n"
: [b0] "r"(b0) "vst1.32 {q2, q3}, [%[local_buffer1]]! \n"
: "memory", "q0", "q1"); : [local_buffer0] "+r"(local_buffer0),
[local_buffer1] "+r"(local_buffer1), [b0] "+r"(b0)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__ #endif // __aarch64__
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON #endif // __ARM_NEON
} }
} for (; j < j_length; j += NR) {
if (n_tail != 0) { float *local_buffer = buffer + j * k + i * NR;
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
}
}
}
}
void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j); const float *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__ #if __aarch64__
asm volatile( asm volatile(
"prfm pldl1keep, [%[b0]] \n\t" "prfm pldl1keep, [%[b0]] \n"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s}, [%[b0]] \n"
"st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n"
: [local_buffer] "+r"(local_buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1"); : "memory", "v0", "v1");
#else #else
asm volatile( asm volatile(
// "pld [%[b0]] \n\t" // "pld [%[b]] \n"
"vld1.32 {q0, q1}, [%[b0]] \n\t" "vld1.32 {q0, q1}, [%[b0]] \n"
"vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" "vst1.32 {q0, q1}, [%[local_buffer]] \n"
: [local_buffer] "+r"(local_buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "q0", "q1"); : "memory", "q0", "q1");
#endif // __aarch64__ #endif // __aarch64__
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7};
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(n_tail));
uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(n_tail));
float *local_buffer = buffer + j_length * k; float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length); const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) { #if __aarch64__
*local_buffer++ = *b0++; asm volatile(
} "prfm pldl1keep, [%[b0]] \n"
for (int j = n; j < j_length + NR; ++j) { "ld1 {v0.4s, v1.4s}, [%[b0]] \n"
*local_buffer++ = 0; "BIF v0.8b, %[vzero].8b, %[vmask1].8b \n"
} "BIF v1.8b, %[vzero].8b, %[vmask2].8b \n"
"st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n"
: [local_buffer] "+r"(local_buffer)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero),
[b0] "r"(b0)
: "memory", "v0", "v1");
#else
asm volatile(
"vld1.32 {q0, q1}, [%[b0]] \n"
"vbif q0, %q[vzero], %q[vmask1] \n"
"vbif q1, %q[vzero], %q[vmask2] \n"
"vst1.32 {q0, q1}, [%[local_buffer]]! \n"
: [local_buffer] "+r"(local_buffer)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero),
[b0] "r"(b0)
: "memory", "q0", "q1");
#endif
} }
} }
} }
...@@ -418,39 +446,10 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -418,39 +446,10 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
void Gemm::PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void Gemm::PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer, const bool parallel) {
const int j_length = n - n_tail; const int j_length = n - n_tail;
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2");
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
}
}
}
}
void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, #pragma omp parallel for if (parallel)
int ldb, float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < j_length; j += NR) { for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k; float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
...@@ -479,39 +478,10 @@ void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, ...@@ -479,39 +478,10 @@ void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B,
} }
void Gemm::PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void Gemm::PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer, const bool parallel) {
const int j_length = n - n_tail; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3");
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
}
}
}
}
void Gemm::PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, #pragma omp parallel for if (parallel)
int ldb, float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < n - n_tail; j += NR) {
float *local_buffer = buffer + j * k; float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
...@@ -2971,7 +2941,48 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -2971,7 +2941,48 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
// C = A * B // C = A * B
void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) { void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) {
memcpy(C, c, n * sizeof(float)); int nc1 = n / 16;
int _nc1 = n % 16;
int nc2 = _nc1 / 4;
int nc3 = 16 - 4 * (_nc1 % 4);
asm volatile(
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t"
"vst1.32 {q0, q1}, [%[C]]! \n\t"
"vld1.32 {q2, q3}, [%[c]]! \n\t"
"vst1.32 {q2, q3}, [%[C]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"vld1.32 {q4}, [%[c]]! \n\t"
"vst1.32 {q4}, [%[C]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t"
"end_nc2_%=: \n\t"
"cmp %[nc3], #16 \n\t"
"beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q5}, [%[c]]! \n\t"
"vst1.32 {q5}, [%[C]]! \n\t"
"end_nc3_%=: \n\t"
:
: [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5");
} }
// C = alpha * A * B + beta * C // C = alpha * A * B + beta * C
...@@ -3252,17 +3263,17 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -3252,17 +3263,17 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
#if __aarch64__ #if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false);
#else #else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false);
#endif #endif
for (int i = 0; i < m; i += MC) { for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
#if __aarch64__ #if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else #else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false);
#endif #endif
if (bias == nullptr) { if (bias == nullptr) {
InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC, InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC,
...@@ -3325,17 +3336,17 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, ...@@ -3325,17 +3336,17 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
#if __aarch64__ #if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false);
#else #else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false);
#endif #endif
for (int i = 0; i < m; i += MC) { for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
#if __aarch64__ #if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else #else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false);
#endif #endif
if (bias == nullptr) { if (bias == nullptr) {
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
...@@ -3401,17 +3412,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, ...@@ -3401,17 +3412,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
#if __aarch64__ #if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false);
#else #else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false);
#endif #endif
for (int i = 0; i < m; i += MC) { for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
#if __aarch64__ #if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else #else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false);
#endif #endif
if (bias1 == nullptr) { if (bias1 == nullptr) {
InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc, InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc,
...@@ -3465,17 +3476,17 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -3465,17 +3476,17 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
#if __aarch64__ #if __aarch64__
procPackA = &Gemm::PackMatrixA_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_16c; procPackB = &Gemm::PackMatrixB_16c;
procAddDot = &Gemm::AddDot6x16; procAddDot = &Gemm::AddDot6x16;
#else #else
procPackA = &Gemm::PackMatrixA_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_8c; procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8; procAddDot = &Gemm::AddDot6x8;
#endif #endif
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
(*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true);
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else { } else {
...@@ -3492,19 +3503,19 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -3492,19 +3503,19 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
MC = (m + MR - 1) / MR * MR; MC = (m + MR - 1) / MR * MR;
#if __aarch64__ #if __aarch64__
procPackA = &Gemm::PackMatrixA_omp_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_16c; procPackB = &Gemm::PackMatrixB_16c;
procAddDot = &Gemm::AddDot6x16; procAddDot = &Gemm::AddDot6x16;
#else #else
procPackA = &Gemm::PackMatrixA_omp_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_8c; procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8; procAddDot = &Gemm::AddDot6x8;
#endif #endif
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
(*this.*procPackA)(m, KC, m % MR, A, lda, packedA); (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true);
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
} }
...@@ -3524,7 +3535,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -3524,7 +3535,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads; float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false);
if (bias == nullptr) { if (bias == nullptr) {
InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, nullptr); &C(i, 0), ldc, relu, nullptr);
...@@ -3546,7 +3557,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -3546,7 +3557,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads; float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false);
InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C, InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, bias); &C(0, j), ldc, relu, bias);
} }
...@@ -3587,17 +3598,17 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, ...@@ -3587,17 +3598,17 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
#if __aarch64__ #if __aarch64__
procPackA = &Gemm::PackMatrixA_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_16c; procPackB = &Gemm::PackMatrixB_16c;
procAddDot = &Gemm::AddDot6x16; procAddDot = &Gemm::AddDot6x16;
#else #else
procPackA = &Gemm::PackMatrixA_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_8c; procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8; procAddDot = &Gemm::AddDot6x8;
#endif #endif
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
(*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true);
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else { } else {
...@@ -3614,18 +3625,18 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, ...@@ -3614,18 +3625,18 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
MC = (m + MR - 1) / MR * MR; MC = (m + MR - 1) / MR * MR;
#if __aarch64__ #if __aarch64__
procPackA = &Gemm::PackMatrixA_omp_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_16c; procPackB = &Gemm::PackMatrixB_16c;
procAddDot = &Gemm::AddDot6x16; procAddDot = &Gemm::AddDot6x16;
#else #else
procPackA = &Gemm::PackMatrixA_omp_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_8c; procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8; procAddDot = &Gemm::AddDot6x8;
#endif #endif
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
(*this.*procPackA)(m, KC, m % MR, A, lda, packedA); (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true);
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
} }
...@@ -3645,7 +3656,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, ...@@ -3645,7 +3656,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads; float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false);
if (bias == nullptr) { if (bias == nullptr) {
InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, new_scale + i, new_bias + i); &C(i, 0), ldc, relu, new_scale + i, new_bias + i);
...@@ -3668,7 +3679,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, ...@@ -3668,7 +3679,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads; float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false);
if (bias == nullptr) { if (bias == nullptr) {
InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, new_scale, new_bias); &C(0, j), ldc, relu, new_scale, new_bias);
...@@ -3715,17 +3726,17 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3715,17 +3726,17 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
#if __aarch64__ #if __aarch64__
procPackA = &Gemm::PackMatrixA_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_16c; procPackB = &Gemm::PackMatrixB_16c;
procAddDot = &Gemm::AddDot6x16; procAddDot = &Gemm::AddDot6x16;
#else #else
procPackA = &Gemm::PackMatrixA_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_8c; procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8; procAddDot = &Gemm::AddDot6x8;
#endif #endif
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
(*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true);
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else { } else {
...@@ -3742,18 +3753,18 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3742,18 +3753,18 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
MC = (m + MR - 1) / MR * MR; MC = (m + MR - 1) / MR * MR;
#if __aarch64__ #if __aarch64__
procPackA = &Gemm::PackMatrixA_omp_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_16c; procPackB = &Gemm::PackMatrixB_16c;
procAddDot = &Gemm::AddDot6x16; procAddDot = &Gemm::AddDot6x16;
#else #else
procPackA = &Gemm::PackMatrixA_omp_6r; procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_8c; procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8; procAddDot = &Gemm::AddDot6x8;
#endif #endif
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
(*this.*procPackA)(m, KC, m % MR, A, lda, packedA); (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true);
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
} }
...@@ -3773,7 +3784,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3773,7 +3784,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads; float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false);
if (bias1 == nullptr) { if (bias1 == nullptr) {
InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc, InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc,
p + i, mode, bias + i, nullptr); p + i, mode, bias + i, nullptr);
...@@ -3795,7 +3806,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3795,7 +3806,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
nc = s_min(n - j, NC); nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads; float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false);
if (bias1 == nullptr) { if (bias1 == nullptr) {
InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p, InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p,
mode, bias, nullptr); mode, bias, nullptr);
......
...@@ -46,37 +46,25 @@ namespace math { ...@@ -46,37 +46,25 @@ namespace math {
class Gemm { class Gemm {
public: public:
typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *); typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *,
const bool);
typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *, typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *,
int); int);
FnPack procPackA; FnPack procPackA;
FnPack procPackB; FnPack procPackB;
FnAddDot procAddDot; FnAddDot procAddDot;
// 将 A\B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer, const bool parallel);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer, const bool parallel);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer, const bool parallel);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
#if __aarch64__ #if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer, const bool parallel);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer, const bool parallel);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
#endif #endif
// 分块矩阵乘法 // 分块矩阵乘法
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册