未验证 提交 66d2ae25 编写于 作者: Y yiicy 提交者: GitHub

[ARM] add sgemmc4 common and small kernel, support for winograd, test=develop (#2471)

* unfinish sgemmc4

* finish armv8 sgemmc4

* arm add sgemmc4 with deal with remain

* [ARM] add sgemmc4 small kernel, test=develop
上级 a7f7d49b
......@@ -60,6 +60,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
cc_library(math_arm SRCS
funcs.cc
packed_sgemm.cc
packed_sgemm_c4.cc
sgemm.cc
gemm_prepacked_int8.cc
gemm_s8.cc
......
......@@ -43,6 +43,7 @@
#include "lite/backends/arm/math/negative.h"
#include "lite/backends/arm/math/norm.h"
#include "lite/backends/arm/math/packed_sgemm.h"
#include "lite/backends/arm/math/packed_sgemm_c4.h"
#include "lite/backends/arm/math/pad2d.h"
#include "lite/backends/arm/math/pooling.h"
#include "lite/backends/arm/math/power.h"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/packed_sgemm_c4.h"
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void loadb_c4(float* out,
const float* in,
const int xstart,
const int xend,
const int k_round,
const int n) {
const int xlen = (xend - xstart + NBLOCK_C4 - 1) / NBLOCK_C4 * NBLOCK_C4;
int xloop = xlen / NBLOCK_C4;
const int flag_remain = n < xstart + xlen;
int remain = 0;
int remain4 = 0;
int remain1 = 0;
if (flag_remain) {
remain = (n - xstart) - (xloop - 1) * NBLOCK_C4;
remain4 = remain >> 2;
remain1 = remain & 3;
xloop -= 1;
}
const int ldo = NBLOCK_C4 * k_round;
const int kloop = k_round >> 2;
in += xstart * 4;
if (xloop > 0) {
#pragma omp parallel for
for (int i = 0; i < kloop; ++i) {
float* out_ptr = out + 4 * NBLOCK_C4 * i;
const float* in_ptr = in + i * 4 * n;
for (int j = 0; j < xloop; ++j) {
float* out_p = out_ptr + j * ldo;
#ifdef __aarch64__
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[in]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in]], #32 \n"
"st1 {v0.4s, v1.4s}, [%[out]], #32 \n"
"ld1 {v4.4s, v5.4s}, [%[in]], #32 \n"
"st1 {v2.4s, v3.4s}, [%[out]], #32 \n"
"ld1 {v6.4s, v7.4s}, [%[in]], #32 \n"
"st1 {v4.4s, v5.4s}, [%[out]], #32 \n"
"st1 {v6.4s, v7.4s}, [%[out]], #32 \n"
: [in] "+r"(in_ptr), [out] "+r"(out_p)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[in]]! \n"
"vld1.32 {d4-d7}, [%[in]]! \n"
"vst1.32 {d0-d3}, [%[out]]! \n"
"vld1.32 {d8-d11}, [%[in]]! \n"
"vst1.32 {d4-d7}, [%[out]]! \n"
"vld1.32 {d12-d15}, [%[in]]! \n"
"vst1.32 {d8-d11}, [%[out]]! \n"
"vst1.32 {d12-d15}, [%[out]]! \n"
: [in] "+r"(in_ptr), [out] "+r"(out_p)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
#endif // __aarch674__
}
}
}
float* out_remain4 = out + xloop * k_round * NBLOCK_C4;
const float* in_remain4 = in + xloop * NBLOCK_C4 * 4;
if (remain4) {
#pragma omp parallel for
for (int i = 0; i < kloop; ++i) {
float* out_ptr = out_remain4 + 4 * 4 * i;
const float* in_ptr = in_remain4 + i * 4 * n;
#ifdef __aarch64__
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[in]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in]], #32 \n"
"st1 {v0.4s, v1.4s}, [%[out]], #32 \n"
"st1 {v2.4s, v3.4s}, [%[out]], #32 \n"
: [in] "+r"(in_ptr), [out] "+r"(out_ptr)
:
: "v0", "v1", "v2", "v3");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[in]]! \n"
"vld1.32 {d4-d7}, [%[in]]! \n"
"vst1.32 {d0-d3}, [%[out]]! \n"
"vst1.32 {d4-d7}, [%[out]]! \n"
: [in] "+r"(in_ptr), [out] "+r"(out_ptr)
:
: "q0", "q1", "q2", "q3");
#endif // __aarch64__
}
}
float* out_remain1 = out_remain4 + remain4 * k_round * 4;
const float* in_remain1 = in_remain4 + remain4 * 4 * 4;
if (remain1) {
#pragma omp parallel for
for (int i = 0; i < kloop; ++i) {
float* out_ptr = out_remain1 + 4 * remain1 * i;
const float* in_ptr = in_remain1 + i * 4 * n;
for (int j = 0; j < remain1; ++j) {
float32x4_t vin = vld1q_f32(in_ptr);
in_ptr += 4;
vst1q_f32(out_ptr, vin);
out_ptr += 4;
}
}
}
}
void sgemm_prepack_c4_common(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
const float* bias,
bool has_bias,
bool has_relu,
ARMContext* ctx) {
const int m_round = (M + 3) / 4 * 4;
const int k_round = (K + 3) / 4 * 4;
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
int threads = ctx->threads();
auto workspace = ctx->workspace_data<float>();
// l2 = ablock * K * threads + K * bchunk_w + threads * ablock * bchunk_w;
int bchunk_w = (l2_cache - threads * k_round * sizeof(float)) /
((k_round + threads * MBLOCK_C4) * sizeof(float));
bchunk_w = bchunk_w > N ? N : bchunk_w;
bchunk_w = bchunk_w / NBLOCK_C4 * NBLOCK_C4;
bchunk_w = bchunk_w > NBLOCK_C4 ? bchunk_w : NBLOCK_C4;
int bchunk_loop = (N + bchunk_w - 1) / bchunk_w;
const int h_loop = m_round >> 2; // MBLOCK_C4 == 4;
const int kcnt = (k_round + KBLOCK_C4 - 1) / KBLOCK_C4;
const int ldc = N * 4;
const int lda = k_round * 4;
float bias_buf[m_round]; // NOLINT
if (has_bias) {
memcpy(bias_buf, bias, M * sizeof(float));
memset(bias_buf + M, 0, (m_round - M) * sizeof(float));
} else {
memset(bias_buf, 0, m_round * sizeof(float));
}
// bchunk_loop
float* c = C;
for (int n = 0; n < bchunk_loop; ++n) {
int x_start = n * bchunk_w;
int x_end = x_start + bchunk_w;
int w_loop = bchunk_w / NBLOCK_C4;
int flag_remain = 0;
int w_loop4 = 0;
int remain = 0;
if (x_end > N) {
w_loop = (N - x_start) / NBLOCK_C4;
int w_loop_rem = (N - x_start) - w_loop * NBLOCK_C4;
w_loop4 = w_loop_rem >> 2;
remain = w_loop_rem & 3;
x_end = N;
flag_remain = 1;
}
float* bchunk = workspace;
loadb_c4(bchunk, B, x_start, x_end, k_round, N);
float* cchunk = c + n * bchunk_w * 4;
int has_remain = (n == bchunk_loop - 1) && flag_remain;
#pragma omp parallel for num_threads(threads)
for (int h = 0; h < h_loop; ++h) {
float* bias_h = bias_buf + h * 4;
#ifdef __aarch64__
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vbias = vld1q_f32(bias_h);
#endif
const float* ablock = A_packed + h * lda;
const float* bblock = bchunk;
float* cblock = cchunk + h * ldc;
for (int w = 0; w < w_loop; ++w) {
int cnt = kcnt;
const float* ablock_ptr = ablock;
// clang-format off
#ifdef __aarch64__
asm volatile(
"prfm pldl1keep, [%[a]] \n"
"prfm pldl1keep, [%[b]] \n"
"prfm pldl1keep, [%[b], #64] \n"
"mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/
"mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/
"mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/
"mov v12.16b, %[vbias].16b \n" /* mov bias to c3*/
/* load a0a1 to v1-v2 */
"ld1 {v1.4s, v2.4s}, [%[a]], #32 \n"
"mov v13.16b, %[vbias].16b \n" /* mov bias to c4*/
"mov v14.16b, %[vbias].16b \n" /* mov bias to c5*/
"mov v15.16b, %[vbias].16b \n" /* mov bias to c6*/
"mov v16.16b, %[vbias].16b \n" /* mov bias to c7*/
"1:\n"
/* load b0b1b2b3 to v5-v8 */
"ld1 {v5.4s, v6.4s}, [%[b]], #32 \n"
"ld1 {v7.4s, v8.4s}, [%[b]], #32 \n"
"prfm pldl1keep, [%[b]] \n"
"fmla v9.4s, v1.4s, v5.s[0] \n"
"fmla v10.4s, v1.4s, v6.s[0] \n"
"fmla v11.4s, v1.4s, v7.s[0] \n"
"fmla v12.4s, v1.4s, v8.s[0] \n"
/* load b4b5b6b7 to v25-v28 */
"ld1 {v25.4s, v26.4s}, [%[b]], #32 \n"
"ld1 {v27.4s, v28.4s}, [%[b]], #32 \n"
"prfm pldl1keep, [%[a], #32] \n"
"fmla v9.4s, v2.4s, v5.s[1] \n"
"fmla v10.4s, v2.4s, v6.s[1] \n"
"fmla v11.4s, v2.4s, v7.s[1] \n"
"fmla v12.4s, v2.4s, v8.s[1] \n"
"prfm pldl1keep, [%[b], #64] \n"
"fmla v13.4s, v1.4s, v25.s[0] \n"
"fmla v14.4s, v1.4s, v26.s[0] \n"
"fmla v15.4s, v1.4s, v27.s[0] \n"
"fmla v16.4s, v1.4s, v28.s[0] \n"
/* load a2a3 to v3-v4 */
"ld1 {v3.4s, v4.4s}, [%[a]], #32 \n"
"prfm pldl1keep, [%[b], #128] \n"
"fmla v13.4s, v2.4s, v25.s[1] \n"
"fmla v14.4s, v2.4s, v26.s[1] \n"
"fmla v15.4s, v2.4s, v27.s[1] \n"
"fmla v16.4s, v2.4s, v28.s[1] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"fmla v9.4s, v3.4s, v5.s[2] \n"
"fmla v10.4s, v3.4s, v6.s[2] \n"
"fmla v11.4s, v3.4s, v7.s[2] \n"
"fmla v12.4s, v3.4s, v8.s[2] \n"
"fmla v13.4s, v3.4s, v25.s[2] \n"
"fmla v14.4s, v3.4s, v26.s[2] \n"
"fmla v15.4s, v3.4s, v27.s[2] \n"
"fmla v16.4s, v3.4s, v28.s[2] \n"
/* load a0a1 to v1-v2 */
"ld1 {v1.4s, v2.4s}, [%[a]], #32 \n"
"fmla v9.4s, v4.4s, v5.s[3] \n"
"fmla v10.4s, v4.4s, v6.s[3] \n"
"fmla v11.4s, v4.4s, v7.s[3] \n"
"fmla v12.4s, v4.4s, v8.s[3] \n"
"fmla v13.4s, v4.4s, v25.s[3] \n"
"fmla v14.4s, v4.4s, v26.s[3] \n"
"fmla v15.4s, v4.4s, v27.s[3] \n"
"fmla v16.4s, v4.4s, v28.s[3] \n"
"bne 1b\n"
"cbz %w[relu], 2f \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"fmax v10.4s, v10.4s, %[vzero].4s \n"
"fmax v11.4s, v11.4s, %[vzero].4s \n"
"fmax v12.4s, v12.4s, %[vzero].4s \n"
"fmax v13.4s, v13.4s, %[vzero].4s \n"
"fmax v14.4s, v14.4s, %[vzero].4s \n"
"fmax v15.4s, v15.4s, %[vzero].4s \n"
"fmax v16.4s, v16.4s, %[vzero].4s \n"
"2:\n"
"st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[c]], #64 \n"
"st1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[c]], #64 \n"
: [a] "+r"(ablock_ptr),
[b] "+r"(bblock),
[c] "+r"(cblock),
[cnt] "+r"(cnt)
: [bias] "r"(bias_h), [relu] "r"(has_relu),
[vbias] "w"(vbias), [vzero] "w" (vzero)
: "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8",
"v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v25", "v26", "v27", "v28", "cc", "memory");
#else
asm volatile(
"vld1.32 {d6-d7}, [%[bias]] \n"
"pld [%[a]] \n"
"pld [%[b]] \n"
"pld [%[b], #64] \n"
"vmov.32 q8, q3 \n" /* mov bias to c0*/
"vmov.32 q9, q3 \n" /* mov bias to c1*/
"vmov.32 q10, q3 \n" /* mov bias to c2*/
"vmov.32 q11, q3 \n" /* mov bias to c3*/
"vld1.32 {d0-d3}, [%[a]]! \n"
"vmov.32 q12, q3 \n" /* mov bias to c4*/
"vmov.32 q13, q3 \n" /* mov bias to c5*/
"vmov.32 q14, q3 \n" /* mov bias to c6*/
"vmov.32 q15, q3 \n" /* mov bias to c7*/
"1:\n"
/* c0c1c2c3 */
"vld1.32 {d8-d11}, [%[b]]! \n"
"vld1.32 {d12-d15}, [%[b]]! \n"
"pld [%[b]] \n"
"vmla.f32 q8, q0, d8[0] \n"
"vmla.f32 q9, q0, d10[0] \n"
"vmla.f32 q10, q0, d12[0] \n"
"vmla.f32 q11, q0, d14[0] \n"
"vld1.32 {d4-d7}, [%[a]]! \n"
"vmla.f32 q8, q1, d8[1] \n"
"vmla.f32 q9, q1, d10[1] \n"
"vmla.f32 q10, q1, d12[1] \n"
"vmla.f32 q11, q1, d14[1] \n"
"pld [%[b], #64] \n"
"vmla.f32 q8, q2, d9[0] \n"
"vmla.f32 q9, q2, d11[0] \n"
"vmla.f32 q10, q2, d13[0] \n"
"vmla.f32 q11, q2, d15[0] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmla.f32 q8, q3, d9[1] \n"
"vmla.f32 q9, q3, d11[1] \n"
"vld1.f32 {d8-d11}, [%[b]]! \n"
"vmla.f32 q10, q3, d13[1] \n"
"vmla.f32 q11, q3, d15[1] \n"
"vld1.32 {d12-d15}, [%[b]]! \n"
/* c4c5c6c7 */
"vmla.f32 q12, q0, d8[0] \n"
"vmla.f32 q13, q0, d10[0] \n"
"vmla.f32 q14, q0, d12[0] \n"
"vmla.f32 q15, q0, d14[0] \n"
"pld [%[a], #32] \n"
"vmla.f32 q12, q1, d8[1] \n"
"vmla.f32 q13, q1, d10[1] \n"
"vmla.f32 q14, q1, d12[1] \n"
"vmla.f32 q15, q1, d14[1] \n"
"vld1.32 {d0-d3}, [%[a]]! \n"
"vmla.f32 q12, q2, d9[0] \n"
"vmla.f32 q13, q2, d11[0] \n"
"vmla.f32 q14, q2, d13[0] \n"
"vmla.f32 q15, q2, d15[0] \n"
"pld [%[b], #64] \n"
"vmla.f32 q12, q3, d9[1] \n"
"vmla.f32 q13, q3, d11[1] \n"
"vmla.f32 q14, q3, d13[1] \n"
"vmla.f32 q15, q3, d15[1] \n"
"bne 1b\n"
"cmp %[relu], #0 \n"
"beq 2f \n"
"vmov.u32 q0, #0 \n"
"vmax.f32 q8, q8, q0 \n"
"vmax.f32 q9, q9, q0 \n"
"vmax.f32 q10, q10, q0 \n"
"vmax.f32 q11, q11, q0 \n"
"vmax.f32 q12, q12, q0 \n"
"vmax.f32 q13, q13, q0 \n"
"vmax.f32 q14, q14, q0 \n"
"vmax.f32 q15, q15, q0 \n"
"2:\n"
"vst1.32 {d16-d19}, [%[c]]! \n"
"vst1.32 {d20-d23}, [%[c]]! \n"
"vst1.32 {d24-d27}, [%[c]]! \n"
"vst1.32 {d28-d31}, [%[c]]! \n"
: [a] "+r"(ablock_ptr),
[b] "+r"(bblock),
[c] "+r"(cblock),
[cnt] "+r"(cnt)
: [bias] "r"(bias_h),
[relu] "r"(has_relu)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory");
#endif
// clang-format on
}
if (has_remain) {
if (w_loop4 > 0) {
int cnt = kcnt;
const float* ablock_ptr = ablock;
// clang-format off
#ifdef __aarch64__
asm volatile(
"prfm pldl1keep, [%[a]] \n"
"prfm pldl1keep, [%[b]] \n"
"mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/
"mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/
"mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/
"mov v12.16b, %[vbias].16b \n" /* mov bias to c3*/
/* load a0a1 to v1-v2 */
"ld1 {v1.4s, v2.4s}, [%[a]], #32 \n"
"1:\n"
/* load b0b1b2b3 to v5-v8 */
"ld1 {v5.4s, v6.4s}, [%[b]], #32 \n"
"ld1 {v7.4s, v8.4s}, [%[b]], #32 \n"
"fmla v9.4s, v1.4s, v5.s[0] \n"
"fmla v10.4s, v1.4s, v6.s[0] \n"
"fmla v11.4s, v1.4s, v7.s[0] \n"
"fmla v12.4s, v1.4s, v8.s[0] \n"
/* load a2a3 to v3-v4 */
"ld1 {v3.4s, v4.4s}, [%[a]], #32 \n"
"prfm pldl1keep, [%[a]] \n"
"fmla v9.4s, v2.4s, v5.s[1] \n"
"fmla v10.4s, v2.4s, v6.s[1] \n"
"fmla v11.4s, v2.4s, v7.s[1] \n"
"fmla v12.4s, v2.4s, v8.s[1] \n"
"prfm pldl1keep, [%[b]] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"fmla v9.4s, v3.4s, v5.s[2] \n"
"fmla v10.4s, v3.4s, v6.s[2] \n"
"fmla v11.4s, v3.4s, v7.s[2] \n"
"fmla v12.4s, v3.4s, v8.s[2] \n"
/* load a0a1 to v1-v2 */
"ld1 {v1.4s, v2.4s}, [%[a]], #32 \n"
"fmla v9.4s, v4.4s, v5.s[3] \n"
"fmla v10.4s, v4.4s, v6.s[3] \n"
"fmla v11.4s, v4.4s, v7.s[3] \n"
"fmla v12.4s, v4.4s, v8.s[3] \n"
"bne 1b\n"
"cbz %w[relu], 2f \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"fmax v10.4s, v10.4s, %[vzero].4s \n"
"fmax v11.4s, v11.4s, %[vzero].4s \n"
"fmax v12.4s, v12.4s, %[vzero].4s \n"
"2:\n"
"st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[c]], #64 \n"
: [a] "+r"(ablock_ptr),
[b] "+r"(bblock),
[c] "+r"(cblock),
[cnt] "+r"(cnt)
: [bias] "r"(bias_h),
[relu] "r"(has_relu),
[vbias] "w"(vbias),
[vzero] "w" (vzero)
: "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "cc", "memory");
#else
asm volatile(
"pld [%[a]] \n"
"pld [%[b]] \n"
"vld1.32 {d6-d7}, [%[bias]] \n"
"vld1.32 {d0-d3}, [%[a]]! \n" /* load a0 a1 */
"vmov.32 q8, q3 \n" /* mov bias to c0 */
"vmov.32 q9, q3 \n" /* mov bias to c1 */
"vmov.32 q10, q3 \n" /* mov bias to c2 */
"vmov.32 q11, q3 \n" /* mov bias to c3 */
"1:\n"
/* c0c1c2c3 */
"vld1.32 {d8-d11}, [%[b]]! \n"
"vld1.32 {d12-d15}, [%[b]]! \n"
"pld [%[b]] \n"
"vmla.f32 q8, q0, d8[0] \n"
"vmla.f32 q9, q0, d10[0] \n"
"vmla.f32 q10, q0, d12[0] \n"
"vmla.f32 q11, q0, d14[0] \n"
"vld1.32 {d4-d7}, [%[a]]! \n"
"pld [%[a]] \n"
"vmla.f32 q8, q1, d8[1] \n"
"vmla.f32 q9, q1, d10[1] \n"
"vmla.f32 q10, q1, d12[1] \n"
"vmla.f32 q11, q1, d14[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmla.f32 q8, q2, d9[0] \n"
"vmla.f32 q9, q2, d11[0] \n"
"vmla.f32 q10, q2, d13[0] \n"
"vmla.f32 q11, q2, d15[0] \n"
"vld1.32 {d0-d3}, [%[a]]! \n"
"vmla.f32 q8, q3, d9[1] \n"
"vmla.f32 q9, q3, d11[1] \n"
"vmla.f32 q10, q3, d13[1] \n"
"vmla.f32 q11, q3, d15[1] \n"
"bne 1b\n"
"cmp %[relu], #0 \n"
"beq 2f \n"
"vmov.u32 q0, #0 \n"
"vmax.f32 q8, q8, q0 \n"
"vmax.f32 q9, q9, q0 \n"
"vmax.f32 q10, q10, q0 \n"
"vmax.f32 q11, q11, q0 \n"
"2:\n"
"vst1.32 {d16-d19}, [%[c]]! \n"
"vst1.32 {d20-d23}, [%[c]]! \n"
: [a] "+r"(ablock_ptr),
[b] "+r"(bblock),
[c] "+r"(cblock),
[cnt] "+r"(cnt)
: [bias] "r"(bias_h), [relu] "r"(has_relu)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "cc", "memory");
#endif
// clang-format on
}
if (remain > 0) {
int cnt = kcnt;
const float* ablock_ptr = ablock;
// clang-format off
#ifdef __aarch64__
asm volatile(
"prfm pldl1keep, [%[a]] \n"
"prfm pldl1keep, [%[b]] \n"
"ld1 {v1.4s, v2.4s}, [%[a]], #32 \n"
"cmp %w[remain], #3 \n"
"beq 1f \n"
"cmp %w[remain], #2 \n"
"beq 2f \n"
/* remain 1 */
"mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/
"mov v10.16b, %[vzero].16b \n" /* mov zero to c1*/
"3: \n"
"ld1 {v5.4s}, [%[b]], #16 \n"
"ld1 {v3.4s, v4.4s}, [%[a]], #32 \n"
"fmla v9.4s, v1.4s, v5.s[0] \n"
"fmla v10.4s, v2.4s, v5.s[1] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v1.4s, v2.4s}, [%[a]], #32 \n"
"fmla v9.4s, v3.4s, v5.s[2] \n"
"fmla v10.4s, v4.4s, v5.s[3] \n"
"bne 3b \n"
"fadd v9.4s, v9.4s, v10.4s \n"
"cbz %w[relu], 6f \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"6: \n"
"st1 {v9.4s}, [%[c]], #16 \n"
"b 9f \n"
/* remain 2 */
"2: \n"
"mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/
"mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/
"mov v11.16b, %[vzero].16b \n" /* mov zero to c2*/
"mov v12.16b, %[vzero].16b \n" /* mov zero to c3*/
"4: \n"
"ld1 {v5.4s, v6.4s}, [%[b]], #32 \n"
"ld1 {v3.4s, v4.4s}, [%[a]], #32 \n"
"fmla v9.4s, v1.4s, v5.s[0] \n"
"fmla v10.4s, v1.4s, v6.s[0] \n"
"fmla v11.4s, v2.4s, v5.s[1] \n"
"fmla v12.4s, v2.4s, v6.s[1] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"fmla v9.4s, v3.4s, v5.s[2] \n"
"fmla v10.4s, v3.4s, v6.s[2] \n"
"fmla v11.4s, v4.4s, v5.s[3] \n"
"fmla v12.4s, v4.4s, v6.s[3] \n"
"ld1 {v1.4s, v2.4s}, [%[a]], #32 \n"
"bne 4b \n"
"fadd v9.4s, v9.4s, v11.4s \n"
"fadd v10.4s, v10.4s, v12.4s \n"
"cbz %w[relu], 7f \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"fmax v10.4s, v10.4s, %[vzero].4s \n"
"7: \n"
"st1 {v9.4s, v10.4s}, [%[c]], #32 \n"
"b 9f \n"
/* remain 3 */
"1: \n"
"mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/
"mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/
"mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/
"5: \n"
"ld1 {v5.4s, v6.4s}, [%[b]], #32 \n"
"ld1 {v7.4s}, [%[b]], #16 \n"
"fmla v9.4s, v1.4s, v5.s[0] \n"
"fmla v10.4s, v1.4s, v6.s[0] \n"
"fmla v11.4s, v1.4s, v7.s[0] \n"
"ld1 {v3.4s, v4.4s}, [%[a]], #32 \n"
"fmla v9.4s, v2.4s, v5.s[1] \n"
"fmla v10.4s, v2.4s, v6.s[1] \n"
"fmla v11.4s, v2.4s, v7.s[1] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"fmla v9.4s, v3.4s, v5.s[2] \n"
"fmla v10.4s, v3.4s, v6.s[2] \n"
"fmla v11.4s, v3.4s, v7.s[2] \n"
"prfm pldl1keep, [%[a]] \n"
"fmla v9.4s, v4.4s, v5.s[3] \n"
"fmla v10.4s, v4.4s, v6.s[3] \n"
"fmla v11.4s, v4.4s, v7.s[3] \n"
"ld1 {v1.4s, v2.4s}, [%[a]], #32 \n"
"bne 5b \n"
"cbz %w[relu], 8f \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"fmax v10.4s, v10.4s, %[vzero].4s \n"
"fmax v11.4s, v11.4s, %[vzero].4s \n"
"8: \n"
"st1 {v9.4s, v10.4s}, [%[c]], #32 \n"
"st1 {v11.4s}, [%[c]], #16 \n"
"9:\n"
: [a] "+r"(ablock_ptr),
[b] "+r"(bblock),
[c] "+r"(cblock),
[cnt] "+r"(cnt)
: [bias] "r"(bias_h), [relu] "r"(has_relu),
[remain] "r"(remain), [vbias] "w"(vbias),
[vzero] "w" (vzero)
: "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v9",
"v10", "v11", "v12", "cc","memory");
#else
asm volatile(
"pld [%[a]] \n"
"pld [%[b]] \n"
"vld1.32 {d0-d1}, [%[bias]] \n"
"vld1.32 {d2-d5}, [%[a]]! \n"
"vmov.u32 q15, #0 \n"
"cmp %[remain], #3 \n"
"beq 1f \n"
"cmp %[remain], #2 \n"
"beq 2f \n"
/* remain 1 */
"vmov.32 q9, q0 \n" /* mov bias to c0*/
"vmov.32 q10, q15 \n" /* mov zero to c1*/
"3: \n"
"vld1.32 {d10-d11}, [%[b]]! \n"
"vld1.32 {d6-d9}, [%[a]]! \n"
"vmla.f32 q9, q1, d10[0] \n"
"vmla.f32 q10, q2, d10[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vld1.32 {d2-d5}, [%[a]]! \n"
"vmla.f32 q9, q3, d11[0] \n"
"vmla.f32 q10, q4, d11[1] \n"
"bne 3b \n"
"vadd.f32 q9, q9, q10 \n"
"cmp %[relu], #0 \n"
"beq 6f \n"
"vmax.f32 q9, q9, q15 \n"
"6: \n"
"vst1.32 {d18-d19}, [%[c]]! \n"
"b 9f \n"
/* remain 2 */
"2: \n"
"vmov.u32 q9, q0 \n" /* mov bias to c0*/
"vmov.u32 q10, q0 \n" /* mov bias to c1*/
"vmov.u32 q11, q15 \n" /* mov zero to c2*/
"vmov.u32 q12, q15 \n" /* mov zero to c3*/
"4: \n"
"vld1.32 {d10-d13}, [%[b]]! \n"
"vld1.32 {d6-d9}, [%[a]]! \n"
"vmla.f32 q9, q1, d10[0] \n"
"vmla.f32 q10, q1, d12[0] \n"
"vmla.f32 q11, q2, d10[1] \n"
"vmla.f32 q12, q2, d12[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmla.f32 q9, q3, d11[0] \n"
"vmla.f32 q10, q3, d13[0] \n"
"vmla.f32 q11, q4, d11[1] \n"
"vmla.f32 q12, q4, d13[1] \n"
"vld1.32 {d2-d5}, [%[a]]! \n"
"bne 4b \n"
"vadd.f32 q9, q9, q11 \n"
"vadd.f32 q10, q10, q12 \n"
"cmp %[relu], #0 \n"
"beq 7f \n"
"vmax.f32 q9, q9, q15 \n"
"vmax.f32 q10, q10, q15 \n"
"7: \n"
"vst1.32 {d18-d21}, [%[c]]! \n"
"b 9f \n"
/* remain 3 */
"1: \n"
"vmov.u32 q9, q0 \n" /* mov bias to c0*/
"vmov.u32 q10, q0 \n" /* mov bias to c1*/
"vmov.u32 q11, q0 \n" /* mov bias to c2*/
"5: \n"
"vld1.32 {d10-d13}, [%[b]]! \n"
"vld1.32 {d14-d15}, [%[b]]! \n"
"vmla.f32 q9, q1, d10[0] \n"
"vmla.f32 q10, q1, d12[0] \n"
"vmla.f32 q11, q1, d14[0] \n"
"vld1.32 {d6-d9}, [%[a]]! \n"
"vmla.f32 q9, q2, d10[1] \n"
"vmla.f32 q10, q2, d12[1] \n"
"vmla.f32 q11, q2, d14[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmla.f32 q9, q3, d11[0] \n"
"vmla.f32 q10, q3, d13[0] \n"
"vmla.f32 q11, q3, d15[0] \n"
"pld [%[a]] \n"
"vmla.f32 q9, q4, d11[1] \n"
"vmla.f32 q10, q4, d13[1] \n"
"vmla.f32 q11, q4, d15[1] \n"
"vld1.32 {d2-d5}, [%[a]]! \n"
"bne 5b \n"
"cmp %[relu], #0 \n"
"beq 8f \n"
"vmax.f32 q9, q9, q15 \n"
"vmax.f32 q10, q10, q15 \n"
"vmax.f32 q11, q11, q15 \n"
"8: \n"
"vst1.32 {d18-d21}, [%[c]]! \n"
"vst1.32 {d22-d23}, [%[c]]! \n"
"9:\n"
: [a] "+r"(ablock_ptr),
[b] "+r"(bblock),
[c] "+r"(cblock),
[cnt] "+r"(cnt)
: [bias] "r"(bias_h),
[relu] "r"(has_relu),
[remain] "r"(remain)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q9",
"q10", "q11", "q12", "q15", "cc","memory");
#endif
// clang-format on
}
}
}
}
}
void sgemm_prepack_c4_small(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
const float* bias,
bool has_bias,
bool has_relu,
ARMContext* ctx) {
const int m_round = (M + 3) / 4 * 4;
const int k_round = (K + 3) / 4 * 4;
const int mloop = m_round >> 2;
const int lda = 4 * k_round;
const int ldb_byte = 4 * N * sizeof(float);
const int kcnt = k_round >> 2;
float bias_buf[m_round]; // NOLINT
if (has_bias) {
memcpy(bias_buf, bias, M * sizeof(float));
memset(bias_buf + M, 0, (m_round - M) * sizeof(float));
} else {
memset(bias_buf, 0, m_round * sizeof(float));
}
#ifdef __aarch64__
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
const float* bias_ptr = bias_buf;
for (int m = 0; m < mloop; ++m) {
#ifdef __aarch64__
float32x4_t vbias = vld1q_f32(bias_ptr);
#endif
const float* b = B;
int n = N;
#ifdef __aarch64__
for (; n > 7; n -= 8) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
// clang-format off
asm volatile(
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* mov bias to c0-c7*/
"mov v8.16b, %[vbias].16b \n"
"mov v9.16b, %[vbias].16b \n"
"mov v10.16b, %[vbias].16b \n"
"mov v11.16b, %[vbias].16b \n"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"mov v12.16b, %[vbias].16b \n"
"mov v13.16b, %[vbias].16b \n"
"mov v14.16b, %[vbias].16b \n"
"mov v15.16b, %[vbias].16b \n"
"1:\n"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v16.4s, v1.s[0] \n"
"fmla v10.4s, v16.4s, v2.s[0] \n"
"fmla v11.4s, v16.4s, v3.s[0] \n"
"prfm pldl1keep, [%[b]] \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32 \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"sub %[b], %[b], #128 \n"
"fmla v12.4s, v16.4s, v4.s[0] \n"
"fmla v13.4s, v16.4s, v5.s[0] \n"
"fmla v14.4s, v16.4s, v6.s[0] \n"
"fmla v15.4s, v16.4s, v7.s[0] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v12.4s, v17.4s, v4.s[1] \n"
"fmla v13.4s, v17.4s, v5.s[1] \n"
"fmla v14.4s, v17.4s, v6.s[1] \n"
"fmla v15.4s, v17.4s, v7.s[1] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v12.4s, v18.4s, v4.s[2] \n"
"fmla v13.4s, v18.4s, v5.s[2] \n"
"fmla v14.4s, v18.4s, v6.s[2] \n"
"fmla v15.4s, v18.4s, v7.s[2] \n"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"fmla v12.4s, v19.4s, v4.s[3] \n"
"fmla v13.4s, v19.4s, v5.s[3] \n"
"fmla v14.4s, v19.4s, v6.s[3] \n"
"fmla v15.4s, v19.4s, v7.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"bne 1b \n"
"cbz %w[relu], 2f \n"
"fmax v8.4s, v8.4s, %[vzero].4s \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"fmax v10.4s, v10.4s, %[vzero].4s \n"
"fmax v11.4s, v11.4s, %[vzero].4s \n"
"fmax v12.4s, v12.4s, %[vzero].4s \n"
"fmax v13.4s, v13.4s, %[vzero].4s \n"
"fmax v14.4s, v14.4s, %[vzero].4s \n"
"fmax v15.4s, v15.4s, %[vzero].4s \n"
"2:\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [relu] "r" (has_relu),
[ldb] "r" (ldb_byte),
[vbias] "w" (vbias),
[vzero] "w" (vzero)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "cc", "memory"
);
b += 4 * 8;
}
for (; n > 3; n -= 4) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* mov bias to c0-c3*/
"mov v8.16b, %[vbias].16b \n"
"mov v9.16b, %[vbias].16b \n"
"mov v10.16b, %[vbias].16b \n"
"mov v11.16b, %[vbias].16b \n"
"1:\n"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v16.4s, v1.s[0] \n"
"fmla v10.4s, v16.4s, v2.s[0] \n"
"fmla v11.4s, v16.4s, v3.s[0] \n"
"sub %[b], %[b], #64 \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"bne 1b \n"
"cbz %w[relu], 2f \n"
"fmax v8.4s, v8.4s, %[vzero].4s \n"
"fmax v9.4s, v9.4s, %[vzero].4s \n"
"fmax v10.4s, v10.4s, %[vzero].4s \n"
"fmax v11.4s, v11.4s, %[vzero].4s \n"
"2:\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [relu] "r" (has_relu),
[ldb] "r" (ldb_byte),
[vbias] "w" (vbias),
[vzero] "w" (vzero)
: "v0", "v1", "v2", "v3", "v8", "v9",
"v10", "v11", "v16", "v17", "v18",
"v19", "cc", "memory"
);
b += 4 * 4;
}
for (; n > 0; n--) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* mov bias to c0 */
"mov v8.16b, %[vbias].16b \n"
"mov v9.16b, %[vzero].16b \n"
"1:\n"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16 \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v17.4s, v0.s[1] \n"
"sub %[b], %[b], #16 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v19.4s, v0.s[3] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"bne 1b \n"
"fadd v8.4s, v8.4s, v9.4s \n"
"cbz %w[relu], 2f \n"
"fmax v8.4s, v8.4s, %[vzero].4s \n"
"2:\n"
"st1 {v8.4s}, [%[c]], #16 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [relu] "r" (has_relu),
[ldb] "r" (ldb_byte),
[vbias] "w" (vbias),
[vzero] "w" (vzero)
: "v0", "v8", "v9", "v16", "v17",
"v18", "v19", "cc", "memory"
);
b += 4;
}
#else
for (; n > 5; n -= 6) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"vld1.32 {d8-d9}, [%[bias]] \n"
/* load a0, a1 */
"vld1.32 {d12-d15}, [%[a]]! \n"
/* mov bias to c0-c7*/
"vmov.u32 q10, q4 \n"
"vmov.u32 q11, q4 \n"
"vmov.u32 q12, q4 \n"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]!\n"
"vld1.32 {d4-d7}, [%[b]]!\n"
"vmov.u32 q13, q4 \n"
"vmov.u32 q14, q4 \n"
"vmov.u32 q15, q4 \n"
"1:\n"
/* load b4, b5 */
"vld1.32 {d8-d11}, [%[b]]! \n"
/* load a2, a3 */
"vld1.32 {d16-d19}, [%[a]]!\n"
"vmla.f32 q10, q6, d0[0] \n"
"vmla.f32 q11, q6, d2[0] \n"
"vmla.f32 q12, q6, d4[0] \n"
"vmla.f32 q13, q6, d6[0] \n"
"vmla.f32 q14, q6, d8[0] \n"
"vmla.f32 q15, q6, d10[0] \n"
"sub %[b], %[b], #96 \n"
"vmla.f32 q10, q7, d0[1] \n"
"vmla.f32 q11, q7, d2[1] \n"
"vmla.f32 q12, q7, d4[1] \n"
"vmla.f32 q13, q7, d6[1] \n"
"vmla.f32 q14, q7, d8[1] \n"
"vmla.f32 q15, q7, d10[1] \n"
"add %[b], %[b], %[ldb] \n"
"pld [%[b]] \n"
/* load a0, a1 */
"vld1.32 {d12-d15}, [%[a]]!\n"
"vmla.f32 q10, q8, d1[0] \n"
"vmla.f32 q11, q8, d3[0] \n"
"vmla.f32 q12, q8, d5[0] \n"
"vmla.f32 q13, q8, d7[0] \n"
"pld [%[b], #64] \n"
"vmla.f32 q10, q9, d1[1] \n"
"vmla.f32 q11, q9, d3[1] \n"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q14, q8, d9[0] \n"
"vmla.f32 q15, q8, d11[0] \n"
"vmla.f32 q12, q9, d5[1] \n"
"vmla.f32 q13, q9, d7[1] \n"
"vmla.f32 q14, q9, d9[1] \n"
"vmla.f32 q15, q9, d11[1] \n"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"subs %[cnt], %[cnt], #1 \n"
"bne 1b \n"
"cmp %[relu], #0 \n"
"beq 2f \n"
"vmov.u32 q0, #0 \n"
"vmax.f32 q10, q10, q0 \n"
"vmax.f32 q11, q11, q0 \n"
"vmax.f32 q12, q12, q0 \n"
"vmax.f32 q13, q13, q0 \n"
"vmax.f32 q14, q14, q0 \n"
"vmax.f32 q15, q15, q0 \n"
"2: \n"
"vst1.32 {d20-d23}, [%[c]]! \n"
"vst1.32 {d24-d27}, [%[c]]! \n"
"vst1.32 {d28-d31}, [%[c]]! \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [relu] "r" (has_relu),
[ldb] "r" (ldb_byte),
[bias] "r"(bias_ptr)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory"
);
b += 4 * 6;
}
for (; n > 3; n -= 4) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"vld1.32 {d24-d25}, [%[bias]] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
/* mov bias to c0-c3*/
"vmov.u32 q8, q12 \n"
"vmov.u32 q9, q12 \n"
"vmov.u32 q10, q12 \n"
"vmov.u32 q11, q12 \n"
"vmov.u32 q13, #0 \n"
"1:\n"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vld1.32 {d4-d7}, [%[b]]! \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!\n"
"vmla.f32 q8, q4, d0[0] \n"
"vmla.f32 q9, q4, d2[0] \n"
"vmla.f32 q10, q4, d4[0] \n"
"vmla.f32 q11, q4, d6[0] \n"
"sub %[b], %[b], #64 \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"bne 1b \n"
"cmp %[relu], #0 \n"
"beq 2f \n"
"vmax.f32 q8, q8, q13 \n"
"vmax.f32 q9, q9, q13 \n"
"vmax.f32 q10, q10, q13 \n"
"vmax.f32 q11, q11, q13 \n"
"2:\n"
"vst1.32 {d16-d19}, [%[c]]!\n"
"vst1.32 {d20-d23}, [%[c]]!\n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [relu] "r" (has_relu),
[ldb] "r" (ldb_byte),
[bias] "r"(bias_ptr)
: "q0", "q1", "q2", "q3", "q4", "q5",
"q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "cc", "memory"
);
b += 4 * 4;
}
for (; n > 0; n--) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"vld1.32 {d14-d15}, [%[bias]] \n"
"vmov.u32 q8, #0 \n"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
/* mov bias to c0 */
"vmov.u32 q5, q7 \n"
"vmov.u32 q6, q8 \n"
"1:\n"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]! \n"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]! \n"
"vmla.f32 q5, q1, d0[0] \n"
"vmla.f32 q6, q2, d0[1] \n"
"sub %[b], %[b], #16 \n"
"subs %[cnt], %[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q5, q3, d1[0] \n"
"vmla.f32 q6, q4, d1[1] \n"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
"bne 1b \n"
"vadd.f32 q5, q5, q6 \n"
"cmp %[relu], #0 \n"
"beq 2f \n"
"vmax.f32 q5, q5, q8 \n"
"2:\n"
"vst1.32 {d10-d11}, [%[c]]!\n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [relu] "r" (has_relu),
[ldb] "r" (ldb_byte),
[bias] "r"(bias_ptr)
: "q0", "q1", "q2", "q3", "q4",
"q5", "q6", "q7", "q8", "cc", "memory"
);
// clang-format on
b += 4;
}
#endif
bias_ptr += 4;
A_packed += lda;
}
}
void sgemm_prepack_c4(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
const float* bias,
bool has_bias,
bool has_relu,
ARMContext* ctx) {
if (N > 16) {
sgemm_prepack_c4_common(
M, N, K, A_packed, B, C, bias, has_bias, has_relu, ctx);
} else {
sgemm_prepack_c4_small(
M, N, K, A_packed, B, C, bias, has_bias, has_relu, ctx);
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include "lite/core/context.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
constexpr int MBLOCK_C4 = 4;
constexpr int NBLOCK_C4 = 8;
constexpr int KBLOCK_C4 = 4;
void sgemm_prepack_c4(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
const float* bias,
bool has_bias,
bool has_relu,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(sgemv_compute_test SRCS sgemv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(sgemm_c4_compute_test SRCS sgemm_c4_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
#ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h"
#endif // LITE_WITH_ARM
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
typedef paddle::lite::Tensor Tensor;
using paddle::lite::Timer;
DEFINE_int32(power_mode,
3,
"power mode: "
"0 for POWER_HIGH;"
"1 for POWER_LOW;"
"2 for POWER_FULL;"
"3 for NO_BIND");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemm_c4: M");
DEFINE_int32(N, 512, "gemm_c4: N");
DEFINE_int32(K, 512, "gemm_c4: K");
DEFINE_bool(flag_relu, false, "do relu");
DEFINE_bool(flag_bias, false, "with bias");
bool test_sgemm_c4(
int m, int n, int k, bool has_bias, bool has_relu, int cls, int ths) {
int m_round = (m + 3) / 4 * 4;
int k_round = (k + 3) / 4 * 4;
int size_a = m * k;
int size_b = n * k;
int size_a_c4 = m_round * k_round;
int size_b_c4 = k_round * n;
Tensor ta;
Tensor tb;
Tensor ta_c4;
Tensor tb_c4;
Tensor tc;
Tensor tc_basic;
Tensor tc_backup;
Tensor tbias;
ta.Resize({size_a});
tb.Resize({size_b});
ta_c4.Resize({size_a_c4});
tb_c4.Resize({size_b_c4});
tc.Resize({m_round * n});
tc_basic.Resize({m_round * n});
tbias.Resize({m});
ta.set_precision(PRECISION(kFloat));
tb.set_precision(PRECISION(kFloat));
ta_c4.set_precision(PRECISION(kFloat));
tb_c4.set_precision(PRECISION(kFloat));
tc.set_precision(PRECISION(kFloat));
tc_basic.set_precision(PRECISION(kFloat));
tbias.set_precision(PRECISION(kFloat));
fill_tensor_rand(ta, -1.f, 1.f);
fill_tensor_rand(tb, -1.f, 1.f);
fill_tensor_rand(tbias, -1.f, 1.f);
fill_tensor_rand(tc, -1.f, 1.f);
auto da = ta.mutable_data<float>();
auto db = tb.mutable_data<float>();
auto da_c4 = ta_c4.mutable_data<float>();
auto db_c4 = tb_c4.mutable_data<float>();
auto dc_basic = tc_basic.mutable_data<float>();
auto dbias = tbias.mutable_data<float>();
// trans A, B to c4
basic_trans_mat_to_c4(da, da_c4, k, m, k, true);
basic_trans_mat_to_c4(db, db_c4, n, k, n, false);
LOG(INFO) << "sgemm_c4 M: " << m << ", N: " << n << ", K: " << k
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
if (FLAGS_check_result) {
basic_gemm_c4(false,
false,
m,
n,
k,
1.f,
da,
k,
db,
n,
0.f,
dc_basic,
n,
dbias,
has_bias,
has_relu);
}
Timer t0;
#ifdef LITE_WITH_ARM
//! compute
double ops = 2.0 * m_round * n * k_round;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
auto dc = tc.mutable_data<float>();
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemm_prepack_c4(
m, n, k, da_c4, db_c4, dc, dbias, has_bias, has_relu, &ctx);
}
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
paddle::lite::arm::math::sgemm_prepack_c4(
m, n, k, da_c4, db_c4, dc, dbias, has_bias, has_relu, &ctx);
t0.end();
}
LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k
<< ", power_mode: " << cls << ", threads: " << ths
<< ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.get_average_ms()
<< " ms, min time: " << t0.get_min_time()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time()
<< " GOPs";
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kFloat));
tdiff.Resize(tc.dims());
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "a: ";
print_tensor(ta);
LOG(INFO) << "a_c4: ";
print_tensor(ta_c4);
LOG(INFO) << "b: ";
print_tensor(tb);
LOG(INFO) << "b_c4: ";
print_tensor(tb_c4);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "lite result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
}
}
#endif
return true;
}
TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789}) {
for (auto& k : {1, 3, 8, 59, 234}) {
for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) {
for (auto& th : {1, 2, 4}) {
auto flag = test_sgemm_c4(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " failed\n";
}
}
}
}
}
}
}
}
}
TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
auto flag = test_sgemm_c4(FLAGS_M,
FLAGS_N,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_power_mode,
FLAGS_threads);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K
<< ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu
<< " passed!!";
}
......@@ -14,6 +14,108 @@
#pragma once
template <typename type>
static void basic_trans_mat_to_c4(const type* input,
type* output,
const int ldin,
const int M,
const int K,
bool pack_k) {
const int m_round = (M + 3) / 4 * 4;
int k_round = (K + 3) / 4 * 4;
if (!pack_k) {
k_round = K;
}
const int m_loop = m_round / 4;
type zero_buf[K];
memset(zero_buf, 0, K * sizeof(type));
for (int i = 0; i < m_loop; ++i) {
const type* in0 = input + i * 4 * ldin;
const type* in1 = in0 + ldin;
const type* in2 = in1 + ldin;
const type* in3 = in2 + ldin;
if (4 * (i + 1) - M > 0) {
switch (4 * (i + 1) - M) {
case 3:
in1 = zero_buf;
case 2:
in2 = zero_buf;
case 1:
in3 = zero_buf;
default:
break;
}
}
for (int j = 0; j < K; ++j) {
*output++ = *in0++;
*output++ = *in1++;
*output++ = *in2++;
*output++ = *in3++;
}
for (int j = K; j < k_round; ++j) {
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
}
}
}
template <typename type, typename type2>
static void basic_gemm_c4(bool trans_a,
bool trans_b,
int m,
int n,
int k,
type2 alpha,
const type* a,
int lda,
const type* b,
int ldb,
type2 beta,
type2* c,
int ldc,
const type2* bias,
bool flag_bias = false,
bool flag_relu = false) {
type2* tmp_c = reinterpret_cast<type2*>(malloc(m * ldc * sizeof(type2)));
memset(tmp_c, 0, m * ldc * sizeof(type2));
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
auto bias_data = static_cast<type2>(0);
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
auto sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * lda + i];
} else {
av = a[i * lda + l];
}
if (trans_b) {
bv = b[j * ldb + l];
} else {
bv = b[l * ldb + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data;
if (flag_relu) {
tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
tmp_c[i * ldc + j] = tmp;
}
}
}
//! trans c to c4
basic_trans_mat_to_c4(tmp_c, c, ldc, m, n, false);
free(tmp_c);
}
template <typename type, typename type2>
static void basic_gemm(bool trans_a,
bool trans_b,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册