From 3d5e261be64191a3583cb3007763ce1777ef3ae1 Mon Sep 17 00:00:00 2001 From: yiicy Date: Wed, 20 Nov 2019 11:14:46 +0800 Subject: [PATCH] [ARM] sgemv support transA, test=develop (#2453) * [ARM] sgemv support transA, test=develop * add sgemv ut, test=develop --- lite/backends/arm/math/conv_impl.cc | 6 +- lite/backends/arm/math/sgemv.cc | 570 ++++++++++++++++++++++++-- lite/backends/arm/math/sgemv.h | 9 +- lite/kernels/arm/fc_compute.cc | 3 +- lite/kernels/arm/matmul_compute.cc | 2 +- lite/kernels/arm/mul_compute.cc | 5 +- lite/tests/math/CMakeLists.txt | 1 + lite/tests/math/sgemv_compute_test.cc | 194 +++++++++ 8 files changed, 739 insertions(+), 51 deletions(-) create mode 100644 lite/tests/math/sgemv_compute_test.cc diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 010563bf93..02a49cf157 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -202,7 +202,8 @@ void conv1x1s1_gemm(const float* i_data, k, flag_bias, bias_group, - flag_relu); + flag_relu, + ctx); } else { sgemm_prepack(false, m, @@ -395,7 +396,8 @@ void conv_im2col_gemm(const float* i_data, k, flag_bias, bias_group, - flag_relu); + flag_relu, + ctx); } else { int ldb = n; sgemm_prepack(false, diff --git a/lite/backends/arm/math/sgemv.cc b/lite/backends/arm/math/sgemv.cc index 506451932d..1830423136 100644 --- a/lite/backends/arm/math/sgemv.cc +++ b/lite/backends/arm/math/sgemv.cc @@ -14,6 +14,7 @@ #include "lite/backends/arm/math/sgemv.h" #include +#include #include "lite/utils/cp_logging.h" namespace paddle { @@ -50,6 +51,495 @@ void sgemv_bias_relu(const bool transA, const float *x, float *y, const float *bias); +#ifdef __aarch64__ +void sgemv_trans(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + bool flag_relu, + const ARMContext *ctx) { + int m_cnt16 = M >> 4; + int m_cnt8 = (M & 15) >> 3; + int m_cnt4 = (M & 15 & 7) >> 2; + int m_remain = M & 15 & 7 & 3; + int ths = ctx->threads(); + int valid_ths = std::min((N + 3) / 4, ths); + int valid_block = std::max(4, (N / valid_ths + 3) / 4 * 4); + valid_ths = (N + valid_block - 1) / valid_block; + int block_cnt = valid_block / 4; + float zero_buf[M]; // NOLINT + float y_buf[valid_ths * M]; // NOLINT + memset(zero_buf, 0, M * sizeof(float)); + if (flag_bias) { + memcpy(y_buf, bias, M * sizeof(float)); + memset(y_buf + M, 0, (valid_ths - 1) * M * sizeof(float)); + } else { + memset(y_buf, 0, valid_ths * M * sizeof(float)); + } +#pragma omp parallel for + for (int t = 0; t < valid_ths; ++t) { + float *block_y = y_buf + t * M; + const float *block_x = x + t * valid_block; + const float *block_A = A + t * valid_block * M; + for (int i = 0; i < block_cnt; ++i) { + float *y_ptr = block_y; + const float *x_ptr = block_x + i * 4; + const float *in0_ptr = block_A + i * 4 * M; + const float *in1_ptr = in0_ptr + M; + const float *in2_ptr = in1_ptr + M; + const float *in3_ptr = in2_ptr + M; + int offset = t * valid_block + (i + 1) * 4 - N; + if (offset > 0) { + if (offset > 3) { + in0_ptr = zero_buf; + in1_ptr = zero_buf; + in2_ptr = zero_buf; + in3_ptr = zero_buf; + } else { + switch (offset) { + case 3: + in1_ptr = zero_buf; + case 2: + in2_ptr = zero_buf; + case 1: + in3_ptr = zero_buf; + default: + break; + } + } + } + // clang-format off + if (m_cnt16 > 0) { + int cnt16 = m_cnt16; + asm volatile( + "ld1 {v4.4s}, [%[x]] \n" /* load x to v4 */ + "ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [%[in0]], #64 \n" /* load in0 to v5, v6, v7, v8 */ + "ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[in1]], #64 \n" /* load in1 to v9, v10, v11, v12 */ + "ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[in2]], #64 \n" /* load in2 to v13, v14, v15, v16 */ + "ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [%[in3]], #64 \n" /* load in3 to v17, v18, v19, v20 */ + "1:\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[y]] \n" /*load y to v0, v1, v2, v3 */ + "fmla v0.4s, v5.4s, v4.s[0] \n" /* v0 += v5 * v4[0] */ + "fmla v1.4s, v6.4s, v4.s[0] \n" /* v1 += v6 * v4[0] */ + "fmla v2.4s, v7.4s, v4.s[0] \n" /* v2 += v7 * v4[0] */ + "fmla v3.4s, v8.4s, v4.s[0] \n" /* v3 += v8 * v4[0] */ + "ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [%[in0]], #64 \n" /* load in0 to v5, v6, v7, v8 */ + "fmla v0.4s, v9.4s, v4.s[1] \n" /* v0 += v9 * v4[1] */ + "fmla v1.4s, v10.4s, v4.s[1] \n" /* v1 += v10 * v4[1] */ + "fmla v2.4s, v11.4s, v4.s[1] \n" /* v2 += v11 * v4[1] */ + "fmla v3.4s, v12.4s, v4.s[1] \n" /* v3 += v12 * v4[1] */ + "ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[in1]], #64 \n" /* load in1 to v9, v10, v11, v12 */ + "fmla v0.4s, v13.4s, v4.s[2] \n" /* v0 += v13 * v4[2] */ + "fmla v1.4s, v14.4s, v4.s[2] \n" /* v1 += v14 * v4[2] */ + "fmla v2.4s, v15.4s, v4.s[2] \n" /* v2 += v15 * v4[2] */ + "fmla v3.4s, v16.4s, v4.s[2] \n" /* v3 += v16 * v4[2] */ + "ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[in2]], #64 \n" /* load in2 to v13, v14, v15, v16 */ + "fmla v0.4s, v17.4s, v4.s[3] \n" /* v0 += v17 * v4[3] */ + "fmla v1.4s, v18.4s, v4.s[3] \n" /* v1 += v18 * v4[3] */ + "fmla v2.4s, v19.4s, v4.s[3] \n" /* v2 += v19 * v4[3] */ + "fmla v3.4s, v20.4s, v4.s[3] \n" /* v3 += v20 * v4[3] */ + "ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [%[in3]], #64 \n" /* load in3 to v17, v18, v19, v20 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[y]], #64 \n" /* store v0, v1, v2, v3 to y */ + "bne 1b \n" /* branch to label 1 */ + "sub %[in0], %[in0], #64 \n" /* restore in0 address */ + "sub %[in1], %[in1], #64 \n" /* restore in1 address */ + "sub %[in2], %[in2], #64 \n" /* restore in2 address */ + "sub %[in3], %[in3], #64 \n" /* restore in3 address */ + : [cnt] "+r"(cnt16), + [in0] "+r"(in0_ptr), + [in1] "+r"(in1_ptr), + [in2] "+r"(in2_ptr), + [in3] "+r"(in3_ptr), + [y] "+r"(y_ptr) + : [x] "r"(x_ptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "cc", "memory" + ); + } + if (m_cnt8 > 0) { + int cnt8 = m_cnt8; + asm volatile( + "ld1 {v2.4s}, [%[x]] \n" /* load x to v2 */ + "ld1 {v3.4s, v4.4s}, [%[in0]], #32 \n" /* load in0 to v3, v4 */ + "ld1 {v5.4s, v6.4s}, [%[in1]], #32 \n" /* load in1 to v5, v6 */ + "ld1 {v7.4s, v8.4s}, [%[in2]], #32 \n" /* load in2 to v7, v8 */ + "ld1 {v9.4s, v10.4s}, [%[in3]], #32 \n" /* load in3 to v9, v10*/ + "1:\n" + "ld1 {v0.4s, v1.4s}, [%[y]] \n" /* load y to v0, v1 */ + "fmla v0.4s, v3.4s, v2.s[0] \n" /* v0 += v3 * v2[0] */ + "fmla v1.4s, v4.4s, v2.s[0] \n" /* v1 += v4 * v2[0] */ + "prfm pldl1keep, [%[in0]] \n" /* preload in0 */ + "ld1 {v3.4s, v4.4s}, [%[in0]], #32 \n" /* load in0 to v3, v4 */ + "fmla v0.4s, v5.4s, v2.s[1] \n" /* v0 += v5 * v2[1] */ + "fmla v1.4s, v6.4s, v2.s[1] \n" /* v1 += v6 * v2[1] */ + "prfm pldl1keep, [%[in1]] \n" /* preload in1 */ + "ld1 {v5.4s, v6.4s}, [%[in1]], #32 \n" /* load in0 to v5, v6 */ + "fmla v0.4s, v7.4s, v2.s[2] \n" /* v0 += v7 * v2[2] */ + "fmla v1.4s, v8.4s, v2.s[2] \n" /* v1 += v8 * v2[2] */ + "prfm pldl1keep, [%[in2]] \n" /* preload in2 */ + "ld1 {v7.4s, v8.4s}, [%[in2]], #32 \n" /* load in0 to v7, v8 */ + "fmla v0.4s, v9.4s, v2.s[3] \n" /* v0 += v9 * v2[3] */ + "fmla v1.4s, v10.4s, v2.s[3] \n" /* v1 += v10 * v2[3] */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "prfm pldl1keep, [%[in3]] \n" /* preload in3 */ + "st1 {v0.4s, v1.4s}, [%[y]], #32 \n" /* store v0, v1 to y */ + "ld1 {v9.4s, v10.4s},[%[in3]], #32 \n" /* load in0 to v9, v10*/ + "bne 1b \n" /* branch to label 1 */ + "sub %[in0], %[in0], #32 \n" /* restore in0 address */ + "sub %[in1], %[in1], #32 \n" /* restore in1 address */ + "sub %[in2], %[in2], #32 \n" /* restore in2 address */ + "sub %[in3], %[in3], #32 \n" /* restore in3 address */ + : [cnt] "+r"(cnt8), + [in0] "+r"(in0_ptr), + [in1] "+r"(in1_ptr), + [in2] "+r"(in2_ptr), + [in3] "+r"(in3_ptr), + [y] "+r"(y_ptr) + : [x] "r"(x_ptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "cc", "memory" + ); + } + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "ld1 {v1.4s}, [%[in0]], #16 \n" /* load in0 to v1 */ + "ld1 {v2.4s}, [%[in1]], #16 \n" /* load in1 to v2 */ + "ld1 {v3.4s}, [%[in2]], #16 \n" /* load in2 to v3 */ + "ld1 {v4.4s}, [%[in3]], #16 \n" /* load in3 to v4 */ + "ld1 {v5.4s}, [%[x]] \n" /* load x to v5 */ + "1:\n" + "ld1 {v0.4s}, [%[y]] \n" /* load y to v0 */ + "fmla v0.4s, v1.4s, v5.s[0] \n" /* v0 += v1 * v5[0] */ + "prfm pldl1keep, [%[in0]] \n" /* preload in0 */ + "ld1 {v1.4s}, [%[in0]], #16 \n" /* load in0 to v1 */ + "fmla v0.4s, v2.4s, v5.s[1] \n" /* v0 += v2 * v5[1] */ + "prfm pldl1keep, [%[in1]] \n" /* preload in1 */ + "ld1 {v2.4s}, [%[in1]], #16 \n" /* load in1 to v2 */ + "fmla v0.4s, v3.4s, v5.s[2] \n" /* v0 += v3 * v5[2] */ + "prfm pldl1keep, [%[in2]] \n" /* preload in2 */ + "ld1 {v3.4s}, [%[in2]], #16 \n" /* load in2 to v3 */ + "fmla v0.4s, v4.4s, v5.s[3] \n" /* v0 += v4 * v5[3] */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "prfm pldl1keep, [%[in3]] \n" /* preload in3 */ + "st1 {v0.4s}, [%[y]], #16 \n" /* store v0 to y */ + "ld1 {v4.4s}, [%[in3]], #16 \n" /* load in3 to v4 */ + "bne 1b \n" /* branch to label 1 */ + "sub %[in0], %[in0], #16 \n" /* restore in0 address*/ + "sub %[in1], %[in1], #16 \n" /* restore in1 address*/ + "sub %[in2], %[in2], #16 \n" /* restore in2 address*/ + "sub %[in3], %[in3], #16 \n" /* restore in3 address*/ + : [cnt] "+r"(cnt4), + [in0] "+r"(in0_ptr), + [in1] "+r"(in1_ptr), + [in2] "+r"(in2_ptr), + [in3] "+r"(in3_ptr), + [y] "+r"(y_ptr) + : [x] "r"(x_ptr) + : "v0", "v1", "v2", "v3", "v4", "v5", "cc", "memory" + ); + } + // clang-format on + for (int r = 0; r < m_remain; ++r) { + float val0 = x_ptr[0] * in0_ptr[r]; + float val1 = x_ptr[1] * in1_ptr[r]; + float val2 = x_ptr[2] * in2_ptr[r]; + float val3 = x_ptr[3] * in3_ptr[r]; + y_ptr[r] += val0 + val1 + val2 + val3; + } + } + } + int cnt4 = M >> 2; + int remain = M & 3; + //! do reduction + int rdc_ths = valid_ths >> 1; + while (rdc_ths > 0) { +#pragma omp parallel for + for (int t = 0; t < rdc_ths; ++t) { + float *y0 = y_buf + t * M; + for (int i = t + rdc_ths; i < valid_ths; i += rdc_ths) { + float *y0_ptr = y0; + float *y_ptr = y_buf + i * M; + for (int j = 0; j < cnt4; ++j) { + float32x4_t val0 = vld1q_f32(y0_ptr + j * 4); + float32x4_t val1 = vld1q_f32(y_ptr + j * 4); + float32x4_t val = vaddq_f32(val0, val1); + vst1q_f32(y0_ptr + j * 4, val); + } + y0_ptr += cnt4 * 4; + y_ptr += cnt4 * 4; + for (int j = 0; j < remain; ++j) { + y0_ptr[j] += y_ptr[j]; + } + } + } + valid_ths = rdc_ths; + rdc_ths = rdc_ths >> 1; + } + if (flag_relu) { + float *in_y = y_buf; + float32x4_t vzero = vdupq_n_f32(0.f); + if (cnt4 > 0) { + int cnt = cnt4; + asm volatile( + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "1:\n" + "fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu */ + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero) + : "v0", "v1", "cc", "memory"); + } + for (int r = 0; r < remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + } + } else { + memcpy(y, y_buf, M * sizeof(float)); + } +} +#else +void sgemv_trans(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + bool flag_relu, + const ARMContext *ctx) { + int m_cnt8 = M >> 3; + int m_cnt4 = (M & 7) >> 2; + int m_remain = M & 7 & 3; + int ths = ctx->threads(); + int valid_ths = std::min((N + 3) / 4, ths); + int valid_block = std::max(4, (N / valid_ths + 3) / 4 * 4); + valid_ths = (N + valid_block - 1) / valid_block; + int block_cnt = valid_block / 4; + float zero_buf[M]; // NOLINT + float y_buf[valid_ths * M]; // NOLINT + memset(zero_buf, 0, M * sizeof(float)); + if (flag_bias) { + memcpy(y_buf, bias, M * sizeof(float)); + memset(y_buf + M, 0, (valid_ths - 1) * M * sizeof(float)); + } else { + memset(y_buf, 0, valid_ths * M * sizeof(float)); + } +#pragma omp parallel for + for (int t = 0; t < valid_ths; ++t) { + float *block_y = y_buf + t * M; + const float *block_x = x + t * valid_block; + const float *block_A = A + t * valid_block * M; + for (int i = 0; i < block_cnt; ++i) { + float *y_ptr = block_y; + const float *x_ptr = block_x + i * 4; + const float *in0_ptr = block_A + i * 4 * M; + const float *in1_ptr = in0_ptr + M; + const float *in2_ptr = in1_ptr + M; + const float *in3_ptr = in2_ptr + M; + int offset = t * valid_block + (i + 1) * 4 - N; + if (offset > 0) { + if (offset > 3) { + in0_ptr = zero_buf; + in1_ptr = zero_buf; + in2_ptr = zero_buf; + in3_ptr = zero_buf; + } else { + switch (offset) { + case 3: + in1_ptr = zero_buf; + case 2: + in2_ptr = zero_buf; + case 1: + in3_ptr = zero_buf; + default: + break; + } + } + } + // clang-format off + if (m_cnt8 > 0) { + int cnt8 = m_cnt8; + asm volatile( + "vld1.32 {d4-d5}, [%[x]] \n" /* load x to q2 */ + "vld1.32 {d6-d9}, [%[in0]]! \n" /* load in0 to q3, q4 */ + "vld1.32 {d10-d13},[%[in1]]! \n" /* load in1 to q5, q6 */ + "vld1.32 {d14-d17},[%[in2]]! \n" /* load in2 to q7, q8 */ + "vld1.32 {d18-d21},[%[in3]]! \n" /* load in3 to q9, q10*/ + "1:\n" + "vld1.32 {d0-d3}, [%[y]] \n" /* load y to q0, q1 */ + "vmla.f32 q0, q3, d4[0] \n" /* q0 += q3 * q2[0] */ + "vmla.f32 q1, q4, d4[0] \n" /* q1 += q4 * q2[0] */ + "pld [%[in0]] \n" /* preload in0 */ + "vld1.32 {d6-d9}, [%[in0]]! \n" /* load in0 to q3, q4 */ + "vmla.f32 q0, q5, d4[1] \n" /* q0 += q5 * q2[1] */ + "vmla.f32 q1, q6, d4[1] \n" /* q1 += q6 * q2[1] */ + "pld [%[in1]] \n" /* preload in1 */ + "vld1.32 {d10-d13},[%[in1]]! \n" /* load in0 to q5, q6 */ + "vmla.f32 q0, q7, d5[0] \n" /* q0 += q7 * q2[2] */ + "vmla.f32 q1, q8, d5[0] \n" /* q1 += q8 * q2[2] */ + "pld [%[in2]] \n" /* preload in2 */ + "vld1.32 {d14-d17},[%[in2]]! \n" /* load in0 to q7, q8 */ + "vmla.f32 q0, q9, d5[1] \n" /* q0 += q9 * q2[3] */ + "vmla.f32 q1, q10, d5[1] \n" /* q1 += q10 * q2[3] */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "pld [%[in3]] \n" /* preload in3 */ + "vst1.32 {d0-d3}, [%[y]]! \n" /* store q0, q1 to y */ + "vld1.32 {d18-d21},[%[in3]]! \n" /* load in0 to q9, q10*/ + "pld [%[y], #32] \n" /* preload y */ + "bne 1b \n" /* branch to label 1 */ + "sub %[in0], %[in0], #32 \n" /* restore in0 address */ + "sub %[in1], %[in1], #32 \n" /* restore in1 address */ + "sub %[in2], %[in2], #32 \n" /* restore in2 address */ + "sub %[in3], %[in3], #32 \n" /* restore in3 address */ + : [cnt] "+r"(cnt8), + [in0] "+r"(in0_ptr), + [in1] "+r"(in1_ptr), + [in2] "+r"(in2_ptr), + [in3] "+r"(in3_ptr), + [y] "+r"(y_ptr) + : [x] "r"(x_ptr) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "cc", "memory" + ); + } + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "vld1.32 {d2-d3}, [%[in0]]! \n" /* load in0 to q1 */ + "vld1.32 {d4-d5}, [%[in1]]! \n" /* load in1 to q2 */ + "vld1.32 {d6-d7}, [%[in2]]! \n" /* load in2 to q3 */ + "vld1.32 {d8-d9}, [%[in3]]! \n" /* load in3 to q4 */ + "vld1.32 {d10-d11},[%[x]] \n" /* load x to q5 */ + "1:\n" + "vld1.32 {d0-d1}, [%[y]] \n" /* load y to q0 */ + "vmla.f32 q0, q1, d10[0] \n" /* q0 += q1 * q5[0] */ + "pld [%[in0]] \n" /* preload in0 */ + "vld1.32 {d2-d3}, [%[in0]]! \n" /* load in0 to q1 */ + "vmla.f32 q0, q2, d10[1] \n" /* q0 += q2 * q5[1] */ + "pld [%[in1]] \n" /* preload in1 */ + "vld1.32 {d4-d5}, [%[in1]]! \n" /* load in0 to q2 */ + "vmla.f32 q0, q3, d11[0] \n" /* q0 += q3 * q5[2] */ + "pld [%[in2]] \n" /* preload in2 */ + "vld1.32 {d6-d7}, [%[in2]]! \n" /* load in0 to q3 */ + "vmla.f32 q0, q4, d11[1] \n" /* q0 += q4 * q5[3] */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "pld [%[in3]] \n" /* preload in3 */ + "vst1.32 {d0-d1}, [%[y]]! \n" /* store q0 to y */ + "vld1.32 {d8-d9}, [%[in3]]! \n" /* load in0 to q4 */ + "bne 1b \n" /* branch to label 1 */ + "sub %[in0], %[in0], #16 \n" /* restore in0 address*/ + "sub %[in1], %[in1], #16 \n" /* restore in1 address*/ + "sub %[in2], %[in2], #16 \n" /* restore in2 address*/ + "sub %[in3], %[in3], #16 \n" /* restore in3 address*/ + : [cnt] "+r"(cnt4), + [in0] "+r"(in0_ptr), + [in1] "+r"(in1_ptr), + [in2] "+r"(in2_ptr), + [in3] "+r"(in3_ptr), + [y] "+r"(y_ptr) + : [x] "r"(x_ptr) + : "q0", "q1", "q2", "q3", "q4", "q5", "cc", "memory" + ); + } + // clang-format on + for (int r = 0; r < m_remain; ++r) { + float val0 = x_ptr[0] * in0_ptr[r]; + float val1 = x_ptr[1] * in1_ptr[r]; + float val2 = x_ptr[2] * in2_ptr[r]; + float val3 = x_ptr[3] * in3_ptr[r]; + y_ptr[r] += val0 + val1 + val2 + val3; + } + } + } + //! do reduction + int rdc_ths = valid_ths >> 1; + while (rdc_ths > 0) { +#pragma omp parallel for + for (int t = 0; t < rdc_ths; ++t) { + float *y0 = y_buf + t * M; + for (int i = t + rdc_ths; i < valid_ths; i += rdc_ths) { + float *y0_ptr = y0; + float *y_ptr = y_buf + i * M; + for (int j = 0; j < m_cnt8; ++j) { + float32x4_t val00 = vld1q_f32(y0_ptr + j * 8); + float32x4_t val01 = vld1q_f32(y0_ptr + j * 8 + 4); + float32x4_t val10 = vld1q_f32(y_ptr + j * 8); + float32x4_t val11 = vld1q_f32(y_ptr + j * 8 + 4); + float32x4_t val0 = vaddq_f32(val00, val10); + float32x4_t val1 = vaddq_f32(val01, val11); + vst1q_f32(y0_ptr + j * 8, val0); + vst1q_f32(y0_ptr + j * 8 + 4, val1); + } + y0_ptr += m_cnt8 * 8; + y_ptr += m_cnt8 * 8; + for (int j = 0; j < m_cnt4; ++j) { + float32x4_t val0 = vld1q_f32(y0_ptr + j * 4); + float32x4_t val1 = vld1q_f32(y_ptr + j * 4); + float32x4_t val = vaddq_f32(val0, val1); + vst1q_f32(y0_ptr + j * 4, val); + } + y0_ptr += m_cnt4 * 4; + y_ptr += m_cnt4 * 4; + for (int j = 0; j < m_remain; ++j) { + y0_ptr[j] += y_ptr[j]; + } + } + } + valid_ths = rdc_ths; + rdc_ths = rdc_ths >> 1; + } + if (flag_relu) { + float *in_y = y_buf; + float32x4_t vzero = vdupq_n_f32(0.f); + if (m_cnt8 > 0) { + int cnt8 = m_cnt8; + asm volatile( + "vld1.32 {d0-d3}, [%[in_y]]! \n" /* load y to q0, q1 */ + "1:\n" + "vmax.f32 q2, q0, %q[vzero] \n" /* q0 relu */ + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "vmax.f32 q3, q1, %q[vzero] \n" /* q1 relu */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "vst1.32 {d4-d7}, [%[out_y]]! \n" /* store q0, q1 to y*/ + "vld1.32 {d2-d3}, [%[in_y]]! \n" /* load y to q0 */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #32 \n" /* restore in_y */ + : [cnt] "+r"(cnt8), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "1:\n" + "vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu */ + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + for (int r = 0; r < m_remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + } + } else { + memcpy(y, y_buf, M * sizeof(float)); + } +} +#endif // __aarch64__ bool sgemv(const float *A, const float *x, @@ -59,33 +549,34 @@ bool sgemv(const float *A, int N, bool is_bias, const float *bias, - bool is_relu) { + bool is_relu, + const ARMContext *ctx) { if (transA) { - LOG(ERROR) << " sgemv, transA is not supported now"; - return false; - } - if (is_bias) { - //! with bias - if (is_relu) { - //! with relu - sgemv_bias_relu(transA, M, N, A, x, y, bias); - } else { - //! without relu - sgemv_bias(transA, M, N, A, x, y, bias); - } + sgemv_trans(M, N, A, x, y, is_bias, bias, is_relu, ctx); } else { - //! without bias - if (is_relu) { - //! with relu - sgemv_relu(transA, M, N, A, x, y); + if (is_bias) { + //! with bias + if (is_relu) { + //! with relu + sgemv_bias_relu(transA, M, N, A, x, y, bias); + } else { + //! without relu + sgemv_bias(transA, M, N, A, x, y, bias); + } } else { - //! without relu - sgemv(transA, M, N, A, x, y); + //! without bias + if (is_relu) { + //! with relu + sgemv_relu(transA, M, N, A, x, y); + } else { + //! without relu + sgemv(transA, M, N, A, x, y); + } } } return true; } - +// clang-format off //! define compute kernel #ifdef __aarch64__ #define SGEMV_IN_8 \ @@ -179,8 +670,8 @@ bool sgemv(const float *A, "fmla v5.4s, v9.4s, v21.4s \n" /* mul + add*/ \ "fmla v6.4s, v9.4s, v23.4s \n" /* mul + add*/ \ "fmla v7.4s, v9.4s, v25.4s \n" /* mul + add*/ \ - "bne 1b \n" /* jump to main loop */ /* pair add to final \ - result */ \ + "bne 1b \n" /* jump to main loop */ \ + /* pair add to final result */ \ "2: \n" /* reduce to scale */ \ "faddp v16.4s, v0.4s, v0.4s\n" /* pair add to vector */ \ "faddp s8, v16.2s \n" /* pair add to scale */ \ @@ -231,8 +722,8 @@ bool sgemv(const float *A, "fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \ "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \ "fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \ - "bne 1b \n" /* jump to main loop */ /* pair add to final \ - result */ \ + "bne 1b \n" /* jump to main loop */ \ + /* pair add to final result */ \ "2: \n" /* reduce to scale */ \ "fadd v9.4s, v0.4s, v1.4s \n" /* add 2 vector */ \ "faddp v10.4s, v9.4s, v9.4s\n" /* pair add to vector */ \ @@ -283,7 +774,7 @@ bool sgemv(const float *A, "fmax s8, s8, s0 \n" /* relu */ \ "str s8, [%[out]] \n" /* save result */ -#else //__aarch64__ +#else // __aarch64__ #define SGEMV_IN_4 \ "pld [%[in]] @ preload cache line, input\n" \ @@ -349,8 +840,8 @@ bool sgemv(const float *A, "vmla.f32 q1, q5, q9 @ mul add\n" \ "vmla.f32 q2, q5, q11 @ mul add\n" \ "vmla.f32 q3, q5, q13 @ mul add\n" \ - "bne 1b @ jump to main loop\n" /* pair add to final \ - result */ \ + "bne 1b @ jump to main loop\n" \ + /* pair add to final result */ \ "2: @ pair add \n" \ "vpadd.f32 d8, d0, d1 @ pair add, first step\n" \ "vpadd.f32 d9, d2, d3 @ pair add, first step\n" \ @@ -382,13 +873,10 @@ bool sgemv(const float *A, "vmla.f32 q0, q12, q14 @ mul add\n" \ "vmla.f32 q0, q13, q15 @ mul add\n" \ "subs %[cnt] , #1 @ sub loop count \n" \ - "bne 1b @ jump to main loop\n" /* pair add to \ - final result \ - */ \ + "bne 1b @ jump to main loop\n" \ "2: @ end processing\n" \ "vpadd.f32 d2, d0, d1 @ pair add, first step\n" \ - "vpadd.f32 d0, d2, d2 @ pair add, final step\n" /* check tails \ - */ \ + "vpadd.f32 d0, d2, d2 @ pair add, final step\n"/*check tails*/ \ "cmp %[tail], #1 @ check whether has mid cols\n" \ "blt 4f @ jump to end\n" \ "3: @ tail loop\n" \ @@ -422,7 +910,7 @@ bool sgemv(const float *A, "vmax.f32 d0, d0, d1 @ relu\n" \ "vst1.32 {d0[0]}, [%[out]] @ save result\n" #endif - +// clang-format on void sgemv(const bool transA, const int M, const int N, @@ -523,7 +1011,7 @@ void sgemv(const bool transA, [tmp4] "r"(tmp4) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } -#else //__aarch64__ +#else // __aarch64__ int out_cnt = M >> 2; #pragma omp parallel for for (int j = 0; j < out_cnt; j++) { @@ -579,7 +1067,7 @@ void sgemv(const bool transA, : [out] "r"(ptr_out) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } -#endif //__aarch64__ +#endif // __aarch64__ } void sgemv_relu(const bool transA, @@ -671,7 +1159,7 @@ void sgemv_relu(const bool transA, : [out] "r"(ptr_out) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } -#else //__aarch64__ +#else // __aarch64__ int out_cnt = M >> 2; #pragma omp parallel for for (int j = 0; j < out_cnt; j++) { @@ -727,7 +1215,7 @@ void sgemv_relu(const bool transA, : [out] "r"(ptr_out) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } -#endif //__aarch64__ +#endif // __aarch64__ } void sgemv_bias(const bool transA, @@ -822,7 +1310,7 @@ void sgemv_bias(const bool transA, : [out] "r"(ptr_out), [bias0] "r"(bias0) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } -#else //__aarch64__ +#else // __aarch64__ int out_cnt = M >> 2; #pragma omp parallel for for (int j = 0; j < out_cnt; j++) { @@ -887,7 +1375,7 @@ void sgemv_bias(const bool transA, : [out] "r"(ptr_out), [bias0] "r"(bias0) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } -#endif //__aarch64__ +#endif // __aarch64__ } void sgemv_bias_relu(const bool transA, @@ -980,7 +1468,7 @@ void sgemv_bias_relu(const bool transA, : [out] "r"(ptr_out), [bias0] "r"(bias0) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } -#else //__aarch64__ +#else // __aarch64__ int out_cnt = M >> 2; #pragma omp parallel for for (int j = 0; j < out_cnt; j++) { @@ -1045,7 +1533,7 @@ void sgemv_bias_relu(const bool transA, : [out] "r"(ptr_out), [bias0] "r"(bias0) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } -#endif //__aarch64__ +#endif // __aarch64__ } } // namespace math diff --git a/lite/backends/arm/math/sgemv.h b/lite/backends/arm/math/sgemv.h index 4d74006f93..aa17349c99 100644 --- a/lite/backends/arm/math/sgemv.h +++ b/lite/backends/arm/math/sgemv.h @@ -15,6 +15,8 @@ #pragma once #include +#include "lite/core/context.h" +#include "lite/core/device_info.h" namespace paddle { namespace lite { @@ -28,9 +30,10 @@ bool sgemv(const float* A, bool transA, int M, int N, - bool is_bias = false, - const float* bias = nullptr, - bool is_relu = false); + bool is_bias, + const float* bias, + bool is_relu, + const ARMContext* ctx); } // namespace math } // namespace arm diff --git a/lite/kernels/arm/fc_compute.cc b/lite/kernels/arm/fc_compute.cc index 1983c73318..525eca269b 100644 --- a/lite/kernels/arm/fc_compute.cc +++ b/lite/kernels/arm/fc_compute.cc @@ -127,7 +127,8 @@ void FcCompute::Run() { k_, param.bias != nullptr, b_data, - false); + false, + &ctx); } } } diff --git a/lite/kernels/arm/matmul_compute.cc b/lite/kernels/arm/matmul_compute.cc index 29be34d0c2..d00a5bdc06 100644 --- a/lite/kernels/arm/matmul_compute.cc +++ b/lite/kernels/arm/matmul_compute.cc @@ -232,7 +232,7 @@ void MatMulCompute::Run() { int ldc = n_; if (n_ == 1) { lite::arm::math::sgemv( - x_data, y_data, o_data, false, m_, k_, false, nullptr, false); + x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx); if (fabsf(alpha - 1.f) > 1e-8f) { for (size_t i = 0; i < param.Out->dims().production(); ++i) { o_data[i] *= alpha; diff --git a/lite/kernels/arm/mul_compute.cc b/lite/kernels/arm/mul_compute.cc index fa43b6cf8e..debe9e907c 100644 --- a/lite/kernels/arm/mul_compute.cc +++ b/lite/kernels/arm/mul_compute.cc @@ -48,14 +48,13 @@ void MulCompute::Run() { CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; k_ = x_w; - + auto& ctx = this->ctx_->template As(); if (n_ == 1) { lite::arm::math::sgemv( - x_data, y_data, o_data, false, m_, k_, false, nullptr, false); + x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx); } else { constexpr bool is_tranposed_y = false; - auto& ctx = this->ctx_->template As(); int hblock = lite::arm::math::get_hblock(&ctx); int m_round = hblock * ((m_ + hblock - 1) / hblock); ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); diff --git a/lite/tests/math/CMakeLists.txt b/lite/tests/math/CMakeLists.txt index 87324375e0..d2acd14c83 100644 --- a/lite/tests/math/CMakeLists.txt +++ b/lite/tests/math/CMakeLists.txt @@ -1,5 +1,6 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(sgemv_compute_test SRCS sgemv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/math/sgemv_compute_test.cc b/lite/tests/math/sgemv_compute_test.cc new file mode 100644 index 0000000000..3c8965cb2c --- /dev/null +++ b/lite/tests/math/sgemv_compute_test.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/tests/utils/fill_data.h" +#include "lite/tests/utils/naive_math_impl.h" +#ifdef LITE_WITH_ARM +#include "lite/backends/arm/math/funcs.h" +#endif // LITE_WITH_ARM +#include "lite/core/context.h" +#include "lite/core/tensor.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +typedef paddle::lite::Tensor Tensor; + +DEFINE_int32(cluster, 3, "cluster id"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, true, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(M, 512, "sgemv: M"); +DEFINE_int32(K, 512, "sgemv: K"); + +DEFINE_bool(traA, false, "gemv: A transpose"); + +DEFINE_bool(flag_relu, false, "do relu"); +DEFINE_bool(flag_bias, false, "with bias"); + +bool test_sgemv( + bool tra, int m, int k, bool has_bias, bool has_relu, int cls, int ths) { + Tensor ta; + Tensor tb; + Tensor tc; + Tensor tc_basic; + Tensor tbias; + + ta.Resize({m, k}); + tb.Resize({k, 1}); + tc.Resize({m, 1}); + tc_basic.Resize({m, 1}); + tbias.Resize({m}); + + ta.set_precision(PRECISION(kFloat)); + tb.set_precision(PRECISION(kFloat)); + tc.set_precision(PRECISION(kFloat)); + tc_basic.set_precision(PRECISION(kFloat)); + tbias.set_precision(PRECISION(kFloat)); + + fill_tensor_rand(ta, -1.f, 1.f); + // fill_tensor_const(ta, 1.f); + fill_tensor_rand(tb, -1.f, 1.f); + // fill_tensor_const(tb, 1.f); + fill_tensor_rand(tbias, -1.f, 1.f); + + LOG(INFO) << "sgemv M: " << m << ", K: " << k + << ", transA: " << (tra ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", bias: " << (has_bias ? "true" : "false"); +#ifdef LITE_WITH_ARM + + auto da = ta.mutable_data(); + auto db = tb.mutable_data(); + auto dc = tc.mutable_data(); + auto dc_basic = tc_basic.mutable_data(); + auto dbias = tbias.mutable_data(); + + if (FLAGS_check_result) { + basic_gemv( + m, k, da, db, dbias, dc_basic, 1.f, 0.f, tra, has_bias, has_relu); + } + paddle::lite::Timer t0; + //! compute + double ops = 2.0 * m * k; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), ths); + /// warmup + for (int j = 0; j < FLAGS_warmup; ++j) { + paddle::lite::arm::math::sgemv( + da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx); + } + + t0.clear(); + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + paddle::lite::arm::math::sgemv( + da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx); + t0.end(); + } + LOG(INFO) << "gemv output: M: " << m << ", K: " << k << ", cluster: " << cls + << ", threads: " << ths << ", GOPS: " << ops * 1e-9f + << " GOPS, avg time: " << t0.get_average_ms() + << " ms, min time: " << t0.get_min_time() + << " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms() + << " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time() + << " GOPs"; + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + /// fp32 result + tensor_cmp_host(tc_basic, tc, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) { + Tensor tdiff; + tdiff.set_precision(PRECISION(kFloat)); + tdiff.Resize(tc.dims()); + tensor_diff(tc_basic, tc, tdiff); + LOG(INFO) << "basic result: "; + print_tensor(tc_basic); + LOG(INFO) << "saber result: "; + print_tensor(tc); + LOG(INFO) << "diff result: "; + print_tensor(tdiff); + return false; + } + } +#endif + return true; +} + +TEST(TestLiteSgemv, Sgemv) { + if (FLAGS_basic_test) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + LOG(INFO) << "run basic sgemv test"; + for (auto& m : {1, 3, 8, 21, 32, 397}) { + for (auto& k : {1, 3, 8, 17, 59, 234}) { + for (auto& tra : {true, false}) { + for (auto& has_bias : {false, true}) { + for (auto& has_relu : {false, true}) { + for (auto& th : {1, 2, 4}) { + auto flag = test_sgemv( + tra, m, k, has_bias, has_relu, FLAGS_cluster, th); + if (flag) { + LOG(INFO) << "test m = " << m << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", trans A: " << (tra ? "true" : "false") + << ", threads: " << th << " passed\n"; + } else { + LOG(FATAL) << "test m = " << m << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << ", trans A: " << (tra ? "true" : "false") + << ", threads: " << th << " failed\n"; + } + } + } + } + } + } + } + } +} + +TEST(TestSgemvCustom, Sgemv_custom) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + auto flag = test_sgemv(FLAGS_traA, + FLAGS_M, + FLAGS_K, + FLAGS_flag_bias, + FLAGS_flag_relu, + FLAGS_cluster, + FLAGS_threads); + if (!flag) { + LOG(FATAL) << "test m = " << FLAGS_M << ", k=" << FLAGS_K + << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias + << ", relu: " << FLAGS_flag_relu << " failed!!"; + } + LOG(INFO) << "test m = " << FLAGS_M << ", k=" << FLAGS_K + << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias + << ", relu: " << FLAGS_flag_relu << " passed!!"; +} -- GitLab