提交 237cf93b 编写于 作者: S StarryRain 提交者: Yanzhan Yang

add faster sgemv_notrans_mx1, fix test_fusion_op (#1772)

* add CPU_ARCH info, improve the performance of GEMM1*1s1

* improve the performance of gemm1*1s1_conv_add and gemm1*1s1_conv_add_bn_relu

* improve the performance of slidingwindow_bn_relu,slidingwindow_add,slidingwindow_add_bn_relu,gemm1*1s1_bn_relu,gemm1*1s1_add_relu

* add faster sgemv_notrans_mx1, fix test_fusion_op
上级 549ebc0a
......@@ -22,6 +22,27 @@ namespace operators {
template <>
bool FusionFcKernel<CPU, float>::Init(FusionFcParam<CPU> *param) {
int M = (int)param->InputX()->dims()[0];
if (M == 1) {
int r = param->InputY()->dims()[0];
int c = param->InputY()->dims()[1];
float *B = param->InputY()->data<float>();
framework::Tensor matrix_trans;
float *trans_b = matrix_trans.mutable_data<float>({r, c});
int index = 0;
for (int j = 0; j < c; j++) {
for (int i = 0; i < r; i++) {
trans_b[index++] = B[i * c + j];
}
}
index = 0;
for (int j = 0; j < c; j++) {
for (int i = 0; i < r; i++) {
B[index] = trans_b[index];
index++;
}
}
}
return true;
}
......
......@@ -32,6 +32,7 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
int axis = param.Axis();
Tensor *out = param.Out();
auto *out_data = out->mutable_data<Itype>();
int M = (int)input_x->dims()[0];
const Tensor x_matrix =
input_x->dims().size() > 2
......@@ -57,9 +58,15 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(Otype) * classes);
}
math::MatMul<Itype, Otype>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out, static_cast<float>(1),
false);
if (M == 1) {
math::MatMul<Itype, Otype>(x_matrix, false, y_matrix, true,
static_cast<float>(1), out,
static_cast<float>(1), false);
} else {
math::MatMul<Itype, Otype>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(1), false);
}
}
} // namespace operators
......
......@@ -414,6 +414,226 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
}
}
void sgemv_notrans_mx1_faster(const int M, const int N, const float alpha,
const float *A, const int lda, const float *B,
const float beta, float *C) {
#pragma omp parallel for
for (int m = 0; m < M - 3; m += 4) {
const float *a_ptr0 = A + m * lda;
const float *a_ptr1 = a_ptr0 + lda;
const float *a_ptr2 = a_ptr1 + lda;
const float *a_ptr3 = a_ptr2 + lda;
const float *b_ptr = B;
float *c_ptr = C + m;
float sum0 = 0.f;
float sum1 = 0.f;
float sum2 = 0.f;
float sum3 = 0.f;
int n = 0;
#if __ARM_NEON
/* matrix_mul_float:
* Calculate matrix A(4xN) * matrix B(Nx1) and store to a result array
* sum_arr[4], a 4x8 * 8x1 will be calculated on each iteration.
*
* Variable: a_ptr0 = pointer to the first row of matrix A, row major order
* Variable: a_ptr1 = pointer to the second row of matrix A, row major order
* Variable: a_ptr2 = pointer to the third row of matrix A, row major order
* Variable: a_ptr3 = pointer to the fourth row of matrix A, row major order
* Variable: b_ptr = pointer to the first col of matrix B, col major order
* Variable: s_ptr = pointer to the sum result array
* Variable: loop = the numbers of loops
*
* Register: Q(V)4-Q(V)11 = matrix A
* Register: Q(V)0-Q(V)1 = matrix B
* Register: Q(V)12-Q(V)15 = matrix C
*/
float sum_arr[4] = {0.f};
float *s_ptr = sum_arr;
int loop = N / 8;
#if __aarch64__
if (loop > 0) {
asm volatile(
// set v12-v15 to 0
"movi v12.4s, #0 \n"
"movi v13.4s, #0 \n"
"movi v14.4s, #0 \n"
"movi v15.4s, #0 \n"
"0: \n"
// load A and B
"ld1 {v0.4s, v1.4s}, [%[b_ptr]] , #32 \n"
"ld1 {v4.4s, v5.4s}, [%[a_ptr0]], #32 \n"
"ld1 {v6.4s, v7.4s}, [%[a_ptr1]], #32 \n"
"ld1 {v8.4s, v9.4s}, [%[a_ptr2]], #32 \n"
"ld1 {v10.4s, v11.4s}, [%[a_ptr3]], #32 \n"
"fmla v12.4s, v4.4s, v0.4s \n" // s0=A(r0c0-r0c3)*B(r0-r3)
"fmla v13.4s, v6.4s, v0.4s \n" // s1=A(r1c0-r1c3)*B(r0-r3)
"fmla v14.4s, v8.4s, v0.4s \n" // s2=A(r2c0-r2c3)*B(r0-r3)
"fmla v15.4s, v10.4s, v0.4s \n" // s3=A(r3c0-r3c3)*B(r0-r3)
"fmla v12.4s, v5.4s, v1.4s \n" // s0=A(r0c4-r0c7)*B(r4-r7)
"fmla v13.4s, v7.4s, v1.4s \n" // s1=A(r1c4-r1c7)*B(r4-r7)
"fmla v14.4s, v9.4s, v1.4s \n" // s2=A(r2c4-r2c7)*B(r4-r7)
"fmla v15.4s, v11.4s, v1.4s \n" // s3=A(r3c4-r3c7)*B(r4-r7)
// cycle
"subs %[loop], %[loop], #1 \n"
"bne 0b \n"
// add and store
"faddp v4.4s, v12.4s, v13.4s \n"
"faddp v5.4s, v14.4s, v15.4s \n"
"faddp v6.4s, v4.4s, v5.4s \n"
"st1 {v6.4s}, [%[s_ptr]] \n"
: [loop] "+r"(loop), [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1),
[a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3), [b_ptr] "+r"(b_ptr)
: [s_ptr] "r"(s_ptr)
: "v0", "v1", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "cc", "memory");
}
#else // __aarch64__
if (loop > 0) {
asm volatile(
// set Q12-Q15 to 0
"vmov.i32 q12, #0 \n"
"vmov.i32 q13, #0 \n"
"vmov.i32 q14, #0 \n"
"vmov.i32 q15, #0 \n"
"0: \n"
// load A and B
"vld1.f32 {d0-d3}, [%[b_ptr]]! \n"
"vld1.f32 {d8-d11}, [%[a_ptr0]]! \n"
"vld1.f32 {d12-d15}, [%[a_ptr1]]! \n"
"vld1.f32 {d16-d19}, [%[a_ptr2]]! \n"
"vld1.f32 {d20-d23}, [%[a_ptr3]]! \n"
"vmla.f32 q12, q4, q0 \n" // s0=A(r0c0-r0c3)*B(r0-r3)
"vmla.f32 q13, q6, q0 \n" // s1=A(r1c0-r1c3)*B(r0-r3)
"vmla.f32 q14, q8, q0 \n" // s2=A(r2c0-r2c3)*B(r0-r3)
"vmla.f32 q15, q10, q0 \n" // s3=A(r3c0-r3c3)*B(r0-r3)
"vmla.f32 q12, q5, q1 \n" // s0=A(r0c4-r0c7)*B(r4-r7)
"vmla.f32 q13, q7, q1 \n" // s1=A(r1c4-r1c7)*B(r4-r7)
"vmla.f32 q14, q9, q1 \n" // s2=A(r2c4-r2c7)*B(r4-r7)
"vmla.f32 q15, q11, q1 \n" // s3=A(r3c4-r3c7)*B(r4-r7)
// cycle
"subs %[loop], #1 \n"
"bne 0b \n"
// add and store
"vpadd.f32 d8, d24, d25 \n"
"vpadd.f32 d9, d26, d27 \n"
"vpadd.f32 d10, d28, d29 \n"
"vpadd.f32 d11, d30, d31 \n"
"vpadd.f32 d12, d8, d9 \n"
"vpadd.f32 d13, d10, d11 \n"
"vst1.32 {d12-d13}, [%[s_ptr]] \n"
: [loop] "+r"(loop), [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1),
[a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3), [b_ptr] "+r"(b_ptr)
: [s_ptr] "r"(s_ptr)
: "q0", "q1", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12",
"q13", "q14", "q15", "cc", "memory");
}
#endif // __aarch64__
sum0 += s_ptr[0];
sum1 += s_ptr[1];
sum2 += s_ptr[2];
sum3 += s_ptr[3];
n = N - (N & 0x07);
#endif // __ARM_NEON
for (; n < N - 7; n += 8) {
sum0 += a_ptr0[0] * b_ptr[0];
sum1 += a_ptr1[0] * b_ptr[0];
sum2 += a_ptr2[0] * b_ptr[0];
sum3 += a_ptr3[0] * b_ptr[0];
sum0 += a_ptr0[1] * b_ptr[1];
sum1 += a_ptr1[1] * b_ptr[1];
sum2 += a_ptr2[1] * b_ptr[1];
sum3 += a_ptr3[1] * b_ptr[1];
sum0 += a_ptr0[2] * b_ptr[2];
sum1 += a_ptr1[2] * b_ptr[2];
sum2 += a_ptr2[2] * b_ptr[2];
sum3 += a_ptr3[2] * b_ptr[2];
sum0 += a_ptr0[3] * b_ptr[3];
sum1 += a_ptr1[3] * b_ptr[3];
sum2 += a_ptr2[3] * b_ptr[3];
sum3 += a_ptr3[3] * b_ptr[3];
sum0 += a_ptr0[4] * b_ptr[4];
sum1 += a_ptr1[4] * b_ptr[4];
sum2 += a_ptr2[4] * b_ptr[4];
sum3 += a_ptr3[4] * b_ptr[4];
sum0 += a_ptr0[5] * b_ptr[5];
sum1 += a_ptr1[5] * b_ptr[5];
sum2 += a_ptr2[5] * b_ptr[5];
sum3 += a_ptr3[5] * b_ptr[5];
sum0 += a_ptr0[6] * b_ptr[6];
sum1 += a_ptr1[6] * b_ptr[6];
sum2 += a_ptr2[6] * b_ptr[6];
sum3 += a_ptr3[6] * b_ptr[6];
sum0 += a_ptr0[7] * b_ptr[7];
sum1 += a_ptr1[7] * b_ptr[7];
sum2 += a_ptr2[7] * b_ptr[7];
sum3 += a_ptr3[7] * b_ptr[7];
a_ptr0 += 8;
a_ptr1 += 8;
a_ptr2 += 8;
a_ptr3 += 8;
b_ptr += 8;
}
for (; n < N; ++n) {
sum0 += a_ptr0[0] * b_ptr[0];
sum1 += a_ptr1[0] * b_ptr[0];
sum2 += a_ptr2[0] * b_ptr[0];
sum3 += a_ptr3[0] * b_ptr[0];
a_ptr0 += 1;
a_ptr1 += 1;
a_ptr2 += 1;
a_ptr3 += 1;
b_ptr += 1;
}
c_ptr[0] = alpha * sum0 + beta * c_ptr[0];
c_ptr[1] = alpha * sum1 + beta * c_ptr[1];
c_ptr[2] = alpha * sum2 + beta * c_ptr[2];
c_ptr[3] = alpha * sum3 + beta * c_ptr[3];
}
int m_tail_start = M - (M & 0x03);
for (int m = m_tail_start; m < M; ++m) {
const float *a_ptr = A + m * lda;
const float *b_ptr = B;
float *c_ptr = C + m;
float sum = 0.f;
for (int n = 0; n < N; n++) {
sum += a_ptr[0] * b_ptr[0];
a_ptr += 1;
b_ptr += 1;
}
c_ptr[0] = alpha * sum + beta * c_ptr[0];
}
}
void sgemv_trans_mx1(const int M, const int N, const float alpha,
const float *A, const int lda, const float *B,
const float beta, float *C) {
......@@ -560,7 +780,8 @@ void sgemv_mx1(const bool trans, const int M, const int N, const float alpha,
if (trans) {
sgemv_trans_mx1(M, N, alpha, A, lda, B, beta, C);
} else {
sgemv_notrans_mx1(M, N, alpha, A, lda, B, beta, C);
// sgemv_notrans_mx1(M, N, alpha, A, lda, B, beta, C);
sgemv_notrans_mx1_faster(M, N, alpha, A, lda, B, beta, C);
}
}
......
......@@ -56,6 +56,7 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
int ldb = (!trans_b) ? dim_b[1] : dim_b[0];
Gemm gemm;
if (trans_a) {
......@@ -71,11 +72,12 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
a[index++] = tmp[i * n + j];
}
}
cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N);
cblas_sgemm(false, trans_b, M, N, K, alpha, a, K, matrix_b.data<float>(),
ldb, beta, matrix_out->data<float>(), N);
} else {
cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N);
cblas_sgemm(false, trans_b, M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), ldb, beta, matrix_out->data<float>(),
N);
}
}
......
......@@ -90,6 +90,11 @@ int TestFcOP() {
auto bias = bias_var->template GetMutable<framework::LoDTensor>();
SetupTensor<S>(bias, bias_shape, -127, 127);
framework::Tensor origin_matrix;
T *origin_inputB_ptr = origin_matrix.mutable_data<T>(inputB_shape);
memcpy(origin_inputB_ptr, inputB->data<T>(),
sizeof(*origin_inputB_ptr) * k * n);
auto scale_var = scope.get()->Var("scale");
auto scale = scale_var->template GetMutable<framework::LoDTensor>();
scale->Resize(framework::make_ddim({1}));
......@@ -105,13 +110,14 @@ int TestFcOP() {
op = new operators::FusionFcOp<CPU, T>("fusion_fc", inputs, outputs, attrs,
scope.get());
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
const T *output_data = output->data<T>();
// compare
T *c = static_cast<T *>(memory::Alloc(sizeof(T) * m * n));
T *a = inputA->data<T>();
T *b = inputB->data<T>();
T *b = origin_inputB_ptr;
S *bias_data = bias->data<S>();
for (int32_t i = 0; i < m; ++i) {
for (int32_t j = 0; j < n; ++j) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册