未验证 提交 dde12f0d 编写于 作者: Y yiicy 提交者: GitHub

[ARM] sgemv support transA, test=develop (#2453)

* [ARM] sgemv support transA, test=develop

* add sgemv ut, test=develop
上级 b094b2b6
......@@ -202,7 +202,8 @@ void conv1x1s1_gemm(const float* i_data,
k,
flag_bias,
bias_group,
flag_relu);
flag_relu,
ctx);
} else {
sgemm_prepack(false,
m,
......@@ -395,7 +396,8 @@ void conv_im2col_gemm(const float* i_data,
k,
flag_bias,
bias_group,
flag_relu);
flag_relu,
ctx);
} else {
int ldb = n;
sgemm_prepack(false,
......
......@@ -14,6 +14,7 @@
#include "lite/backends/arm/math/sgemv.h"
#include <arm_neon.h>
#include <algorithm>
#include "lite/utils/cp_logging.h"
namespace paddle {
......@@ -50,6 +51,495 @@ void sgemv_bias_relu(const bool transA,
const float *x,
float *y,
const float *bias);
#ifdef __aarch64__
void sgemv_trans(const int M,
const int N,
const float *A,
const float *x,
float *y,
bool flag_bias,
const float *bias,
bool flag_relu,
const ARMContext *ctx) {
int m_cnt16 = M >> 4;
int m_cnt8 = (M & 15) >> 3;
int m_cnt4 = (M & 15 & 7) >> 2;
int m_remain = M & 15 & 7 & 3;
int ths = ctx->threads();
int valid_ths = std::min((N + 3) / 4, ths);
int valid_block = std::max(4, (N / valid_ths + 3) / 4 * 4);
valid_ths = (N + valid_block - 1) / valid_block;
int block_cnt = valid_block / 4;
float zero_buf[M]; // NOLINT
float y_buf[valid_ths * M]; // NOLINT
memset(zero_buf, 0, M * sizeof(float));
if (flag_bias) {
memcpy(y_buf, bias, M * sizeof(float));
memset(y_buf + M, 0, (valid_ths - 1) * M * sizeof(float));
} else {
memset(y_buf, 0, valid_ths * M * sizeof(float));
}
#pragma omp parallel for
for (int t = 0; t < valid_ths; ++t) {
float *block_y = y_buf + t * M;
const float *block_x = x + t * valid_block;
const float *block_A = A + t * valid_block * M;
for (int i = 0; i < block_cnt; ++i) {
float *y_ptr = block_y;
const float *x_ptr = block_x + i * 4;
const float *in0_ptr = block_A + i * 4 * M;
const float *in1_ptr = in0_ptr + M;
const float *in2_ptr = in1_ptr + M;
const float *in3_ptr = in2_ptr + M;
int offset = t * valid_block + (i + 1) * 4 - N;
if (offset > 0) {
if (offset > 3) {
in0_ptr = zero_buf;
in1_ptr = zero_buf;
in2_ptr = zero_buf;
in3_ptr = zero_buf;
} else {
switch (offset) {
case 3:
in1_ptr = zero_buf;
case 2:
in2_ptr = zero_buf;
case 1:
in3_ptr = zero_buf;
default:
break;
}
}
}
// clang-format off
if (m_cnt16 > 0) {
int cnt16 = m_cnt16;
asm volatile(
"ld1 {v4.4s}, [%[x]] \n" /* load x to v4 */
"ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [%[in0]], #64 \n" /* load in0 to v5, v6, v7, v8 */
"ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[in1]], #64 \n" /* load in1 to v9, v10, v11, v12 */
"ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[in2]], #64 \n" /* load in2 to v13, v14, v15, v16 */
"ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [%[in3]], #64 \n" /* load in3 to v17, v18, v19, v20 */
"1:\n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[y]] \n" /*load y to v0, v1, v2, v3 */
"fmla v0.4s, v5.4s, v4.s[0] \n" /* v0 += v5 * v4[0] */
"fmla v1.4s, v6.4s, v4.s[0] \n" /* v1 += v6 * v4[0] */
"fmla v2.4s, v7.4s, v4.s[0] \n" /* v2 += v7 * v4[0] */
"fmla v3.4s, v8.4s, v4.s[0] \n" /* v3 += v8 * v4[0] */
"ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [%[in0]], #64 \n" /* load in0 to v5, v6, v7, v8 */
"fmla v0.4s, v9.4s, v4.s[1] \n" /* v0 += v9 * v4[1] */
"fmla v1.4s, v10.4s, v4.s[1] \n" /* v1 += v10 * v4[1] */
"fmla v2.4s, v11.4s, v4.s[1] \n" /* v2 += v11 * v4[1] */
"fmla v3.4s, v12.4s, v4.s[1] \n" /* v3 += v12 * v4[1] */
"ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[in1]], #64 \n" /* load in1 to v9, v10, v11, v12 */
"fmla v0.4s, v13.4s, v4.s[2] \n" /* v0 += v13 * v4[2] */
"fmla v1.4s, v14.4s, v4.s[2] \n" /* v1 += v14 * v4[2] */
"fmla v2.4s, v15.4s, v4.s[2] \n" /* v2 += v15 * v4[2] */
"fmla v3.4s, v16.4s, v4.s[2] \n" /* v3 += v16 * v4[2] */
"ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[in2]], #64 \n" /* load in2 to v13, v14, v15, v16 */
"fmla v0.4s, v17.4s, v4.s[3] \n" /* v0 += v17 * v4[3] */
"fmla v1.4s, v18.4s, v4.s[3] \n" /* v1 += v18 * v4[3] */
"fmla v2.4s, v19.4s, v4.s[3] \n" /* v2 += v19 * v4[3] */
"fmla v3.4s, v20.4s, v4.s[3] \n" /* v3 += v20 * v4[3] */
"ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [%[in3]], #64 \n" /* load in3 to v17, v18, v19, v20 */
"subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[y]], #64 \n" /* store v0, v1, v2, v3 to y */
"bne 1b \n" /* branch to label 1 */
"sub %[in0], %[in0], #64 \n" /* restore in0 address */
"sub %[in1], %[in1], #64 \n" /* restore in1 address */
"sub %[in2], %[in2], #64 \n" /* restore in2 address */
"sub %[in3], %[in3], #64 \n" /* restore in3 address */
: [cnt] "+r"(cnt16),
[in0] "+r"(in0_ptr),
[in1] "+r"(in1_ptr),
[in2] "+r"(in2_ptr),
[in3] "+r"(in3_ptr),
[y] "+r"(y_ptr)
: [x] "r"(x_ptr)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8",
"v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v20", "cc", "memory"
);
}
if (m_cnt8 > 0) {
int cnt8 = m_cnt8;
asm volatile(
"ld1 {v2.4s}, [%[x]] \n" /* load x to v2 */
"ld1 {v3.4s, v4.4s}, [%[in0]], #32 \n" /* load in0 to v3, v4 */
"ld1 {v5.4s, v6.4s}, [%[in1]], #32 \n" /* load in1 to v5, v6 */
"ld1 {v7.4s, v8.4s}, [%[in2]], #32 \n" /* load in2 to v7, v8 */
"ld1 {v9.4s, v10.4s}, [%[in3]], #32 \n" /* load in3 to v9, v10*/
"1:\n"
"ld1 {v0.4s, v1.4s}, [%[y]] \n" /* load y to v0, v1 */
"fmla v0.4s, v3.4s, v2.s[0] \n" /* v0 += v3 * v2[0] */
"fmla v1.4s, v4.4s, v2.s[0] \n" /* v1 += v4 * v2[0] */
"prfm pldl1keep, [%[in0]] \n" /* preload in0 */
"ld1 {v3.4s, v4.4s}, [%[in0]], #32 \n" /* load in0 to v3, v4 */
"fmla v0.4s, v5.4s, v2.s[1] \n" /* v0 += v5 * v2[1] */
"fmla v1.4s, v6.4s, v2.s[1] \n" /* v1 += v6 * v2[1] */
"prfm pldl1keep, [%[in1]] \n" /* preload in1 */
"ld1 {v5.4s, v6.4s}, [%[in1]], #32 \n" /* load in0 to v5, v6 */
"fmla v0.4s, v7.4s, v2.s[2] \n" /* v0 += v7 * v2[2] */
"fmla v1.4s, v8.4s, v2.s[2] \n" /* v1 += v8 * v2[2] */
"prfm pldl1keep, [%[in2]] \n" /* preload in2 */
"ld1 {v7.4s, v8.4s}, [%[in2]], #32 \n" /* load in0 to v7, v8 */
"fmla v0.4s, v9.4s, v2.s[3] \n" /* v0 += v9 * v2[3] */
"fmla v1.4s, v10.4s, v2.s[3] \n" /* v1 += v10 * v2[3] */
"subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */
"prfm pldl1keep, [%[in3]] \n" /* preload in3 */
"st1 {v0.4s, v1.4s}, [%[y]], #32 \n" /* store v0, v1 to y */
"ld1 {v9.4s, v10.4s},[%[in3]], #32 \n" /* load in0 to v9, v10*/
"bne 1b \n" /* branch to label 1 */
"sub %[in0], %[in0], #32 \n" /* restore in0 address */
"sub %[in1], %[in1], #32 \n" /* restore in1 address */
"sub %[in2], %[in2], #32 \n" /* restore in2 address */
"sub %[in3], %[in3], #32 \n" /* restore in3 address */
: [cnt] "+r"(cnt8),
[in0] "+r"(in0_ptr),
[in1] "+r"(in1_ptr),
[in2] "+r"(in2_ptr),
[in3] "+r"(in3_ptr),
[y] "+r"(y_ptr)
: [x] "r"(x_ptr)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "cc", "memory"
);
}
if (m_cnt4 > 0) {
int cnt4 = m_cnt4;
asm volatile(
"ld1 {v1.4s}, [%[in0]], #16 \n" /* load in0 to v1 */
"ld1 {v2.4s}, [%[in1]], #16 \n" /* load in1 to v2 */
"ld1 {v3.4s}, [%[in2]], #16 \n" /* load in2 to v3 */
"ld1 {v4.4s}, [%[in3]], #16 \n" /* load in3 to v4 */
"ld1 {v5.4s}, [%[x]] \n" /* load x to v5 */
"1:\n"
"ld1 {v0.4s}, [%[y]] \n" /* load y to v0 */
"fmla v0.4s, v1.4s, v5.s[0] \n" /* v0 += v1 * v5[0] */
"prfm pldl1keep, [%[in0]] \n" /* preload in0 */
"ld1 {v1.4s}, [%[in0]], #16 \n" /* load in0 to v1 */
"fmla v0.4s, v2.4s, v5.s[1] \n" /* v0 += v2 * v5[1] */
"prfm pldl1keep, [%[in1]] \n" /* preload in1 */
"ld1 {v2.4s}, [%[in1]], #16 \n" /* load in1 to v2 */
"fmla v0.4s, v3.4s, v5.s[2] \n" /* v0 += v3 * v5[2] */
"prfm pldl1keep, [%[in2]] \n" /* preload in2 */
"ld1 {v3.4s}, [%[in2]], #16 \n" /* load in2 to v3 */
"fmla v0.4s, v4.4s, v5.s[3] \n" /* v0 += v4 * v5[3] */
"subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */
"prfm pldl1keep, [%[in3]] \n" /* preload in3 */
"st1 {v0.4s}, [%[y]], #16 \n" /* store v0 to y */
"ld1 {v4.4s}, [%[in3]], #16 \n" /* load in3 to v4 */
"bne 1b \n" /* branch to label 1 */
"sub %[in0], %[in0], #16 \n" /* restore in0 address*/
"sub %[in1], %[in1], #16 \n" /* restore in1 address*/
"sub %[in2], %[in2], #16 \n" /* restore in2 address*/
"sub %[in3], %[in3], #16 \n" /* restore in3 address*/
: [cnt] "+r"(cnt4),
[in0] "+r"(in0_ptr),
[in1] "+r"(in1_ptr),
[in2] "+r"(in2_ptr),
[in3] "+r"(in3_ptr),
[y] "+r"(y_ptr)
: [x] "r"(x_ptr)
: "v0", "v1", "v2", "v3", "v4", "v5", "cc", "memory"
);
}
// clang-format on
for (int r = 0; r < m_remain; ++r) {
float val0 = x_ptr[0] * in0_ptr[r];
float val1 = x_ptr[1] * in1_ptr[r];
float val2 = x_ptr[2] * in2_ptr[r];
float val3 = x_ptr[3] * in3_ptr[r];
y_ptr[r] += val0 + val1 + val2 + val3;
}
}
}
int cnt4 = M >> 2;
int remain = M & 3;
//! do reduction
int rdc_ths = valid_ths >> 1;
while (rdc_ths > 0) {
#pragma omp parallel for
for (int t = 0; t < rdc_ths; ++t) {
float *y0 = y_buf + t * M;
for (int i = t + rdc_ths; i < valid_ths; i += rdc_ths) {
float *y0_ptr = y0;
float *y_ptr = y_buf + i * M;
for (int j = 0; j < cnt4; ++j) {
float32x4_t val0 = vld1q_f32(y0_ptr + j * 4);
float32x4_t val1 = vld1q_f32(y_ptr + j * 4);
float32x4_t val = vaddq_f32(val0, val1);
vst1q_f32(y0_ptr + j * 4, val);
}
y0_ptr += cnt4 * 4;
y_ptr += cnt4 * 4;
for (int j = 0; j < remain; ++j) {
y0_ptr[j] += y_ptr[j];
}
}
}
valid_ths = rdc_ths;
rdc_ths = rdc_ths >> 1;
}
if (flag_relu) {
float *in_y = y_buf;
float32x4_t vzero = vdupq_n_f32(0.f);
if (cnt4 > 0) {
int cnt = cnt4;
asm volatile(
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"1:\n"
"fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu */
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */
"st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #16 \n" /* restore in_y */
: [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero)
: "v0", "v1", "cc", "memory");
}
for (int r = 0; r < remain; ++r) {
y[r] = in_y[r] > 0.f ? in_y[r] : 0.f;
}
} else {
memcpy(y, y_buf, M * sizeof(float));
}
}
#else
void sgemv_trans(const int M,
const int N,
const float *A,
const float *x,
float *y,
bool flag_bias,
const float *bias,
bool flag_relu,
const ARMContext *ctx) {
int m_cnt8 = M >> 3;
int m_cnt4 = (M & 7) >> 2;
int m_remain = M & 7 & 3;
int ths = ctx->threads();
int valid_ths = std::min((N + 3) / 4, ths);
int valid_block = std::max(4, (N / valid_ths + 3) / 4 * 4);
valid_ths = (N + valid_block - 1) / valid_block;
int block_cnt = valid_block / 4;
float zero_buf[M]; // NOLINT
float y_buf[valid_ths * M]; // NOLINT
memset(zero_buf, 0, M * sizeof(float));
if (flag_bias) {
memcpy(y_buf, bias, M * sizeof(float));
memset(y_buf + M, 0, (valid_ths - 1) * M * sizeof(float));
} else {
memset(y_buf, 0, valid_ths * M * sizeof(float));
}
#pragma omp parallel for
for (int t = 0; t < valid_ths; ++t) {
float *block_y = y_buf + t * M;
const float *block_x = x + t * valid_block;
const float *block_A = A + t * valid_block * M;
for (int i = 0; i < block_cnt; ++i) {
float *y_ptr = block_y;
const float *x_ptr = block_x + i * 4;
const float *in0_ptr = block_A + i * 4 * M;
const float *in1_ptr = in0_ptr + M;
const float *in2_ptr = in1_ptr + M;
const float *in3_ptr = in2_ptr + M;
int offset = t * valid_block + (i + 1) * 4 - N;
if (offset > 0) {
if (offset > 3) {
in0_ptr = zero_buf;
in1_ptr = zero_buf;
in2_ptr = zero_buf;
in3_ptr = zero_buf;
} else {
switch (offset) {
case 3:
in1_ptr = zero_buf;
case 2:
in2_ptr = zero_buf;
case 1:
in3_ptr = zero_buf;
default:
break;
}
}
}
// clang-format off
if (m_cnt8 > 0) {
int cnt8 = m_cnt8;
asm volatile(
"vld1.32 {d4-d5}, [%[x]] \n" /* load x to q2 */
"vld1.32 {d6-d9}, [%[in0]]! \n" /* load in0 to q3, q4 */
"vld1.32 {d10-d13},[%[in1]]! \n" /* load in1 to q5, q6 */
"vld1.32 {d14-d17},[%[in2]]! \n" /* load in2 to q7, q8 */
"vld1.32 {d18-d21},[%[in3]]! \n" /* load in3 to q9, q10*/
"1:\n"
"vld1.32 {d0-d3}, [%[y]] \n" /* load y to q0, q1 */
"vmla.f32 q0, q3, d4[0] \n" /* q0 += q3 * q2[0] */
"vmla.f32 q1, q4, d4[0] \n" /* q1 += q4 * q2[0] */
"pld [%[in0]] \n" /* preload in0 */
"vld1.32 {d6-d9}, [%[in0]]! \n" /* load in0 to q3, q4 */
"vmla.f32 q0, q5, d4[1] \n" /* q0 += q5 * q2[1] */
"vmla.f32 q1, q6, d4[1] \n" /* q1 += q6 * q2[1] */
"pld [%[in1]] \n" /* preload in1 */
"vld1.32 {d10-d13},[%[in1]]! \n" /* load in0 to q5, q6 */
"vmla.f32 q0, q7, d5[0] \n" /* q0 += q7 * q2[2] */
"vmla.f32 q1, q8, d5[0] \n" /* q1 += q8 * q2[2] */
"pld [%[in2]] \n" /* preload in2 */
"vld1.32 {d14-d17},[%[in2]]! \n" /* load in0 to q7, q8 */
"vmla.f32 q0, q9, d5[1] \n" /* q0 += q9 * q2[3] */
"vmla.f32 q1, q10, d5[1] \n" /* q1 += q10 * q2[3] */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"pld [%[in3]] \n" /* preload in3 */
"vst1.32 {d0-d3}, [%[y]]! \n" /* store q0, q1 to y */
"vld1.32 {d18-d21},[%[in3]]! \n" /* load in0 to q9, q10*/
"pld [%[y], #32] \n" /* preload y */
"bne 1b \n" /* branch to label 1 */
"sub %[in0], %[in0], #32 \n" /* restore in0 address */
"sub %[in1], %[in1], #32 \n" /* restore in1 address */
"sub %[in2], %[in2], #32 \n" /* restore in2 address */
"sub %[in3], %[in3], #32 \n" /* restore in3 address */
: [cnt] "+r"(cnt8),
[in0] "+r"(in0_ptr),
[in1] "+r"(in1_ptr),
[in2] "+r"(in2_ptr),
[in3] "+r"(in3_ptr),
[y] "+r"(y_ptr)
: [x] "r"(x_ptr)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9", "q10", "cc", "memory"
);
}
if (m_cnt4 > 0) {
int cnt4 = m_cnt4;
asm volatile(
"vld1.32 {d2-d3}, [%[in0]]! \n" /* load in0 to q1 */
"vld1.32 {d4-d5}, [%[in1]]! \n" /* load in1 to q2 */
"vld1.32 {d6-d7}, [%[in2]]! \n" /* load in2 to q3 */
"vld1.32 {d8-d9}, [%[in3]]! \n" /* load in3 to q4 */
"vld1.32 {d10-d11},[%[x]] \n" /* load x to q5 */
"1:\n"
"vld1.32 {d0-d1}, [%[y]] \n" /* load y to q0 */
"vmla.f32 q0, q1, d10[0] \n" /* q0 += q1 * q5[0] */
"pld [%[in0]] \n" /* preload in0 */
"vld1.32 {d2-d3}, [%[in0]]! \n" /* load in0 to q1 */
"vmla.f32 q0, q2, d10[1] \n" /* q0 += q2 * q5[1] */
"pld [%[in1]] \n" /* preload in1 */
"vld1.32 {d4-d5}, [%[in1]]! \n" /* load in0 to q2 */
"vmla.f32 q0, q3, d11[0] \n" /* q0 += q3 * q5[2] */
"pld [%[in2]] \n" /* preload in2 */
"vld1.32 {d6-d7}, [%[in2]]! \n" /* load in0 to q3 */
"vmla.f32 q0, q4, d11[1] \n" /* q0 += q4 * q5[3] */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"pld [%[in3]] \n" /* preload in3 */
"vst1.32 {d0-d1}, [%[y]]! \n" /* store q0 to y */
"vld1.32 {d8-d9}, [%[in3]]! \n" /* load in0 to q4 */
"bne 1b \n" /* branch to label 1 */
"sub %[in0], %[in0], #16 \n" /* restore in0 address*/
"sub %[in1], %[in1], #16 \n" /* restore in1 address*/
"sub %[in2], %[in2], #16 \n" /* restore in2 address*/
"sub %[in3], %[in3], #16 \n" /* restore in3 address*/
: [cnt] "+r"(cnt4),
[in0] "+r"(in0_ptr),
[in1] "+r"(in1_ptr),
[in2] "+r"(in2_ptr),
[in3] "+r"(in3_ptr),
[y] "+r"(y_ptr)
: [x] "r"(x_ptr)
: "q0", "q1", "q2", "q3", "q4", "q5", "cc", "memory"
);
}
// clang-format on
for (int r = 0; r < m_remain; ++r) {
float val0 = x_ptr[0] * in0_ptr[r];
float val1 = x_ptr[1] * in1_ptr[r];
float val2 = x_ptr[2] * in2_ptr[r];
float val3 = x_ptr[3] * in3_ptr[r];
y_ptr[r] += val0 + val1 + val2 + val3;
}
}
}
//! do reduction
int rdc_ths = valid_ths >> 1;
while (rdc_ths > 0) {
#pragma omp parallel for
for (int t = 0; t < rdc_ths; ++t) {
float *y0 = y_buf + t * M;
for (int i = t + rdc_ths; i < valid_ths; i += rdc_ths) {
float *y0_ptr = y0;
float *y_ptr = y_buf + i * M;
for (int j = 0; j < m_cnt8; ++j) {
float32x4_t val00 = vld1q_f32(y0_ptr + j * 8);
float32x4_t val01 = vld1q_f32(y0_ptr + j * 8 + 4);
float32x4_t val10 = vld1q_f32(y_ptr + j * 8);
float32x4_t val11 = vld1q_f32(y_ptr + j * 8 + 4);
float32x4_t val0 = vaddq_f32(val00, val10);
float32x4_t val1 = vaddq_f32(val01, val11);
vst1q_f32(y0_ptr + j * 8, val0);
vst1q_f32(y0_ptr + j * 8 + 4, val1);
}
y0_ptr += m_cnt8 * 8;
y_ptr += m_cnt8 * 8;
for (int j = 0; j < m_cnt4; ++j) {
float32x4_t val0 = vld1q_f32(y0_ptr + j * 4);
float32x4_t val1 = vld1q_f32(y_ptr + j * 4);
float32x4_t val = vaddq_f32(val0, val1);
vst1q_f32(y0_ptr + j * 4, val);
}
y0_ptr += m_cnt4 * 4;
y_ptr += m_cnt4 * 4;
for (int j = 0; j < m_remain; ++j) {
y0_ptr[j] += y_ptr[j];
}
}
}
valid_ths = rdc_ths;
rdc_ths = rdc_ths >> 1;
}
if (flag_relu) {
float *in_y = y_buf;
float32x4_t vzero = vdupq_n_f32(0.f);
if (m_cnt8 > 0) {
int cnt8 = m_cnt8;
asm volatile(
"vld1.32 {d0-d3}, [%[in_y]]! \n" /* load y to q0, q1 */
"1:\n"
"vmax.f32 q2, q0, %q[vzero] \n" /* q0 relu */
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"vmax.f32 q3, q1, %q[vzero] \n" /* q1 relu */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"vst1.32 {d4-d7}, [%[out_y]]! \n" /* store q0, q1 to y*/
"vld1.32 {d2-d3}, [%[in_y]]! \n" /* load y to q0 */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #32 \n" /* restore in_y */
: [cnt] "+r"(cnt8), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero)
: "q0", "q1", "q2", "q3", "cc", "memory");
}
if (m_cnt4 > 0) {
int cnt4 = m_cnt4;
asm volatile(
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"1:\n"
"vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu */
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #16 \n" /* restore in_y */
: [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero)
: "q0", "q1", "cc", "memory");
}
for (int r = 0; r < m_remain; ++r) {
y[r] = in_y[r] > 0.f ? in_y[r] : 0.f;
}
} else {
memcpy(y, y_buf, M * sizeof(float));
}
}
#endif // __aarch64__
bool sgemv(const float *A,
const float *x,
......@@ -59,33 +549,34 @@ bool sgemv(const float *A,
int N,
bool is_bias,
const float *bias,
bool is_relu) {
bool is_relu,
const ARMContext *ctx) {
if (transA) {
LOG(ERROR) << " sgemv, transA is not supported now";
return false;
}
if (is_bias) {
//! with bias
if (is_relu) {
//! with relu
sgemv_bias_relu(transA, M, N, A, x, y, bias);
} else {
//! without relu
sgemv_bias(transA, M, N, A, x, y, bias);
}
sgemv_trans(M, N, A, x, y, is_bias, bias, is_relu, ctx);
} else {
//! without bias
if (is_relu) {
//! with relu
sgemv_relu(transA, M, N, A, x, y);
if (is_bias) {
//! with bias
if (is_relu) {
//! with relu
sgemv_bias_relu(transA, M, N, A, x, y, bias);
} else {
//! without relu
sgemv_bias(transA, M, N, A, x, y, bias);
}
} else {
//! without relu
sgemv(transA, M, N, A, x, y);
//! without bias
if (is_relu) {
//! with relu
sgemv_relu(transA, M, N, A, x, y);
} else {
//! without relu
sgemv(transA, M, N, A, x, y);
}
}
}
return true;
}
// clang-format off
//! define compute kernel
#ifdef __aarch64__
#define SGEMV_IN_8 \
......@@ -179,8 +670,8 @@ bool sgemv(const float *A,
"fmla v5.4s, v9.4s, v21.4s \n" /* mul + add*/ \
"fmla v6.4s, v9.4s, v23.4s \n" /* mul + add*/ \
"fmla v7.4s, v9.4s, v25.4s \n" /* mul + add*/ \
"bne 1b \n" /* jump to main loop */ /* pair add to final \
result */ \
"bne 1b \n" /* jump to main loop */ \
/* pair add to final result */ \
"2: \n" /* reduce to scale */ \
"faddp v16.4s, v0.4s, v0.4s\n" /* pair add to vector */ \
"faddp s8, v16.2s \n" /* pair add to scale */ \
......@@ -231,8 +722,8 @@ bool sgemv(const float *A,
"fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \
"subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \
"fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \
"bne 1b \n" /* jump to main loop */ /* pair add to final \
result */ \
"bne 1b \n" /* jump to main loop */ \
/* pair add to final result */ \
"2: \n" /* reduce to scale */ \
"fadd v9.4s, v0.4s, v1.4s \n" /* add 2 vector */ \
"faddp v10.4s, v9.4s, v9.4s\n" /* pair add to vector */ \
......@@ -283,7 +774,7 @@ bool sgemv(const float *A,
"fmax s8, s8, s0 \n" /* relu */ \
"str s8, [%[out]] \n" /* save result */
#else //__aarch64__
#else // __aarch64__
#define SGEMV_IN_4 \
"pld [%[in]] @ preload cache line, input\n" \
......@@ -349,8 +840,8 @@ bool sgemv(const float *A,
"vmla.f32 q1, q5, q9 @ mul add\n" \
"vmla.f32 q2, q5, q11 @ mul add\n" \
"vmla.f32 q3, q5, q13 @ mul add\n" \
"bne 1b @ jump to main loop\n" /* pair add to final \
result */ \
"bne 1b @ jump to main loop\n" \
/* pair add to final result */ \
"2: @ pair add \n" \
"vpadd.f32 d8, d0, d1 @ pair add, first step\n" \
"vpadd.f32 d9, d2, d3 @ pair add, first step\n" \
......@@ -382,13 +873,10 @@ bool sgemv(const float *A,
"vmla.f32 q0, q12, q14 @ mul add\n" \
"vmla.f32 q0, q13, q15 @ mul add\n" \
"subs %[cnt] , #1 @ sub loop count \n" \
"bne 1b @ jump to main loop\n" /* pair add to \
final result \
*/ \
"bne 1b @ jump to main loop\n" \
"2: @ end processing\n" \
"vpadd.f32 d2, d0, d1 @ pair add, first step\n" \
"vpadd.f32 d0, d2, d2 @ pair add, final step\n" /* check tails \
*/ \
"vpadd.f32 d0, d2, d2 @ pair add, final step\n"/*check tails*/ \
"cmp %[tail], #1 @ check whether has mid cols\n" \
"blt 4f @ jump to end\n" \
"3: @ tail loop\n" \
......@@ -422,7 +910,7 @@ bool sgemv(const float *A,
"vmax.f32 d0, d0, d1 @ relu\n" \
"vst1.32 {d0[0]}, [%[out]] @ save result\n"
#endif
// clang-format on
void sgemv(const bool transA,
const int M,
const int N,
......@@ -523,7 +1011,7 @@ void sgemv(const bool transA,
[tmp4] "r"(tmp4)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory");
}
#else //__aarch64__
#else // __aarch64__
int out_cnt = M >> 2;
#pragma omp parallel for
for (int j = 0; j < out_cnt; j++) {
......@@ -579,7 +1067,7 @@ void sgemv(const bool transA,
: [out] "r"(ptr_out)
: "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory");
}
#endif //__aarch64__
#endif // __aarch64__
}
void sgemv_relu(const bool transA,
......@@ -671,7 +1159,7 @@ void sgemv_relu(const bool transA,
: [out] "r"(ptr_out)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory");
}
#else //__aarch64__
#else // __aarch64__
int out_cnt = M >> 2;
#pragma omp parallel for
for (int j = 0; j < out_cnt; j++) {
......@@ -727,7 +1215,7 @@ void sgemv_relu(const bool transA,
: [out] "r"(ptr_out)
: "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory");
}
#endif //__aarch64__
#endif // __aarch64__
}
void sgemv_bias(const bool transA,
......@@ -822,7 +1310,7 @@ void sgemv_bias(const bool transA,
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory");
}
#else //__aarch64__
#else // __aarch64__
int out_cnt = M >> 2;
#pragma omp parallel for
for (int j = 0; j < out_cnt; j++) {
......@@ -887,7 +1375,7 @@ void sgemv_bias(const bool transA,
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory");
}
#endif //__aarch64__
#endif // __aarch64__
}
void sgemv_bias_relu(const bool transA,
......@@ -980,7 +1468,7 @@ void sgemv_bias_relu(const bool transA,
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory");
}
#else //__aarch64__
#else // __aarch64__
int out_cnt = M >> 2;
#pragma omp parallel for
for (int j = 0; j < out_cnt; j++) {
......@@ -1045,7 +1533,7 @@ void sgemv_bias_relu(const bool transA,
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory");
}
#endif //__aarch64__
#endif // __aarch64__
}
} // namespace math
......
......@@ -15,6 +15,8 @@
#pragma once
#include <cmath>
#include "lite/core/context.h"
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
......@@ -28,9 +30,10 @@ bool sgemv(const float* A,
bool transA,
int M,
int N,
bool is_bias = false,
const float* bias = nullptr,
bool is_relu = false);
bool is_bias,
const float* bias,
bool is_relu,
const ARMContext* ctx);
} // namespace math
} // namespace arm
......
......@@ -127,7 +127,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
k_,
param.bias != nullptr,
b_data,
false);
false,
&ctx);
}
}
}
......
......@@ -232,7 +232,7 @@ void MatMulCompute::Run() {
int ldc = n_;
if (n_ == 1) {
lite::arm::math::sgemv(
x_data, y_data, o_data, false, m_, k_, false, nullptr, false);
x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx);
if (fabsf(alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= alpha;
......
......@@ -48,14 +48,13 @@ void MulCompute::Run() {
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
k_ = x_w;
auto& ctx = this->ctx_->template As<ARMContext>();
if (n_ == 1) {
lite::arm::math::sgemv(
x_data, y_data, o_data, false, m_, k_, false, nullptr, false);
x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx);
} else {
constexpr bool is_tranposed_y = false;
auto& ctx = this->ctx_->template As<ARMContext>();
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
......
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(sgemv_compute_test SRCS sgemv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
#ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h"
#endif // LITE_WITH_ARM
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
typedef paddle::lite::Tensor Tensor;
DEFINE_int32(cluster, 3, "cluster id");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "sgemv: M");
DEFINE_int32(K, 512, "sgemv: K");
DEFINE_bool(traA, false, "gemv: A transpose");
DEFINE_bool(flag_relu, false, "do relu");
DEFINE_bool(flag_bias, false, "with bias");
bool test_sgemv(
bool tra, int m, int k, bool has_bias, bool has_relu, int cls, int ths) {
Tensor ta;
Tensor tb;
Tensor tc;
Tensor tc_basic;
Tensor tbias;
ta.Resize({m, k});
tb.Resize({k, 1});
tc.Resize({m, 1});
tc_basic.Resize({m, 1});
tbias.Resize({m});
ta.set_precision(PRECISION(kFloat));
tb.set_precision(PRECISION(kFloat));
tc.set_precision(PRECISION(kFloat));
tc_basic.set_precision(PRECISION(kFloat));
tbias.set_precision(PRECISION(kFloat));
fill_tensor_rand(ta, -1.f, 1.f);
// fill_tensor_const(ta, 1.f);
fill_tensor_rand(tb, -1.f, 1.f);
// fill_tensor_const(tb, 1.f);
fill_tensor_rand(tbias, -1.f, 1.f);
LOG(INFO) << "sgemv M: " << m << ", K: " << k
<< ", transA: " << (tra ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM
auto da = ta.mutable_data<float>();
auto db = tb.mutable_data<float>();
auto dc = tc.mutable_data<float>();
auto dc_basic = tc_basic.mutable_data<float>();
auto dbias = tbias.mutable_data<float>();
if (FLAGS_check_result) {
basic_gemv(
m, k, da, db, dbias, dc_basic, 1.f, 0.f, tra, has_bias, has_relu);
}
paddle::lite::Timer t0;
//! compute
double ops = 2.0 * m * k;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
/// warmup
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemv(
da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx);
}
t0.clear();
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
paddle::lite::arm::math::sgemv(
da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx);
t0.end();
}
LOG(INFO) << "gemv output: M: " << m << ", K: " << k << ", cluster: " << cls
<< ", threads: " << ths << ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.get_average_ms()
<< " ms, min time: " << t0.get_min_time()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time()
<< " GOPs";
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
/// fp32 result
tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kFloat));
tdiff.Resize(tc.dims());
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "saber result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
}
}
#endif
return true;
}
TEST(TestLiteSgemv, Sgemv) {
if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemv test";
for (auto& m : {1, 3, 8, 21, 32, 397}) {
for (auto& k : {1, 3, 8, 17, 59, 234}) {
for (auto& tra : {true, false}) {
for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) {
for (auto& th : {1, 2, 4}) {
auto flag = test_sgemv(
tra, m, k, has_bias, has_relu, FLAGS_cluster, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false")
<< ", threads: " << th << " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false")
<< ", threads: " << th << " failed\n";
}
}
}
}
}
}
}
}
}
TEST(TestSgemvCustom, Sgemv_custom) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
auto flag = test_sgemv(FLAGS_traA,
FLAGS_M,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_cluster,
FLAGS_threads);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", k=" << FLAGS_K
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", k=" << FLAGS_K
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " passed!!";
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册