diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index e41829761bcd6ac87a73c6378ec36a17458aff56..f4cf22decb5a722f29560f4b563bc8a81001b922 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -27,390 +27,418 @@ namespace paddle_mobile { namespace operators { namespace math { -// 将A矩阵分块复制到连续内存(RowMajor) -void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const float *a0, *a1, *a2, *a3; - for (int i = 0; i < m - m_tail; i += MR) { - a0 = A + i * lda; - a1 = A + (i + 1) * lda; - a2 = A + (i + 2) * lda; - a3 = A + (i + 3) * lda; - for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - } - } - - if (m_tail != 0) { - a0 = &A(m - m_tail, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - } - } +#if __ARM_NEON +inline float32x4_t vandq_f32(float32x4_t x, uint32x4_t mask) { + return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); } +#endif void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const int i_length = m - m_tail; - for (int i = 0; i < i_length; i += MR) { + float *buffer, const bool parallel) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5}; + int remain_k = k & 0x3; + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k)); + + #pragma omp parallel for if (parallel) + for (int i = 0; i < m - 5; i += 6) { const float *a0 = A + i * lda; const float *a1 = A + (i + 1) * lda; const float *a2 = A + (i + 2) * lda; const float *a3 = A + (i + 3) * lda; const float *a4 = A + (i + 4) * lda; const float *a5 = A + (i + 5) * lda; - float *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } - if (m_tail != 0) { - const float *a0 = &A(i_length, 0); - const float *a1 = a0 + lda; - const float *a2 = a0 + 2 * lda; - const float *a3 = a0 + 3 * lda; - const float *a4 = a0 + 4 * lda; - const float *a5 = a0 + 5 * lda; - float *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - case 4: - a4 = zero; - case 5: - a5 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } -} + float *out_ptr = buffer + i * k; -void Gemm::PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const int i_length = m - m_tail; -#pragma omp parallel for - for (int i = 0; i < i_length; i += MR) { - const float *a0 = A + i * lda; - const float *a1 = A + (i + 1) * lda; - const float *a2 = A + (i + 2) * lda; - const float *a3 = A + (i + 3) * lda; - const float *a4 = A + (i + 4) * lda; - const float *a5 = A + (i + 5) * lda; - float *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } - if (m_tail != 0) { - const float *a0 = &A(i_length, 0); - const float *a1 = a0 + lda; - const float *a2 = a0 + 2 * lda; - const float *a3 = a0 + 3 * lda; - const float *a4 = a0 + 4 * lda; - const float *a5 = a0 + 5 * lda; - float *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - case 4: - a4 = zero; - case 5: - a5 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; + int loops = k >> 2; + if (loops > 0) { +#if __aarch64__ + for (int l = 0; l < loops; ++l) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = + vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + _d3 = + vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1])); + + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0])); + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1])); + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0])); + vst1q_f32(out_ptr + 18, _d3); + vst1_f32(out_ptr + 22, vget_high_f32(_q3.val[1])); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; + out_ptr += 24; + } +#else + asm volatile( + "loop_4k_%=: \n" + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vswp.32 d1, d4 \n" + "vswp.32 d3, d6 \n" + + "vst1.32 {q0}, [%[out]]! \n" + "vst1.32 {d8}, [%[out]]! \n" + "vst1.32 {q1}, [%[out]]! \n" + "vst1.32 {d10}, [%[out]]! \n" + "vst1.32 {q2}, [%[out]]! \n" + "vst1.32 {d9}, [%[out]]! \n" + "vst1.32 {q3}, [%[out]]! \n" + "vst1.32 {d11}, [%[out]]! \n" + + "subs %[loops], #1 \n" + "bne loop_4k_%= \n" + : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2), + [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +#endif } - } -} -void Gemm::PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const int i_length = m - m_tail; - for (int i = 0; i < i_length; i += MR) { - const float *a0 = A + i * lda; - const float *a1 = A + (i + 1) * lda; - const float *a2 = A + (i + 2) * lda; - const float *a3 = A + (i + 3) * lda; - const float *a4 = A + (i + 4) * lda; - const float *a5 = A + (i + 5) * lda; - const float *a6 = A + (i + 6) * lda; - const float *a7 = A + (i + 7) * lda; - float *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - *local_buffer++ = *a6++; - *local_buffer++ = *a7++; - } - } - if (m_tail != 0) { - const float *a0 = &A(i_length, 0); + if (remain_k > 0) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + _d0 = vandq_f32(_d0, vmask1); + _d1 = vandq_f32(_d1, vmask1); + _d2 = vandq_f32(_d2, vmask1); + _d3 = vandq_f32(_d3, vmask1); + _d4 = vandq_f32(_d4, vmask1); + _d5 = vandq_f32(_d5, vmask1); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + + switch (remain_k) { + case 3: + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0])); + case 2: + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1])); + case 1: + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0])); + default: + break; + } + } + } + + int remain_m = m % 6; + if (remain_m) { + int remain_m_start = m - remain_m; + const float *a0 = A + remain_m_start * lda; const float *a1 = a0 + lda; const float *a2 = a0 + 2 * lda; const float *a3 = a0 + 3 * lda; const float *a4 = a0 + 4 * lda; const float *a5 = a0 + 5 * lda; - const float *a6 = a0 + 6 * lda; - const float *a7 = a0 + 7 * lda; - float *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - case 4: - a4 = zero; - case 5: - a5 = zero; - case 6: - a6 = zero; - case 7: - a7 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - *local_buffer++ = *a6++; - *local_buffer++ = *a7++; + float *out_ptr = buffer + remain_m_start * k; + + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_m)); + uint32x4_t vmask3 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_m)); + + int loops = k >> 2; + if (loops > 0) { +#if __aarch64__ + for (int l = 0; l < loops; ++l) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = + vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + _d3 = + vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1])); + + _d0 = vandq_f32(_d0, vmask2); + _d1 = vandq_f32(_d1, vmask2); + _d2 = vandq_f32(_d2, vmask2); + _d3 = vandq_f32(_d3, vmask2); + _d4 = vandq_f32(_q3.val[0], vmask3); + _d5 = vandq_f32(_q3.val[1], vmask3); + + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_d4)); + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_d5)); + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_d4)); + vst1q_f32(out_ptr + 18, _d3); + vst1_f32(out_ptr + 22, vget_high_f32(_d5)); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; + out_ptr += 24; + } +#else + asm volatile( + "loop_4k_%=: \n" + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vswp.32 d1, d4 \n" + "vswp.32 d3, d6 \n" + + "vbif q0, %q[vzero], %q[vmask2] \n" + "vbif q1, %q[vzero], %q[vmask2] \n" + "vbif q2, %q[vzero], %q[vmask2] \n" + "vbif q3, %q[vzero], %q[vmask2] \n" + "vbif q4, %q[vzero], %q[vmask3] \n" + "vbif q5, %q[vzero], %q[vmask3] \n" + + "vst1.32 {q0}, [%[out]]! \n" + "vst1.32 {d8}, [%[out]]! \n" + "vst1.32 {q1}, [%[out]]! \n" + "vst1.32 {d10}, [%[out]]! \n" + "vst1.32 {q2}, [%[out]]! \n" + "vst1.32 {d9}, [%[out]]! \n" + "vst1.32 {q3}, [%[out]]! \n" + "vst1.32 {d11}, [%[out]]! \n" + + "subs %[loops], #1 \n" + "bne loop_4k_%= \n" + : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2), + [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops) + : [vmask2] "w"(vmask2), [vmask3] "w"(vmask3), [vzero] "w"(vzero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +#endif } - } -} -void Gemm::PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const int i_length = m - m_tail; -#pragma omp parallel for - for (int i = 0; i < i_length; i += MR) { - const float *a0 = A + i * lda; - const float *a1 = A + (i + 1) * lda; - const float *a2 = A + (i + 2) * lda; - const float *a3 = A + (i + 3) * lda; - const float *a4 = A + (i + 4) * lda; - const float *a5 = A + (i + 5) * lda; - const float *a6 = A + (i + 6) * lda; - const float *a7 = A + (i + 7) * lda; - float *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - *local_buffer++ = *a6++; - *local_buffer++ = *a7++; - } - } - if (m_tail != 0) { - const float *a0 = &A(i_length, 0); - const float *a1 = a0 + lda; - const float *a2 = a0 + 2 * lda; - const float *a3 = a0 + 3 * lda; - const float *a4 = a0 + 4 * lda; - const float *a5 = a0 + 5 * lda; - const float *a6 = a0 + 6 * lda; - const float *a7 = a0 + 7 * lda; - float *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - case 4: - a4 = zero; - case 5: - a5 = zero; - case 6: - a6 = zero; - case 7: - a7 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - *local_buffer++ = *a6++; - *local_buffer++ = *a7++; + if (remain_k > 0) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + _d0 = vandq_f32(_d0, vmask1); + _d1 = vandq_f32(_d1, vmask1); + _d2 = vandq_f32(_d2, vmask1); + _d3 = vandq_f32(_d3, vmask1); + _d4 = vandq_f32(_d4, vmask1); + _d5 = vandq_f32(_d5, vmask1); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + // _d3 = vcombine_f32(vget_high_f32(_q0.val[1]), + // vget_high_f32(_q1.val[1])); + + _d0 = vandq_f32(_d0, vmask2); + _d1 = vandq_f32(_d1, vmask2); + _d2 = vandq_f32(_d2, vmask2); + // _d3 = vandq_f32(_d3, vmask2); + _d4 = vandq_f32(_q3.val[0], vmask3); + _d5 = vandq_f32(_q3.val[1], vmask3); + + switch (remain_k) { + case 3: + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_d4)); + case 2: + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_d5)); + case 1: + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_d4)); + default: + break; + } } } } // 将B矩阵分块复制到连续内存(RowMajor) void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { + float *buffer, const bool parallel) { const int j_length = n - n_tail; - for (int j = 0; j < j_length; j += NR) { - float *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { + + #pragma omp parallel for if (parallel) + for (int i = 0; i < k; ++i) { + int j = 0; + for (; j < j_length - 31; j += 32) { + float *local_buffer0 = buffer + j * k + i * NR; + float *local_buffer1 = buffer + (j + 8) * k + i * NR; + float *local_buffer2 = buffer + (j + 16) * k + i * NR; + float *local_buffer3 = buffer + (j + 24) * k + i * NR; + const float *b0 = B + i * ldb + j; +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b0]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[b0]], #32 \n" + "ld1 {v6.4s, v7.4s}, [%[b0]] \n" + "st1 {v0.4s, v1.4s}, [%[local_buffer0]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[local_buffer1]], #32 \n" + "st1 {v4.4s, v5.4s}, [%[local_buffer2]], #32 \n" + "st1 {v6.4s, v7.4s}, [%[local_buffer3]], #32 \n" + : [local_buffer0] "+r"(local_buffer0), + [local_buffer1] "+r"(local_buffer1), + [local_buffer2] "+r"(local_buffer2), + [local_buffer3] "+r"(local_buffer3), [b0] "+r"(b0) + : + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#else + asm volatile( + // "pld [%[b]] \n" + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q2, q3}, [%[b0]]! \n" + "vld1.32 {q4, q5}, [%[b0]]! \n" + "vld1.32 {q6, q7}, [%[b0]]! \n" + "vst1.32 {q0, q1}, [%[local_buffer0]]! \n" + "vst1.32 {q2, q3}, [%[local_buffer1]]! \n" + "vst1.32 {q4, q5}, [%[local_buffer2]]! \n" + "vst1.32 {q6, q7}, [%[local_buffer3]]! \n" + : [local_buffer0] "+r"(local_buffer0), + [local_buffer1] "+r"(local_buffer1), + [local_buffer2] "+r"(local_buffer2), + [local_buffer3] "+r"(local_buffer3), [b0] "+r"(b0) + : + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +#endif // __aarch64__ + } + for (; j < j_length - 15; j += 16) { + float *local_buffer0 = buffer + j * k + i * NR; + float *local_buffer1 = buffer + (j + 8) * k + i * NR; const float *b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ asm volatile( - "prfm pldl1keep, [%[b0]] \n\t" - "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "v0", "v1"); + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b0]] \n" + "st1 {v0.4s, v1.4s}, [%[local_buffer0]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[local_buffer1]], #32 \n" + : [local_buffer0] "+r"(local_buffer0), + [local_buffer1] "+r"(local_buffer1), [b0] "+r"(b0) + : + : "memory", "v0", "v1", "v2", "v3"); #else asm volatile( - // "pld [%[b0]] \n\t" - "vld1.32 {q0, q1}, [%[b0]] \n\t" - "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "q0", "q1"); + // "pld [%[b0]] \n" + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q2, q3}, [%[b0]] \n" + "vst1.32 {q0, q1}, [%[local_buffer0]]! \n" + "vst1.32 {q2, q3}, [%[local_buffer1]]! \n" + : [local_buffer0] "+r"(local_buffer0), + [local_buffer1] "+r"(local_buffer1), [b0] "+r"(b0) + : + : "memory", "q0", "q1", "q2", "q3"); #endif // __aarch64__ -#else - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; #endif // __ARM_NEON } - } - if (n_tail != 0) { - float *local_buffer = buffer + j_length * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } - } - } -} - -void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { - const int j_length = n - n_tail; -#pragma omp parallel for - for (int j = 0; j < j_length; j += NR) { - float *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { + for (; j < j_length; j += NR) { + float *local_buffer = buffer + j * k + i * NR; const float *b0 = &B(i, j); -#if __ARM_NEON #if __aarch64__ asm volatile( - "prfm pldl1keep, [%[b0]] \n\t" - "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n" : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "v0", "v1"); #else asm volatile( - // "pld [%[b0]] \n\t" - "vld1.32 {q0, q1}, [%[b0]] \n\t" - "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" + // "pld [%[b]] \n" + "vld1.32 {q0, q1}, [%[b0]] \n" + "vst1.32 {q0, q1}, [%[local_buffer]] \n" : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "q0", "q1"); #endif // __aarch64__ -#else - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; -#endif // __ARM_NEON } } if (n_tail != 0) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(n_tail)); + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(n_tail)); + float *local_buffer = buffer + j_length * k; for (int i = 0; i < k; ++i) { const float *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "BIF v0.8b, %[vzero].8b, %[vmask1].8b \n" + "BIF v1.8b, %[vzero].8b, %[vmask2].8b \n" + "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n" + : [local_buffer] "+r"(local_buffer) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero), + [b0] "r"(b0) + : "memory", "v0", "v1"); +#else + asm volatile( + "vld1.32 {q0, q1}, [%[b0]] \n" + "vbif q0, %q[vzero], %q[vmask1] \n" + "vbif q1, %q[vzero], %q[vmask2] \n" + "vst1.32 {q0, q1}, [%[local_buffer]]! \n" + : [local_buffer] "+r"(local_buffer) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero), + [b0] "r"(b0) + : "memory", "q0", "q1"); +#endif } } } @@ -418,39 +446,10 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, #if __ARM_NEON #if __aarch64__ void Gemm::PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { + float *buffer, const bool parallel) { const int j_length = n - n_tail; - for (int j = 0; j < j_length; j += NR) { - float *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j); - asm volatile( - "prfm pldl2keep, [%[b0], #64] \n\t" - "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "v0", "v1", "v2"); - } - } - if (n_tail != 0) { - float *local_buffer = buffer + j_length * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } - } - } -} -void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, - int ldb, float *buffer) { - const int j_length = n - n_tail; -#pragma omp parallel for + #pragma omp parallel for if (parallel) for (int j = 0; j < j_length; j += NR) { float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { @@ -479,39 +478,10 @@ void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, } void Gemm::PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { + float *buffer, const bool parallel) { const int j_length = n - n_tail; - for (int j = 0; j < n - n_tail; j += NR) { - float *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j); - asm volatile( - "prfm pldl2keep, [%[b0], #64] \n\t" - "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "v0", "v1", "v2", "v3"); - } - } - if (n_tail != 0) { - float *local_buffer = buffer + j_length * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } - } - } -} -void Gemm::PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, - int ldb, float *buffer) { - const int j_length = n - n_tail; -#pragma omp parallel for + #pragma omp parallel for if (parallel) for (int j = 0; j < n - n_tail; j += NR) { float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { @@ -2971,7 +2941,48 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, // C = A * B void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) { - memcpy(C, c, n * sizeof(float)); + int nc1 = n / 16; + int _nc1 = n % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); + + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vst1.32 {q0, q1}, [%[C]]! \n\t" + + "vld1.32 {q2, q3}, [%[c]]! \n\t" + "vst1.32 {q2, q3}, [%[C]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q4}, [%[c]]! \n\t" + "vst1.32 {q4}, [%[C]]! \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + "vld1.32 {q5}, [%[c]]! \n\t" + "vst1.32 {q5}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); } // C = alpha * A * B + beta * C @@ -3252,17 +3263,17 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, nc = s_min(n - j, NC); #if __aarch64__ // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #endif for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); #if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); #else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); #endif if (bias == nullptr) { InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC, @@ -3325,17 +3336,17 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, nc = s_min(n - j, NC); #if __aarch64__ // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #endif for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); #if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); #else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); #endif if (bias == nullptr) { InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, @@ -3401,17 +3412,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, nc = s_min(n - j, NC); #if __aarch64__ // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #endif for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); #if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); #else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); #endif if (bias1 == nullptr) { InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc, @@ -3465,17 +3476,17 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, #if __aarch64__ procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; + procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; + procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3492,19 +3503,19 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, MC = (m + MR - 1) / MR * MR; #if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } @@ -3524,7 +3535,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, mc = s_min(m - i, MC); float *local_A = packedA + MC * KC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false); if (bias == nullptr) { InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0), ldc, relu, nullptr); @@ -3546,7 +3557,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, nc = s_min(n - j, NC); float *local_B = packedB + KC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false); InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j), ldc, relu, bias); } @@ -3587,17 +3598,17 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, #if __aarch64__ procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; + procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; + procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3614,18 +3625,18 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, MC = (m + MR - 1) / MR * MR; #if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } @@ -3645,7 +3656,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, mc = s_min(m - i, MC); float *local_A = packedA + MC * KC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false); if (bias == nullptr) { InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0), ldc, relu, new_scale + i, new_bias + i); @@ -3668,7 +3679,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, nc = s_min(n - j, NC); float *local_B = packedB + KC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false); if (bias == nullptr) { InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j), ldc, relu, new_scale, new_bias); @@ -3715,17 +3726,17 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, #if __aarch64__ procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; + procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; + procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3742,18 +3753,18 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, MC = (m + MR - 1) / MR * MR; #if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } @@ -3773,7 +3784,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, mc = s_min(m - i, MC); float *local_A = packedA + MC * KC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false); if (bias1 == nullptr) { InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc, p + i, mode, bias + i, nullptr); @@ -3795,7 +3806,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, nc = s_min(n - j, NC); float *local_B = packedB + KC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false); if (bias1 == nullptr) { InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p, mode, bias, nullptr); diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index effab20b2045fbe93590189e28bac24d1f72ab2c..3fc418003c7faa804c0f7a146b1f9108e0b01789 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -46,37 +46,25 @@ namespace math { class Gemm { public: - typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *); + typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *, + const bool); typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *, int); FnPack procPackA; FnPack procPackB; FnAddDot procAddDot; - // 将 A\B 矩阵分块复制到连续内存(RowMajor) - void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); - void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); + float *buffer, const bool parallel); void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); - void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); + float *buffer, const bool parallel); void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); - void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); + float *buffer, const bool parallel); #if __aarch64__ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); - void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); + float *buffer, const bool parallel); void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); - void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); + float *buffer, const bool parallel); #endif // 分块矩阵乘法