diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 733511c37652f315b64683fad6fa5f4ed5d06c91..a914f64734fdcd62b5fd76ad4065630c39c2488b 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -33,6 +33,7 @@ float *packedA; float *packedB; float *packedC; float *zero; +/* // 将A矩阵分块复制到连续内存(ColMajor) void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, float *buffer) { @@ -60,6 +61,36 @@ void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, } } +// 将B矩阵分块复制到连续内存(ColMajor) +void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + int i, j; + const float *Bj, *Bj1, *Bj2, *Bj3; + for (j = 0; j < n - n_tail; j += NR) { + Bj = &B(0, j); + Bj1 = &B(0, j + 1); + Bj2 = &B(0, j + 2); + Bj3 = &B(0, j + 3); + for (i = 0; i < k; ++i) { + *buffer++ = *Bj++; + *buffer++ = *Bj1++; + *buffer++ = *Bj2++; + *buffer++ = *Bj3++; + } + } + if (n_tail != 0) { + for (i = 0; i < k; ++i) { + for (int j = n - n_tail; j < n; ++j) { + *buffer++ = B(i, j); + } + for (int j = n; j < n + (NR - n_tail); ++j) { + *buffer++ = 0; + } + } + } +} +*/ + // 将A矩阵分块复制到连续内存(RowMajor) void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda, float *buffer) { @@ -100,35 +131,6 @@ void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda, } } -// 将B矩阵分块复制到连续内存(ColMajor) -void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { - int i, j; - const float *Bj, *Bj1, *Bj2, *Bj3; - for (j = 0; j < n - n_tail; j += NR) { - Bj = &B(0, j); - Bj1 = &B(0, j + 1); - Bj2 = &B(0, j + 2); - Bj3 = &B(0, j + 3); - for (i = 0; i < k; ++i) { - *buffer++ = *Bj++; - *buffer++ = *Bj1++; - *buffer++ = *Bj2++; - *buffer++ = *Bj3++; - } - } - if (n_tail != 0) { - for (i = 0; i < k; ++i) { - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = B(i, j); - } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; - } - } - } -} - // 将B矩阵分块复制到连续内存(RowMajor) void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { @@ -138,18 +140,31 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb, b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ - + asm volatile( + "prfm pldl1keep, [%[b0]] \n\t" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s}, [%[buffer]], #32 \n\t" + : [buffer] "+r"(buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1"); #else asm volatile( - "pld [%[b0]] \n\t" - "vld1.32 {q0, q1}, [%[b0]] \n\t" - "vst1.32 {q0, q1}, [%[buffer]]! \n\t" + "pld [%[b0]] \n\t" + "vld1.32 {q0, q1}, [%[b0]] \n\t" + "vst1.32 {q0, q1}, [%[buffer]]! \n\t" : [buffer] "+r"(buffer) : [b0] "r"(b0) - : "memory", "q0", "q0"); + : "memory", "q0", "q1"); #endif // __aarch64__ #else - + *buffer++ = *b0++; + *buffer++ = *b0++; + *buffer++ = *b0++; + *buffer++ = *b0++; + *buffer++ = *b0++; + *buffer++ = *b0++; + *buffer++ = *b0++; + *buffer++ = *b0++; #endif // __ARM_NEON } } @@ -217,7 +232,7 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, #if __ARM_NEON #if __aarch64__ -void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) { +void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { // init C float32x4_t cv0 = vdupq_n_f32(0.0); float32x4_t cv1 = vdupq_n_f32(0.0); @@ -244,23 +259,264 @@ void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) { a += MR; b += NR; } - float32x4x4_t cv = {cv0, cv1, cv2, cv3}; - int i, j; - for (i = 0; i < mc; ++i) { - for (j = 0; j < nc; ++j) { - if (beta == 0.0) { - C(i, j) = 0.0; - } else if (beta != 1.0) { - C(i, j) *= beta; + + vst1q_f32(c, cv0); + vst1q_f32(c + ldc, cv1); + vst1q_f32(c + 2 * ldc, cv2); + vst1q_f32(c + 3 * ldc, cv3); + // float32x4x4_t cv = {cv0, cv1, cv2, cv3}; +} + +void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { + // init C + float32x4_t cv0 = vdupq_n_f32(0.0); + float32x4_t cv1 = vdupq_n_f32(0.0); + float32x4_t cv2 = vdupq_n_f32(0.0); + float32x4_t cv3 = vdupq_n_f32(0.0); + float32x4_t cv4 = vdupq_n_f32(0.0); + float32x4_t cv5 = vdupq_n_f32(0.0); + float32x4_t cv6 = vdupq_n_f32(0.0); + float32x4_t cv7 = vdupq_n_f32(0.0); + + float32x4_t av; + float32x4_t bv0; + float32x4_t bv1; + + float32x2_t av01; + float32x2_t av23; + + for (int p = 0; p < k; p += 1) { + av = vld1q_f32(a); + bv0 = vld1q_f32(b); + bv1 = vld1q_f32(b + 4); + + av01 = vget_low_f32(av); + cv0 = vmlaq_lane_f32(cv0, bv0, av01, 0); + cv1 = vmlaq_lane_f32(cv1, bv1, av01, 0); + cv2 = vmlaq_lane_f32(cv2, bv0, av01, 1); + cv3 = vmlaq_lane_f32(cv3, bv1, av01, 1); + av23 = vget_high_f32(av); + cv4 = vmlaq_lane_f32(cv4, bv0, av23, 0); + cv5 = vmlaq_lane_f32(cv5, bv1, av23, 0); + cv6 = vmlaq_lane_f32(cv6, bv0, av23, 1); + cv7 = vmlaq_lane_f32(cv7, bv1, av23, 1); + + a += MR; + b += NR; + } + + vst1q_f32(c, cv0); + vst1q_f32(c + 4, cv1); + vst1q_f32(c + ldc, cv2); + vst1q_f32(c + ldc + 4, cv3); + vst1q_f32(c + 2 * ldc, cv4); + vst1q_f32(c + 2 * ldc + 4, cv5); + vst1q_f32(c + 3 * ldc, cv6); + vst1q_f32(c + 3 * ldc + 4, cv7); +} + +// 分块矩阵乘法结果回写 +// C = A * B +void WriteBasic(int mc, int nc, float *c, float *C, int ldc) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + } + } + } +} + +// C = alpha * A * B + beta * C +void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} + +// C = A * B + C +void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t cv1; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv1 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv1); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv1 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv1); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + } + } + } +} + +// C = A * B + C, relu(C) +void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t cv1; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv1 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv1); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv1 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv1); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + } + } + } +} + +// C = A * B, batchnorm(C) +void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, + float *new_bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t cv1; + float32x4_t bias; + float32x2_t scale; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias = vld1q_dup_f32(new_bias); + scale = vld1_dup_f32(new_scale); + new_bias++; + new_scale++; + float scale0 = vget_lane_f32(scale, 0); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vmlaq_n_f32(bias, cv, scale0); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vmlaq_n_f32(bias, cv, scale0); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; + } + } + } +} + +// C = A * B, batchnorm(C), relu(C) +void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t bias; + float32x2_t scale; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias = vld1q_dup_f32(new_bias); + scale = vld1_dup_f32(new_scale); + new_bias++; + new_scale++; + float scale0 = vget_lane_f32(scale, 0); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vmlaq_n_f32(bias, cv, scale0); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vmlaq_n_f32(bias, cv, scale0); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; } - if (j == 0) { - C(i, j) += alpha * vgetq_lane_f32(cv.val[i], 0); - } else if (j == 1) { - C(i, j) += alpha * vgetq_lane_f32(cv.val[i], 1); - } else if (j == 2) { - C(i, j) += alpha * vgetq_lane_f32(cv.val[i], 2); - } else if (j == 3) { - C(i, j) += alpha * vgetq_lane_f32(cv.val[i], 3); + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); } } } @@ -338,6 +594,7 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { "q10", "q11", "q12", "q13"); } +/* void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { @@ -770,6 +1027,7 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, VecWriteWithBn(n, bufferC, C, ldc, new_scale, new_bias); } } +*/ void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { const float *a_ptr, *b_ptr; @@ -1288,6 +1546,7 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, "q8", "q10", "q11", "q12", "q13", "q14"); } +/* // C = A * B void VecWriteBasic(int n, float *c, float *C, int ldc) { int nc1 = n / 16; @@ -1563,6 +1822,7 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *scale, : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13", "q14"); } +*/ #endif // __aarch64__ #else diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index b4bce43c7a29fba09ade7512cbc660f0ac2888ab..d8b305a7282b871d61ed588b1237f4f8f1cb56f8 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -28,6 +28,7 @@ namespace paddle_mobile { namespace operators { namespace math { +/* // 将 A 矩阵分块复制到连续内存(ColMajor) void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, float *buffer); @@ -35,6 +36,7 @@ void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, // 将 B 矩阵分块复制到连续内存(ColMajor) void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *buffer); +*/ // 将 A 矩阵分块复制到连续内存(RowMajor) void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda, @@ -51,7 +53,7 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, const float *b, float beta, float *c, float *C, int ldc, bool relu, float *new_scale, float *new_bias); - +/* // 向量矩阵乘法 (M = 1) void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, @@ -60,6 +62,7 @@ void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *new_scale, float *new_bias); +*/ // 计算一个更小的 C 矩阵分块 void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); @@ -81,6 +84,7 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias); +/* // 向量矩阵乘法结果回写 // C = A * B void VecWriteBasic(int n, float *c, float *C, int ldc); @@ -96,6 +100,7 @@ void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, // C = A * B, batchnorm(C), relu(C) void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, float *new_bias); +*/ // 32位 float 矩阵乘法 void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,