提交 d5067d4c 编写于 作者: Z zhaojiaying01

update gemm with neon assembly

上级 e6ebcfe8
......@@ -3,6 +3,11 @@ project(paddle-mobile)
#add_definitions(-DPADDLE_MOBILE_DEBUG)
add_definitions(-DENABLE_EXCEPTION)
add_definitions(-DARMV7)
#add_definitions(-DARMV8)
#add_definitions(-DIOS)
#add_definitions(-DX86)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
set(CMAKE_BUILD_TYPE RelWithDebInfo)
set(CMAKE_VERBOSE_MAKEFILE ON)
......
......@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/gemm.h"
#ifndef X86
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
float ab[MR * NR];
// 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int paddingM, const float *A, int lda,
float *buffer) {
......@@ -170,17 +174,197 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
}
// 计算一个更小的 4 * 4 的 C 矩阵分块
#if defined(IOS)
void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
int ldb, float beta, float *C, int ldc, int mc, int nc) {
// init C
float32x4_t cv0 = vdupq_n_f32(0.0);
float32x4_t cv1 = vdupq_n_f32(0.0);
float32x4_t cv2 = vdupq_n_f32(0.0);
float32x4_t cv3 = vdupq_n_f32(0.0);
float32x4_t av;
float32x4_t bv;
float32x2_t av01;
float32x2_t av23;
for (int p = 0; p < k; p += 1) {
av = vld1q_f32(a);
bv = vld1q_f32(b);
av01 = vget_low_f32(av);
cv0 = vmlaq_lane_f32(cv0, bv, av01, 0);
cv1 = vmlaq_lane_f32(cv1, bv, av01, 1);
av23 = vget_high_f32(av);
cv2 = vmlaq_lane_f32(cv2, bv, av23, 0);
cv3 = vmlaq_lane_f32(cv3, bv, av23, 1);
a += MR;
b += NR;
}
float32x4x4_t cv = {cv0, cv1, cv2, cv3};
int i, j;
for (i = 0; i < mc; ++i) {
for (j = 0; j < nc; ++j) {
if (beta == 0.0) {
C(i, j) = 0.0;
} else if (beta != 1.0) {
C(i, j) *= beta;
}
if (j == 0) {
C(i, j) += alpha * vgetq_lane_f32(cv.val[i], 0);
} else if (j == 1) {
C(i, j) += alpha * vgetq_lane_f32(cv.val[i], 1);
} else if (j == 2) {
C(i, j) += alpha * vgetq_lane_f32(cv.val[i], 2);
} else if (j == 3) {
C(i, j) += alpha * vgetq_lane_f32(cv.val[i], 3);
}
}
}
}
#elif defined(ARMV7)
void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
int ldb, float beta, float *C, int ldc, int mc, int nc) {
int kc1 = k / 2, kc2 = k % 2;
int bytes_ldc = 4 * ldc;
int flag_alpha = (alpha == 1.0) ? 1 : 2;
int flag_beta;
if (beta == 0.0) {
flag_beta = 0;
} else if (beta == 1.0) {
flag_beta = 1;
} else {
flag_beta = 2;
}
asm volatile(
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[a]]! \n\t"
"vld1.32 {q2, q3}, [%[b]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q2, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q2, d1[1] \n\t"
"vmla.f32 q10, q3, d2[0] \n\t"
"vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t"
"vld1.32 {q0}, [%[a]]! \n\t"
"vld1.32 {q1}, [%[b]]! \n\t"
"vmla.f32 q10, q1, d0[0] \n\t"
"vmla.f32 q11, q1, d0[1] \n\t"
"vmla.f32 q12, q1, d1[0] \n\t"
"vmla.f32 q13, q1, d1[1] \n\t"
"end_kc2_%=: \n\t"
"cmp %[mc], #4 \n\t"
"bne temp_%= \n\t"
"cmp %[nc], #4 \n\t"
"bne temp_%= \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"cmp %[flag_alpha], #1 \n\t"
"bne alpha_%= \n\t"
"alpha_%=: \n\t"
"vmul.f32 q10, q10, d8[0] \n\t"
"vmul.f32 q11, q11, d8[0] \n\t"
"vmul.f32 q12, q12, d8[0] \n\t"
"vmul.f32 q13, q13, d8[0] \n\t"
"beta_%=: \n\t"
"cmp %[flag_beta], #0 \n\t"
"beq memory_%= \n\t"
"mov r4, %[C] \n\t"
"mov r6, %[bytes_ldc]\n\t"
"vld1.32 {q0}, [r4], r6 \n\t"
"vld1.32 {q1}, [r4], r6 \n\t"
"vld1.32 {q2}, [r4], r6 \n\t"
"vld1.32 {q3}, [r4] \n\t"
"cmp %[flag_beta], #1 \n\t"
"beq beta_eq1_%= \n\t"
"bne beta_ne1_%= \n\t"
"beta_eq1_%=: \n\t"
"vadd.f32 q10, q10, q0 \n\t"
"vadd.f32 q11, q11, q1 \n\t"
"vadd.f32 q12, q12, q2 \n\t"
"vadd.f32 q13, q13, q3 \n\t"
"b memory_%= \n\t"
"beta_ne1_%=: \n\t"
"vmla.f32 q10, q0, d8[1] \n\t"
"vmla.f32 q11, q1, d8[1] \n\t"
"vmla.f32 q12, q2, d8[1] \n\t"
"vmla.f32 q13, q3, d8[1] \n\t"
"memory_%=: \n\t"
"mov r5, %[C] \n\t"
"mov r6, %[bytes_ldc]\n\t"
"vst1.32 {q10}, [r5], r6 \n\t"
"vst1.32 {q11}, [r5], r6 \n\t"
"vst1.32 {q12}, [r5], r6 \n\t"
"vst1.32 {q13}, [r5] \n\t"
"b end_%= \n\t"
"temp_%=: \n\t"
"vst1.32 {q10, q11}, [%[ab]]!\n\t"
"vst1.32 {q12, q13}, [%[ab]] \n\t"
"end_%=: \n\t"
:
: [a] "r"(a), [b] "r"(b), [C] "r"(C), [ab] "r"(ab), [kc1] "r"(kc1),
[kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha),
[beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc),
[flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta)
: "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13");
if (mc != MR || nc != NR) {
int i, j;
for (i = 0; i < mc; ++i) {
for (j = 0; j < nc; ++j) {
if (beta == 0.0) {
if (alpha != 1.0) {
C(i, j) = alpha * ab[i * MR + j];
} else {
C(i, j) = ab[i * MR + j];
}
} else {
if (beta != 1.0) {
C(i, j) *= beta;
}
if (alpha != 1.0) {
C(i, j) += alpha * ab[i * MR + j];
} else {
C(i, j) += ab[i * MR + j];
}
}
}
}
}
}
#else
void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
int ldb, float beta, float *C, int ldc, int mc, int nc) {
float c[16] = {0};
float reg_a0, reg_a1, reg_a2, reg_a3, reg_b0, reg_b1, reg_b2, reg_b3;
// // init C
// float32x4_t cv0 = vdup_n_f32(0.0);
// float32x4_t cv1 = vdup_n_f32(0.0);
// float32x4_t cv2 = vdup_n_f32(0.0);
// float32x4_t cv3 = vdup_n_f32(0.0);
for (int p = 0; p < k; p += 1) {
reg_b0 = *b++;
reg_b1 = *b++;
......@@ -232,6 +416,7 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
}
}
}
#endif
// 32位 float 矩阵乘法
void sgemm(int m, int n, int k, float alpha, const float *A, int lda,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册