diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index cbbcf49a5fd55dabd6b072bc6b3b2e3f9bb91a13..cea1bf04677a3cc19ded5a20311518155760de69 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -60,6 +60,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) cc_library(math_arm SRCS funcs.cc packed_sgemm.cc + packed_sgemm_c4.cc sgemm.cc gemm_prepacked_int8.cc gemm_s8.cc diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index d8ef6ff47d0392ac15caf2d94b7c53ff63659da2..2d07e908c229c8d300ad64510d72fc12f8374fea 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -43,6 +43,7 @@ #include "lite/backends/arm/math/negative.h" #include "lite/backends/arm/math/norm.h" #include "lite/backends/arm/math/packed_sgemm.h" +#include "lite/backends/arm/math/packed_sgemm_c4.h" #include "lite/backends/arm/math/pad2d.h" #include "lite/backends/arm/math/pooling.h" #include "lite/backends/arm/math/power.h" diff --git a/lite/backends/arm/math/packed_sgemm_c4.cc b/lite/backends/arm/math/packed_sgemm_c4.cc new file mode 100644 index 0000000000000000000000000000000000000000..677490502e643fae3bc8149933e9936880711f0d --- /dev/null +++ b/lite/backends/arm/math/packed_sgemm_c4.cc @@ -0,0 +1,1155 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/packed_sgemm_c4.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void loadb_c4(float* out, + const float* in, + const int xstart, + const int xend, + const int k_round, + const int n) { + const int xlen = (xend - xstart + NBLOCK_C4 - 1) / NBLOCK_C4 * NBLOCK_C4; + int xloop = xlen / NBLOCK_C4; + const int flag_remain = n < xstart + xlen; + int remain = 0; + int remain4 = 0; + int remain1 = 0; + if (flag_remain) { + remain = (n - xstart) - (xloop - 1) * NBLOCK_C4; + remain4 = remain >> 2; + remain1 = remain & 3; + xloop -= 1; + } + const int ldo = NBLOCK_C4 * k_round; + const int kloop = k_round >> 2; + in += xstart * 4; + if (xloop > 0) { +#pragma omp parallel for + for (int i = 0; i < kloop; ++i) { + float* out_ptr = out + 4 * NBLOCK_C4 * i; + const float* in_ptr = in + i * 4 * n; + for (int j = 0; j < xloop; ++j) { + float* out_p = out_ptr + j * ldo; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" + "st1 {v0.4s, v1.4s}, [%[out]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[in]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[out]], #32 \n" + "ld1 {v6.4s, v7.4s}, [%[in]], #32 \n" + "st1 {v4.4s, v5.4s}, [%[out]], #32 \n" + "st1 {v6.4s, v7.4s}, [%[out]], #32 \n" + : [in] "+r"(in_ptr), [out] "+r"(out_p) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[in]]! \n" + "vld1.32 {d4-d7}, [%[in]]! \n" + "vst1.32 {d0-d3}, [%[out]]! \n" + "vld1.32 {d8-d11}, [%[in]]! \n" + "vst1.32 {d4-d7}, [%[out]]! \n" + "vld1.32 {d12-d15}, [%[in]]! \n" + "vst1.32 {d8-d11}, [%[out]]! \n" + "vst1.32 {d12-d15}, [%[out]]! \n" + : [in] "+r"(in_ptr), [out] "+r"(out_p) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +#endif // __aarch674__ + } + } + } + float* out_remain4 = out + xloop * k_round * NBLOCK_C4; + const float* in_remain4 = in + xloop * NBLOCK_C4 * 4; + if (remain4) { +#pragma omp parallel for + for (int i = 0; i < kloop; ++i) { + float* out_ptr = out_remain4 + 4 * 4 * i; + const float* in_ptr = in_remain4 + i * 4 * n; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" + "st1 {v0.4s, v1.4s}, [%[out]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[out]], #32 \n" + : [in] "+r"(in_ptr), [out] "+r"(out_ptr) + : + : "v0", "v1", "v2", "v3"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[in]]! \n" + "vld1.32 {d4-d7}, [%[in]]! \n" + "vst1.32 {d0-d3}, [%[out]]! \n" + "vst1.32 {d4-d7}, [%[out]]! \n" + : [in] "+r"(in_ptr), [out] "+r"(out_ptr) + : + : "q0", "q1", "q2", "q3"); +#endif // __aarch64__ + } + } + float* out_remain1 = out_remain4 + remain4 * k_round * 4; + const float* in_remain1 = in_remain4 + remain4 * 4 * 4; + if (remain1) { +#pragma omp parallel for + for (int i = 0; i < kloop; ++i) { + float* out_ptr = out_remain1 + 4 * remain1 * i; + const float* in_ptr = in_remain1 + i * 4 * n; + for (int j = 0; j < remain1; ++j) { + float32x4_t vin = vld1q_f32(in_ptr); + in_ptr += 4; + vst1q_f32(out_ptr, vin); + out_ptr += 4; + } + } + } +} + +void sgemm_prepack_c4_common(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx) { + const int m_round = (M + 3) / 4 * 4; + const int k_round = (K + 3) / 4 * 4; + size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; + int threads = ctx->threads(); + auto workspace = ctx->workspace_data(); + // l2 = ablock * K * threads + K * bchunk_w + threads * ablock * bchunk_w; + int bchunk_w = (l2_cache - threads * k_round * sizeof(float)) / + ((k_round + threads * MBLOCK_C4) * sizeof(float)); + bchunk_w = bchunk_w > N ? N : bchunk_w; + bchunk_w = bchunk_w / NBLOCK_C4 * NBLOCK_C4; + bchunk_w = bchunk_w > NBLOCK_C4 ? bchunk_w : NBLOCK_C4; + int bchunk_loop = (N + bchunk_w - 1) / bchunk_w; + + const int h_loop = m_round >> 2; // MBLOCK_C4 == 4; + const int kcnt = (k_round + KBLOCK_C4 - 1) / KBLOCK_C4; + const int ldc = N * 4; + const int lda = k_round * 4; + float bias_buf[m_round]; // NOLINT + if (has_bias) { + memcpy(bias_buf, bias, M * sizeof(float)); + memset(bias_buf + M, 0, (m_round - M) * sizeof(float)); + } else { + memset(bias_buf, 0, m_round * sizeof(float)); + } + // bchunk_loop + float* c = C; + for (int n = 0; n < bchunk_loop; ++n) { + int x_start = n * bchunk_w; + int x_end = x_start + bchunk_w; + int w_loop = bchunk_w / NBLOCK_C4; + int flag_remain = 0; + int w_loop4 = 0; + int remain = 0; + if (x_end > N) { + w_loop = (N - x_start) / NBLOCK_C4; + int w_loop_rem = (N - x_start) - w_loop * NBLOCK_C4; + w_loop4 = w_loop_rem >> 2; + remain = w_loop_rem & 3; + x_end = N; + flag_remain = 1; + } + float* bchunk = workspace; + loadb_c4(bchunk, B, x_start, x_end, k_round, N); + float* cchunk = c + n * bchunk_w * 4; + int has_remain = (n == bchunk_loop - 1) && flag_remain; +#pragma omp parallel for num_threads(threads) + for (int h = 0; h < h_loop; ++h) { + float* bias_h = bias_buf + h * 4; +#ifdef __aarch64__ + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vbias = vld1q_f32(bias_h); +#endif + const float* ablock = A_packed + h * lda; + const float* bblock = bchunk; + float* cblock = cchunk + h * ldc; + for (int w = 0; w < w_loop; ++w) { + int cnt = kcnt; + const float* ablock_ptr = ablock; +// clang-format off +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[a]] \n" + "prfm pldl1keep, [%[b]] \n" + "prfm pldl1keep, [%[b], #64] \n" + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/ + "mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/ + "mov v12.16b, %[vbias].16b \n" /* mov bias to c3*/ + /* load a0a1 to v1-v2 */ + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "mov v13.16b, %[vbias].16b \n" /* mov bias to c4*/ + "mov v14.16b, %[vbias].16b \n" /* mov bias to c5*/ + "mov v15.16b, %[vbias].16b \n" /* mov bias to c6*/ + "mov v16.16b, %[vbias].16b \n" /* mov bias to c7*/ + "1:\n" + /* load b0b1b2b3 to v5-v8 */ + "ld1 {v5.4s, v6.4s}, [%[b]], #32 \n" + "ld1 {v7.4s, v8.4s}, [%[b]], #32 \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v1.4s, v6.s[0] \n" + "fmla v11.4s, v1.4s, v7.s[0] \n" + "fmla v12.4s, v1.4s, v8.s[0] \n" + /* load b4b5b6b7 to v25-v28 */ + "ld1 {v25.4s, v26.4s}, [%[b]], #32 \n" + "ld1 {v27.4s, v28.4s}, [%[b]], #32 \n" + "prfm pldl1keep, [%[a], #32] \n" + "fmla v9.4s, v2.4s, v5.s[1] \n" + "fmla v10.4s, v2.4s, v6.s[1] \n" + "fmla v11.4s, v2.4s, v7.s[1] \n" + "fmla v12.4s, v2.4s, v8.s[1] \n" + "prfm pldl1keep, [%[b], #64] \n" + "fmla v13.4s, v1.4s, v25.s[0] \n" + "fmla v14.4s, v1.4s, v26.s[0] \n" + "fmla v15.4s, v1.4s, v27.s[0] \n" + "fmla v16.4s, v1.4s, v28.s[0] \n" + /* load a2a3 to v3-v4 */ + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[b], #128] \n" + "fmla v13.4s, v2.4s, v25.s[1] \n" + "fmla v14.4s, v2.4s, v26.s[1] \n" + "fmla v15.4s, v2.4s, v27.s[1] \n" + "fmla v16.4s, v2.4s, v28.s[1] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v3.4s, v6.s[2] \n" + "fmla v11.4s, v3.4s, v7.s[2] \n" + "fmla v12.4s, v3.4s, v8.s[2] \n" + "fmla v13.4s, v3.4s, v25.s[2] \n" + "fmla v14.4s, v3.4s, v26.s[2] \n" + "fmla v15.4s, v3.4s, v27.s[2] \n" + "fmla v16.4s, v3.4s, v28.s[2] \n" + /* load a0a1 to v1-v2 */ + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "fmla v9.4s, v4.4s, v5.s[3] \n" + "fmla v10.4s, v4.4s, v6.s[3] \n" + "fmla v11.4s, v4.4s, v7.s[3] \n" + "fmla v12.4s, v4.4s, v8.s[3] \n" + + "fmla v13.4s, v4.4s, v25.s[3] \n" + "fmla v14.4s, v4.4s, v26.s[3] \n" + "fmla v15.4s, v4.4s, v27.s[3] \n" + "fmla v16.4s, v4.4s, v28.s[3] \n" + "bne 1b\n" + "cbz %w[relu], 2f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "fmax v12.4s, v12.4s, %[vzero].4s \n" + "fmax v13.4s, v13.4s, %[vzero].4s \n" + "fmax v14.4s, v14.4s, %[vzero].4s \n" + "fmax v15.4s, v15.4s, %[vzero].4s \n" + "fmax v16.4s, v16.4s, %[vzero].4s \n" + "2:\n" + "st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[c]], #64 \n" + "st1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[c]], #64 \n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), [relu] "r"(has_relu), + [vbias] "w"(vbias), [vzero] "w" (vzero) + : "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v25", "v26", "v27", "v28", "cc", "memory"); +#else + asm volatile( + "vld1.32 {d6-d7}, [%[bias]] \n" + "pld [%[a]] \n" + "pld [%[b]] \n" + "pld [%[b], #64] \n" + "vmov.32 q8, q3 \n" /* mov bias to c0*/ + "vmov.32 q9, q3 \n" /* mov bias to c1*/ + "vmov.32 q10, q3 \n" /* mov bias to c2*/ + "vmov.32 q11, q3 \n" /* mov bias to c3*/ + "vld1.32 {d0-d3}, [%[a]]! \n" + "vmov.32 q12, q3 \n" /* mov bias to c4*/ + "vmov.32 q13, q3 \n" /* mov bias to c5*/ + "vmov.32 q14, q3 \n" /* mov bias to c6*/ + "vmov.32 q15, q3 \n" /* mov bias to c7*/ + "1:\n" + /* c0c1c2c3 */ + "vld1.32 {d8-d11}, [%[b]]! \n" + "vld1.32 {d12-d15}, [%[b]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q0, d8[0] \n" + "vmla.f32 q9, q0, d10[0] \n" + "vmla.f32 q10, q0, d12[0] \n" + "vmla.f32 q11, q0, d14[0] \n" + "vld1.32 {d4-d7}, [%[a]]! \n" + "vmla.f32 q8, q1, d8[1] \n" + "vmla.f32 q9, q1, d10[1] \n" + "vmla.f32 q10, q1, d12[1] \n" + "vmla.f32 q11, q1, d14[1] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q2, d9[0] \n" + "vmla.f32 q9, q2, d11[0] \n" + "vmla.f32 q10, q2, d13[0] \n" + "vmla.f32 q11, q2, d15[0] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q3, d9[1] \n" + "vmla.f32 q9, q3, d11[1] \n" + "vld1.f32 {d8-d11}, [%[b]]! \n" + "vmla.f32 q10, q3, d13[1] \n" + "vmla.f32 q11, q3, d15[1] \n" + "vld1.32 {d12-d15}, [%[b]]! \n" + /* c4c5c6c7 */ + "vmla.f32 q12, q0, d8[0] \n" + "vmla.f32 q13, q0, d10[0] \n" + "vmla.f32 q14, q0, d12[0] \n" + "vmla.f32 q15, q0, d14[0] \n" + "pld [%[a], #32] \n" + "vmla.f32 q12, q1, d8[1] \n" + "vmla.f32 q13, q1, d10[1] \n" + "vmla.f32 q14, q1, d12[1] \n" + "vmla.f32 q15, q1, d14[1] \n" + "vld1.32 {d0-d3}, [%[a]]! \n" + "vmla.f32 q12, q2, d9[0] \n" + "vmla.f32 q13, q2, d11[0] \n" + "vmla.f32 q14, q2, d13[0] \n" + "vmla.f32 q15, q2, d15[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q12, q3, d9[1] \n" + "vmla.f32 q13, q3, d11[1] \n" + "vmla.f32 q14, q3, d13[1] \n" + "vmla.f32 q15, q3, d15[1] \n" + "bne 1b\n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmov.u32 q0, #0 \n" + "vmax.f32 q8, q8, q0 \n" + "vmax.f32 q9, q9, q0 \n" + "vmax.f32 q10, q10, q0 \n" + "vmax.f32 q11, q11, q0 \n" + "vmax.f32 q12, q12, q0 \n" + "vmax.f32 q13, q13, q0 \n" + "vmax.f32 q14, q14, q0 \n" + "vmax.f32 q15, q15, q0 \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + "vst1.32 {d24-d27}, [%[c]]! \n" + "vst1.32 {d28-d31}, [%[c]]! \n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), + [relu] "r"(has_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory"); +#endif + // clang-format on + } + if (has_remain) { + if (w_loop4 > 0) { + int cnt = kcnt; + const float* ablock_ptr = ablock; +// clang-format off +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[a]] \n" + "prfm pldl1keep, [%[b]] \n" + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/ + "mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/ + "mov v12.16b, %[vbias].16b \n" /* mov bias to c3*/ + /* load a0a1 to v1-v2 */ + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "1:\n" + /* load b0b1b2b3 to v5-v8 */ + "ld1 {v5.4s, v6.4s}, [%[b]], #32 \n" + "ld1 {v7.4s, v8.4s}, [%[b]], #32 \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v1.4s, v6.s[0] \n" + "fmla v11.4s, v1.4s, v7.s[0] \n" + "fmla v12.4s, v1.4s, v8.s[0] \n" + /* load a2a3 to v3-v4 */ + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[a]] \n" + "fmla v9.4s, v2.4s, v5.s[1] \n" + "fmla v10.4s, v2.4s, v6.s[1] \n" + "fmla v11.4s, v2.4s, v7.s[1] \n" + "fmla v12.4s, v2.4s, v8.s[1] \n" + "prfm pldl1keep, [%[b]] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v3.4s, v6.s[2] \n" + "fmla v11.4s, v3.4s, v7.s[2] \n" + "fmla v12.4s, v3.4s, v8.s[2] \n" + /* load a0a1 to v1-v2 */ + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "fmla v9.4s, v4.4s, v5.s[3] \n" + "fmla v10.4s, v4.4s, v6.s[3] \n" + "fmla v11.4s, v4.4s, v7.s[3] \n" + "fmla v12.4s, v4.4s, v8.s[3] \n" + "bne 1b\n" + "cbz %w[relu], 2f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "fmax v12.4s, v12.4s, %[vzero].4s \n" + "2:\n" + "st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[c]], #64 \n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), + [relu] "r"(has_relu), + [vbias] "w"(vbias), + [vzero] "w" (vzero) + : "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "cc", "memory"); +#else + asm volatile( + "pld [%[a]] \n" + "pld [%[b]] \n" + "vld1.32 {d6-d7}, [%[bias]] \n" + "vld1.32 {d0-d3}, [%[a]]! \n" /* load a0 a1 */ + "vmov.32 q8, q3 \n" /* mov bias to c0 */ + "vmov.32 q9, q3 \n" /* mov bias to c1 */ + "vmov.32 q10, q3 \n" /* mov bias to c2 */ + "vmov.32 q11, q3 \n" /* mov bias to c3 */ + "1:\n" + /* c0c1c2c3 */ + "vld1.32 {d8-d11}, [%[b]]! \n" + "vld1.32 {d12-d15}, [%[b]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q0, d8[0] \n" + "vmla.f32 q9, q0, d10[0] \n" + "vmla.f32 q10, q0, d12[0] \n" + "vmla.f32 q11, q0, d14[0] \n" + "vld1.32 {d4-d7}, [%[a]]! \n" + "pld [%[a]] \n" + "vmla.f32 q8, q1, d8[1] \n" + "vmla.f32 q9, q1, d10[1] \n" + "vmla.f32 q10, q1, d12[1] \n" + "vmla.f32 q11, q1, d14[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q2, d9[0] \n" + "vmla.f32 q9, q2, d11[0] \n" + "vmla.f32 q10, q2, d13[0] \n" + "vmla.f32 q11, q2, d15[0] \n" + "vld1.32 {d0-d3}, [%[a]]! \n" + "vmla.f32 q8, q3, d9[1] \n" + "vmla.f32 q9, q3, d11[1] \n" + "vmla.f32 q10, q3, d13[1] \n" + "vmla.f32 q11, q3, d15[1] \n" + "bne 1b\n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmov.u32 q0, #0 \n" + "vmax.f32 q8, q8, q0 \n" + "vmax.f32 q9, q9, q0 \n" + "vmax.f32 q10, q10, q0 \n" + "vmax.f32 q11, q11, q0 \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), [relu] "r"(has_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "cc", "memory"); +#endif + // clang-format on + } + if (remain > 0) { + int cnt = kcnt; + const float* ablock_ptr = ablock; +// clang-format off +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[a]] \n" + "prfm pldl1keep, [%[b]] \n" + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "cmp %w[remain], #3 \n" + "beq 1f \n" + "cmp %w[remain], #2 \n" + "beq 2f \n" + /* remain 1 */ + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vzero].16b \n" /* mov zero to c1*/ + "3: \n" + "ld1 {v5.4s}, [%[b]], #16 \n" + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v2.4s, v5.s[1] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v4.4s, v5.s[3] \n" + "bne 3b \n" + "fadd v9.4s, v9.4s, v10.4s \n" + "cbz %w[relu], 6f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "6: \n" + "st1 {v9.4s}, [%[c]], #16 \n" + "b 9f \n" + /* remain 2 */ + "2: \n" + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/ + "mov v11.16b, %[vzero].16b \n" /* mov zero to c2*/ + "mov v12.16b, %[vzero].16b \n" /* mov zero to c3*/ + "4: \n" + "ld1 {v5.4s, v6.4s}, [%[b]], #32 \n" + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v1.4s, v6.s[0] \n" + "fmla v11.4s, v2.4s, v5.s[1] \n" + "fmla v12.4s, v2.4s, v6.s[1] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v3.4s, v6.s[2] \n" + "fmla v11.4s, v4.4s, v5.s[3] \n" + "fmla v12.4s, v4.4s, v6.s[3] \n" + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "bne 4b \n" + "fadd v9.4s, v9.4s, v11.4s \n" + "fadd v10.4s, v10.4s, v12.4s \n" + "cbz %w[relu], 7f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "7: \n" + "st1 {v9.4s, v10.4s}, [%[c]], #32 \n" + "b 9f \n" + /* remain 3 */ + "1: \n" + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/ + "mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/ + "5: \n" + "ld1 {v5.4s, v6.4s}, [%[b]], #32 \n" + "ld1 {v7.4s}, [%[b]], #16 \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v1.4s, v6.s[0] \n" + "fmla v11.4s, v1.4s, v7.s[0] \n" + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "fmla v9.4s, v2.4s, v5.s[1] \n" + "fmla v10.4s, v2.4s, v6.s[1] \n" + "fmla v11.4s, v2.4s, v7.s[1] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v3.4s, v6.s[2] \n" + "fmla v11.4s, v3.4s, v7.s[2] \n" + "prfm pldl1keep, [%[a]] \n" + "fmla v9.4s, v4.4s, v5.s[3] \n" + "fmla v10.4s, v4.4s, v6.s[3] \n" + "fmla v11.4s, v4.4s, v7.s[3] \n" + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "bne 5b \n" + "cbz %w[relu], 8f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "8: \n" + "st1 {v9.4s, v10.4s}, [%[c]], #32 \n" + "st1 {v11.4s}, [%[c]], #16 \n" + "9:\n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), [relu] "r"(has_relu), + [remain] "r"(remain), [vbias] "w"(vbias), + [vzero] "w" (vzero) + : "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v9", + "v10", "v11", "v12", "cc","memory"); +#else + asm volatile( + "pld [%[a]] \n" + "pld [%[b]] \n" + "vld1.32 {d0-d1}, [%[bias]] \n" + "vld1.32 {d2-d5}, [%[a]]! \n" + "vmov.u32 q15, #0 \n" + "cmp %[remain], #3 \n" + "beq 1f \n" + "cmp %[remain], #2 \n" + "beq 2f \n" + /* remain 1 */ + "vmov.32 q9, q0 \n" /* mov bias to c0*/ + "vmov.32 q10, q15 \n" /* mov zero to c1*/ + "3: \n" + "vld1.32 {d10-d11}, [%[b]]! \n" + "vld1.32 {d6-d9}, [%[a]]! \n" + "vmla.f32 q9, q1, d10[0] \n" + "vmla.f32 q10, q2, d10[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vld1.32 {d2-d5}, [%[a]]! \n" + "vmla.f32 q9, q3, d11[0] \n" + "vmla.f32 q10, q4, d11[1] \n" + "bne 3b \n" + "vadd.f32 q9, q9, q10 \n" + "cmp %[relu], #0 \n" + "beq 6f \n" + "vmax.f32 q9, q9, q15 \n" + "6: \n" + "vst1.32 {d18-d19}, [%[c]]! \n" + "b 9f \n" + /* remain 2 */ + "2: \n" + "vmov.u32 q9, q0 \n" /* mov bias to c0*/ + "vmov.u32 q10, q0 \n" /* mov bias to c1*/ + "vmov.u32 q11, q15 \n" /* mov zero to c2*/ + "vmov.u32 q12, q15 \n" /* mov zero to c3*/ + "4: \n" + "vld1.32 {d10-d13}, [%[b]]! \n" + "vld1.32 {d6-d9}, [%[a]]! \n" + "vmla.f32 q9, q1, d10[0] \n" + "vmla.f32 q10, q1, d12[0] \n" + "vmla.f32 q11, q2, d10[1] \n" + "vmla.f32 q12, q2, d12[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q9, q3, d11[0] \n" + "vmla.f32 q10, q3, d13[0] \n" + "vmla.f32 q11, q4, d11[1] \n" + "vmla.f32 q12, q4, d13[1] \n" + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 4b \n" + "vadd.f32 q9, q9, q11 \n" + "vadd.f32 q10, q10, q12 \n" + "cmp %[relu], #0 \n" + "beq 7f \n" + "vmax.f32 q9, q9, q15 \n" + "vmax.f32 q10, q10, q15 \n" + "7: \n" + "vst1.32 {d18-d21}, [%[c]]! \n" + "b 9f \n" + /* remain 3 */ + "1: \n" + "vmov.u32 q9, q0 \n" /* mov bias to c0*/ + "vmov.u32 q10, q0 \n" /* mov bias to c1*/ + "vmov.u32 q11, q0 \n" /* mov bias to c2*/ + "5: \n" + "vld1.32 {d10-d13}, [%[b]]! \n" + "vld1.32 {d14-d15}, [%[b]]! \n" + "vmla.f32 q9, q1, d10[0] \n" + "vmla.f32 q10, q1, d12[0] \n" + "vmla.f32 q11, q1, d14[0] \n" + "vld1.32 {d6-d9}, [%[a]]! \n" + "vmla.f32 q9, q2, d10[1] \n" + "vmla.f32 q10, q2, d12[1] \n" + "vmla.f32 q11, q2, d14[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q9, q3, d11[0] \n" + "vmla.f32 q10, q3, d13[0] \n" + "vmla.f32 q11, q3, d15[0] \n" + "pld [%[a]] \n" + "vmla.f32 q9, q4, d11[1] \n" + "vmla.f32 q10, q4, d13[1] \n" + "vmla.f32 q11, q4, d15[1] \n" + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 5b \n" + "cmp %[relu], #0 \n" + "beq 8f \n" + "vmax.f32 q9, q9, q15 \n" + "vmax.f32 q10, q10, q15 \n" + "vmax.f32 q11, q11, q15 \n" + "8: \n" + "vst1.32 {d18-d21}, [%[c]]! \n" + "vst1.32 {d22-d23}, [%[c]]! \n" + "9:\n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), + [relu] "r"(has_relu), + [remain] "r"(remain) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q9", + "q10", "q11", "q12", "q15", "cc","memory"); +#endif + // clang-format on + } + } + } + } +} + +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx) { + const int m_round = (M + 3) / 4 * 4; + const int k_round = (K + 3) / 4 * 4; + const int mloop = m_round >> 2; + const int lda = 4 * k_round; + const int ldb_byte = 4 * N * sizeof(float); + const int kcnt = k_round >> 2; + float bias_buf[m_round]; // NOLINT + if (has_bias) { + memcpy(bias_buf, bias, M * sizeof(float)); + memset(bias_buf + M, 0, (m_round - M) * sizeof(float)); + } else { + memset(bias_buf, 0, m_round * sizeof(float)); + } +#ifdef __aarch64__ + float32x4_t vzero = vdupq_n_f32(0.f); +#endif + const float* bias_ptr = bias_buf; + for (int m = 0; m < mloop; ++m) { +#ifdef __aarch64__ + float32x4_t vbias = vld1q_f32(bias_ptr); +#endif + const float* b = B; + int n = N; +#ifdef __aarch64__ + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* mov bias to c0-c7*/ + "mov v8.16b, %[vbias].16b \n" + "mov v9.16b, %[vbias].16b \n" + "mov v10.16b, %[vbias].16b \n" + "mov v11.16b, %[vbias].16b \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "mov v12.16b, %[vbias].16b \n" + "mov v13.16b, %[vbias].16b \n" + "mov v14.16b, %[vbias].16b \n" + "mov v15.16b, %[vbias].16b \n" + "1:\n" + /* load b2, b3 */ + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + /* load b4, b5 */ + "ld1 {v4.4s, v5.4s}, [%[b]], #32 \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load b6, b7 */ + "ld1 {v6.4s, v7.4s}, [%[b]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "sub %[b], %[b], #128 \n" + "fmla v12.4s, v16.4s, v4.s[0] \n" + "fmla v13.4s, v16.4s, v5.s[0] \n" + "fmla v14.4s, v16.4s, v6.s[0] \n" + "fmla v15.4s, v16.4s, v7.s[0] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v12.4s, v17.4s, v4.s[1] \n" + "fmla v13.4s, v17.4s, v5.s[1] \n" + "fmla v14.4s, v17.4s, v6.s[1] \n" + "fmla v15.4s, v17.4s, v7.s[1] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v12.4s, v18.4s, v4.s[2] \n" + "fmla v13.4s, v18.4s, v5.s[2] \n" + "fmla v14.4s, v18.4s, v6.s[2] \n" + "fmla v15.4s, v18.4s, v7.s[2] \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "fmla v12.4s, v19.4s, v4.s[3] \n" + "fmla v13.4s, v19.4s, v5.s[3] \n" + "fmla v14.4s, v19.4s, v6.s[3] \n" + "fmla v15.4s, v19.4s, v7.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "cbz %w[relu], 2f \n" + "fmax v8.4s, v8.4s, %[vzero].4s \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "fmax v12.4s, v12.4s, %[vzero].4s \n" + "fmax v13.4s, v13.4s, %[vzero].4s \n" + "fmax v14.4s, v14.4s, %[vzero].4s \n" + "fmax v15.4s, v15.4s, %[vzero].4s \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [vbias] "w" (vbias), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* mov bias to c0-c3*/ + "mov v8.16b, %[vbias].16b \n" + "mov v9.16b, %[vbias].16b \n" + "mov v10.16b, %[vbias].16b \n" + "mov v11.16b, %[vbias].16b \n" + "1:\n" + /* load b0-b3 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + "sub %[b], %[b], #64 \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "cbz %w[relu], 2f \n" + "fmax v8.4s, v8.4s, %[vzero].4s \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [vbias] "w" (vbias), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v8", "v9", + "v10", "v11", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* mov bias to c0 */ + "mov v8.16b, %[vbias].16b \n" + "mov v9.16b, %[vzero].16b \n" + "1:\n" + /* load b0 */ + "ld1 {v0.4s}, [%[b]], #16 \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v17.4s, v0.s[1] \n" + "sub %[b], %[b], #16 \n" + "subs %w[cnt], %w[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v19.4s, v0.s[3] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "bne 1b \n" + "fadd v8.4s, v8.4s, v9.4s \n" + "cbz %w[relu], 2f \n" + "fmax v8.4s, v8.4s, %[vzero].4s \n" + "2:\n" + "st1 {v8.4s}, [%[c]], #16 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [vbias] "w" (vbias), + [vzero] "w" (vzero) + : "v0", "v8", "v9", "v16", "v17", + "v18", "v19", "cc", "memory" + ); + b += 4; + } +#else + for (; n > 5; n -= 6) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "vld1.32 {d8-d9}, [%[bias]] \n" + /* load a0, a1 */ + "vld1.32 {d12-d15}, [%[a]]! \n" + /* mov bias to c0-c7*/ + "vmov.u32 q10, q4 \n" + "vmov.u32 q11, q4 \n" + "vmov.u32 q12, q4 \n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]!\n" + "vld1.32 {d4-d7}, [%[b]]!\n" + "vmov.u32 q13, q4 \n" + "vmov.u32 q14, q4 \n" + "vmov.u32 q15, q4 \n" + "1:\n" + /* load b4, b5 */ + "vld1.32 {d8-d11}, [%[b]]! \n" + /* load a2, a3 */ + "vld1.32 {d16-d19}, [%[a]]!\n" + "vmla.f32 q10, q6, d0[0] \n" + "vmla.f32 q11, q6, d2[0] \n" + "vmla.f32 q12, q6, d4[0] \n" + "vmla.f32 q13, q6, d6[0] \n" + "vmla.f32 q14, q6, d8[0] \n" + "vmla.f32 q15, q6, d10[0] \n" + "sub %[b], %[b], #96 \n" + "vmla.f32 q10, q7, d0[1] \n" + "vmla.f32 q11, q7, d2[1] \n" + "vmla.f32 q12, q7, d4[1] \n" + "vmla.f32 q13, q7, d6[1] \n" + "vmla.f32 q14, q7, d8[1] \n" + "vmla.f32 q15, q7, d10[1] \n" + "add %[b], %[b], %[ldb] \n" + "pld [%[b]] \n" + /* load a0, a1 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "vmla.f32 q10, q8, d1[0] \n" + "vmla.f32 q11, q8, d3[0] \n" + "vmla.f32 q12, q8, d5[0] \n" + "vmla.f32 q13, q8, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q10, q9, d1[1] \n" + "vmla.f32 q11, q9, d3[1] \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q8, d9[0] \n" + "vmla.f32 q15, q8, d11[0] \n" + "vmla.f32 q12, q9, d5[1] \n" + "vmla.f32 q13, q9, d7[1] \n" + "vmla.f32 q14, q9, d9[1] \n" + "vmla.f32 q15, q9, d11[1] \n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "subs %[cnt], %[cnt], #1 \n" + "bne 1b \n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmov.u32 q0, #0 \n" + "vmax.f32 q10, q10, q0 \n" + "vmax.f32 q11, q11, q0 \n" + "vmax.f32 q12, q12, q0 \n" + "vmax.f32 q13, q13, q0 \n" + "vmax.f32 q14, q14, q0 \n" + "vmax.f32 q15, q15, q0 \n" + "2: \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + "vst1.32 {d24-d27}, [%[c]]! \n" + "vst1.32 {d28-d31}, [%[c]]! \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [bias] "r"(bias_ptr) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory" + ); + b += 4 * 6; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "vld1.32 {d24-d25}, [%[bias]] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + /* mov bias to c0-c3*/ + "vmov.u32 q8, q12 \n" + "vmov.u32 q9, q12 \n" + "vmov.u32 q10, q12 \n" + "vmov.u32 q11, q12 \n" + "vmov.u32 q13, #0 \n" + "1:\n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vld1.32 {d4-d7}, [%[b]]! \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + "sub %[b], %[b], #64 \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "bne 1b \n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmax.f32 q8, q8, q13 \n" + "vmax.f32 q9, q9, q13 \n" + "vmax.f32 q10, q10, q13 \n" + "vmax.f32 q11, q11, q13 \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]!\n" + "vst1.32 {d20-d23}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [bias] "r"(bias_ptr) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "vld1.32 {d14-d15}, [%[bias]] \n" + "vmov.u32 q8, #0 \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + /* mov bias to c0 */ + "vmov.u32 q5, q7 \n" + "vmov.u32 q6, q8 \n" + "1:\n" + /* load b0 */ + "vld1.32 {d0-d1}, [%[b]]! \n" + /* load a2, a3 */ + "vld1.32 {d6-d9}, [%[a]]! \n" + "vmla.f32 q5, q1, d0[0] \n" + "vmla.f32 q6, q2, d0[1] \n" + "sub %[b], %[b], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q5, q3, d1[0] \n" + "vmla.f32 q6, q4, d1[1] \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 1b \n" + "vadd.f32 q5, q5, q6 \n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmax.f32 q5, q5, q8 \n" + "2:\n" + "vst1.32 {d10-d11}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [bias] "r"(bias_ptr) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "cc", "memory" + ); + // clang-format on + b += 4; + } +#endif + bias_ptr += 4; + A_packed += lda; + } +} + +void sgemm_prepack_c4(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx) { + if (N > 16) { + sgemm_prepack_c4_common( + M, N, K, A_packed, B, C, bias, has_bias, has_relu, ctx); + } else { + sgemm_prepack_c4_small( + M, N, K, A_packed, B, C, bias, has_bias, has_relu, ctx); + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/packed_sgemm_c4.h b/lite/backends/arm/math/packed_sgemm_c4.h new file mode 100644 index 0000000000000000000000000000000000000000..0b88de36d75ebbafe851d657cfa3b291ebdf353f --- /dev/null +++ b/lite/backends/arm/math/packed_sgemm_c4.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/context.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +constexpr int MBLOCK_C4 = 4; +constexpr int NBLOCK_C4 = 8; +constexpr int KBLOCK_C4 = 4; + +void sgemm_prepack_c4(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/tests/math/CMakeLists.txt b/lite/tests/math/CMakeLists.txt index d2acd14c83352c40e66781a9152a2f619918ddf2..b199f655239150438ecba881d5e1e4fa1e5dfa31 100644 --- a/lite/tests/math/CMakeLists.txt +++ b/lite/tests/math/CMakeLists.txt @@ -1,6 +1,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(sgemv_compute_test SRCS sgemv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(sgemm_c4_compute_test SRCS sgemm_c4_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/math/sgemm_c4_compute_test.cc b/lite/tests/math/sgemm_c4_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5fcc54f33833b52266b09b91d36fe8b74447b95e --- /dev/null +++ b/lite/tests/math/sgemm_c4_compute_test.cc @@ -0,0 +1,236 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/tests/utils/fill_data.h" +#include "lite/tests/utils/naive_math_impl.h" +#ifdef LITE_WITH_ARM +#include "lite/backends/arm/math/funcs.h" +#endif // LITE_WITH_ARM +#include "lite/core/context.h" +#include "lite/core/tensor.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +typedef paddle::lite::Tensor Tensor; +using paddle::lite::Timer; + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(M, 512, "gemm_c4: M"); +DEFINE_int32(N, 512, "gemm_c4: N"); +DEFINE_int32(K, 512, "gemm_c4: K"); + +DEFINE_bool(flag_relu, false, "do relu"); +DEFINE_bool(flag_bias, false, "with bias"); + +bool test_sgemm_c4( + int m, int n, int k, bool has_bias, bool has_relu, int cls, int ths) { + int m_round = (m + 3) / 4 * 4; + int k_round = (k + 3) / 4 * 4; + int size_a = m * k; + int size_b = n * k; + int size_a_c4 = m_round * k_round; + int size_b_c4 = k_round * n; + + Tensor ta; + Tensor tb; + Tensor ta_c4; + Tensor tb_c4; + Tensor tc; + Tensor tc_basic; + Tensor tc_backup; + Tensor tbias; + + ta.Resize({size_a}); + tb.Resize({size_b}); + ta_c4.Resize({size_a_c4}); + tb_c4.Resize({size_b_c4}); + tc.Resize({m_round * n}); + tc_basic.Resize({m_round * n}); + tbias.Resize({m}); + + ta.set_precision(PRECISION(kFloat)); + tb.set_precision(PRECISION(kFloat)); + ta_c4.set_precision(PRECISION(kFloat)); + tb_c4.set_precision(PRECISION(kFloat)); + tc.set_precision(PRECISION(kFloat)); + tc_basic.set_precision(PRECISION(kFloat)); + tbias.set_precision(PRECISION(kFloat)); + + fill_tensor_rand(ta, -1.f, 1.f); + fill_tensor_rand(tb, -1.f, 1.f); + fill_tensor_rand(tbias, -1.f, 1.f); + fill_tensor_rand(tc, -1.f, 1.f); + + auto da = ta.mutable_data(); + auto db = tb.mutable_data(); + auto da_c4 = ta_c4.mutable_data(); + auto db_c4 = tb_c4.mutable_data(); + auto dc_basic = tc_basic.mutable_data(); + auto dbias = tbias.mutable_data(); + + // trans A, B to c4 + basic_trans_mat_to_c4(da, da_c4, k, m, k, true); + basic_trans_mat_to_c4(db, db_c4, n, k, n, false); + + LOG(INFO) << "sgemm_c4 M: " << m << ", N: " << n << ", K: " << k + << ", relu: " << (has_relu ? "true" : "false") + << ", bias: " << (has_bias ? "true" : "false"); + + if (FLAGS_check_result) { + basic_gemm_c4(false, + false, + m, + n, + k, + 1.f, + da, + k, + db, + n, + 0.f, + dc_basic, + n, + dbias, + has_bias, + has_relu); + } + Timer t0; +#ifdef LITE_WITH_ARM + //! compute + double ops = 2.0 * m_round * n * k_round; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), ths); + auto dc = tc.mutable_data(); + for (int j = 0; j < FLAGS_warmup; ++j) { + paddle::lite::arm::math::sgemm_prepack_c4( + m, n, k, da_c4, db_c4, dc, dbias, has_bias, has_relu, &ctx); + } + + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + paddle::lite::arm::math::sgemm_prepack_c4( + m, n, k, da_c4, db_c4, dc, dbias, has_bias, has_relu, &ctx); + t0.end(); + } + LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k + << ", power_mode: " << cls << ", threads: " << ths + << ", GOPS: " << ops * 1e-9f + << " GOPS, avg time: " << t0.get_average_ms() + << " ms, min time: " << t0.get_min_time() + << " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms() + << " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time() + << " GOPs"; + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tc_basic, tc, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) { + Tensor tdiff; + tdiff.set_precision(PRECISION(kFloat)); + tdiff.Resize(tc.dims()); + tensor_diff(tc_basic, tc, tdiff); + LOG(INFO) << "a: "; + print_tensor(ta); + LOG(INFO) << "a_c4: "; + print_tensor(ta_c4); + LOG(INFO) << "b: "; + print_tensor(tb); + LOG(INFO) << "b_c4: "; + print_tensor(tb_c4); + LOG(INFO) << "basic result: "; + print_tensor(tc_basic); + LOG(INFO) << "lite result: "; + print_tensor(tc); + LOG(INFO) << "diff result: "; + print_tensor(tdiff); + return false; + } + } +#endif + return true; +} + +TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { + if (FLAGS_basic_test) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + LOG(INFO) << "run basic sgemm_c4 test"; + for (auto& m : {1, 3, 8, 32, 397}) { + for (auto& n : {1, 2, 3, 4, 13, 141, 789}) { + for (auto& k : {1, 3, 8, 59, 234}) { + for (auto& has_bias : {false, true}) { + for (auto& has_relu : {false, true}) { + for (auto& th : {1, 2, 4}) { + auto flag = test_sgemm_c4( + m, n, k, has_bias, has_relu, FLAGS_power_mode, th); + if (flag) { + LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << " passed\n"; + } else { + LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << " failed\n"; + } + } + } + } + } + } + } + } +} + +TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + auto flag = test_sgemm_c4(FLAGS_M, + FLAGS_N, + FLAGS_K, + FLAGS_flag_bias, + FLAGS_flag_relu, + FLAGS_power_mode, + FLAGS_threads); + if (!flag) { + LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N + << ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias + << ", relu: " << FLAGS_flag_relu << " failed!!"; + } + LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K + << ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu + << " passed!!"; +} diff --git a/lite/tests/utils/naive_math_impl.h b/lite/tests/utils/naive_math_impl.h index 846126ac247ee685bd8772ede87635c45b52f79a..6d8c57c1375d37098bc9099319b0348e289f116d 100644 --- a/lite/tests/utils/naive_math_impl.h +++ b/lite/tests/utils/naive_math_impl.h @@ -14,6 +14,108 @@ #pragma once +template +static void basic_trans_mat_to_c4(const type* input, + type* output, + const int ldin, + const int M, + const int K, + bool pack_k) { + const int m_round = (M + 3) / 4 * 4; + int k_round = (K + 3) / 4 * 4; + if (!pack_k) { + k_round = K; + } + const int m_loop = m_round / 4; + type zero_buf[K]; + memset(zero_buf, 0, K * sizeof(type)); + for (int i = 0; i < m_loop; ++i) { + const type* in0 = input + i * 4 * ldin; + const type* in1 = in0 + ldin; + const type* in2 = in1 + ldin; + const type* in3 = in2 + ldin; + if (4 * (i + 1) - M > 0) { + switch (4 * (i + 1) - M) { + case 3: + in1 = zero_buf; + case 2: + in2 = zero_buf; + case 1: + in3 = zero_buf; + default: + break; + } + } + for (int j = 0; j < K; ++j) { + *output++ = *in0++; + *output++ = *in1++; + *output++ = *in2++; + *output++ = *in3++; + } + for (int j = K; j < k_round; ++j) { + *output++ = static_cast(0); + *output++ = static_cast(0); + *output++ = static_cast(0); + *output++ = static_cast(0); + } + } +} + +template +static void basic_gemm_c4(bool trans_a, + bool trans_b, + int m, + int n, + int k, + type2 alpha, + const type* a, + int lda, + const type* b, + int ldb, + type2 beta, + type2* c, + int ldc, + const type2* bias, + bool flag_bias = false, + bool flag_relu = false) { + type2* tmp_c = reinterpret_cast(malloc(m * ldc * sizeof(type2))); + memset(tmp_c, 0, m * ldc * sizeof(type2)); +#pragma omp parallel for + for (int i = 0; i < m; ++i) { + auto bias_data = static_cast(0); + if (flag_bias) { + bias_data = bias[i]; + } + for (int j = 0; j < n; ++j) { + auto sum = static_cast(0); + for (int l = 0; l < k; ++l) { + type av; + type bv; + if (trans_a) { + av = a[l * lda + i]; + } else { + av = a[i * lda + l]; + } + if (trans_b) { + bv = b[j * ldb + l]; + } else { + bv = b[l * ldb + j]; + } + sum += av * bv; + } + type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data; + if (flag_relu) { + tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0; + } else { + tmp_c[i * ldc + j] = tmp; + } + } + } + //! trans c to c4 + basic_trans_mat_to_c4(tmp_c, c, ldc, m, n, false); + free(tmp_c); +} + template static void basic_gemm(bool trans_a, bool trans_b,