diff --git a/cmake/cross_compiling/android.cmake b/cmake/cross_compiling/android.cmake index c36057544ab6503befde5642c38c75c25906b585..a12ecdccc1ec7dd117d47720fe72879f490092a4 100644 --- a/cmake/cross_compiling/android.cmake +++ b/cmake/cross_compiling/android.cmake @@ -16,6 +16,8 @@ if(NOT ANDROID) return() endif() +add_definitions(-DLITE_WITH_ANDROID) + if(NOT DEFINED ANDROID_NDK) set(ANDROID_NDK $ENV{NDK_ROOT}) if(NOT ANDROID_NDK) diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index ba05973a863a2a6b81ec7e9963c36b7c8f85e67a..93c3d9167c3307293450ba583a35a383d1f96fa2 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -118,6 +118,7 @@ endfunction() add_subdirectory(core) add_subdirectory(x86) +add_subdirectory(arm) add_subdirectory(host) add_subdirectory(cuda) add_subdirectory(operators) diff --git a/paddle/fluid/lite/arm/CMakeLists.txt b/paddle/fluid/lite/arm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8abd04b52338299f75399903aa68fe834ce81d04 --- /dev/null +++ b/paddle/fluid/lite/arm/CMakeLists.txt @@ -0,0 +1,2 @@ + +add_subdirectory(math) diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..278cb54a418d8e546a1ebb26c5664412ca692590 --- /dev/null +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -0,0 +1,2 @@ + +cc_library(math_arm SRCS funcs.cc packed_sgemm.cc DEPS ${lite_kernel_deps} eigen3) diff --git a/paddle/fluid/lite/arm/math/funcs.cc b/paddle/fluid/lite/arm/math/funcs.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff1bf5b09a989972993a3fe525bd5d6597874057 --- /dev/null +++ b/paddle/fluid/lite/arm/math/funcs.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/funcs.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void fill_bias_fc(float *tensor, const float *bias, const int num, + const int channel) { + int cnt = channel >> 4; + int remain = channel & 15; + + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = tensor + j * channel; + + float32x4_t vout1; + float32x4_t vout2; + float32x4_t vout3; + float32x4_t vout4; + + for (int i = 0; i < cnt; ++i) { + float32x4_t vin1 = vld1q_f32(ptr_out); + float32x4_t vb1 = vld1q_f32(ptr_bias); + + float32x4_t vin2 = vld1q_f32(ptr_out + 4); + float32x4_t vb2 = vld1q_f32(ptr_bias + 4); + + float32x4_t vin3 = vld1q_f32(ptr_out + 8); + float32x4_t vb3 = vld1q_f32(ptr_bias + 8); + + float32x4_t vin4 = vld1q_f32(ptr_out + 12); + float32x4_t vb4 = vld1q_f32(ptr_bias + 12); + + vout1 = vaddq_f32(vin1, vb1); + vout2 = vaddq_f32(vin2, vb2); + vout3 = vaddq_f32(vin3, vb3); + vout4 = vaddq_f32(vin4, vb4); + + vst1q_f32(ptr_out, vout1); + vst1q_f32(ptr_out + 4, vout2); + vst1q_f32(ptr_out + 8, vout3); + vst1q_f32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + +#if 0 + if (cnt > 0) { + asm( + "1: \n" + "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" + "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" + "vadd.f32 q2, q0, q1 @ add bias\n" + "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 1b @ jump to main loop\n" + :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ + [cnt] "+r"(cnt) + : + :"q0", "q1", "q2" + ); + } +#endif + for (; remain > 0; remain--) { + *(ptr_out++) += *(ptr_bias++); + } + } +} + +template <> +void fill_bias_fc(int *tensor, const int *bias, const int num, + const int channel) { + int cnt = channel >> 4; + int remain = channel & 15; + + for (int j = 0; j < num; ++j) { + const int *ptr_bias = bias; + int *ptr_out = tensor + j * channel; + + int32x4_t vout1; + int32x4_t vout2; + int32x4_t vout3; + int32x4_t vout4; + + for (int i = 0; i < cnt; ++i) { + int32x4_t vin1 = vld1q_s32(ptr_out); + int32x4_t vb1 = vld1q_s32(ptr_bias); + + int32x4_t vin2 = vld1q_s32(ptr_out + 4); + int32x4_t vb2 = vld1q_s32(ptr_bias + 4); + + int32x4_t vin3 = vld1q_s32(ptr_out + 8); + int32x4_t vb3 = vld1q_s32(ptr_bias + 8); + + int32x4_t vin4 = vld1q_s32(ptr_out + 12); + int32x4_t vb4 = vld1q_s32(ptr_bias + 12); + + vout1 = vaddq_s32(vin1, vb1); + vout2 = vaddq_s32(vin2, vb2); + vout3 = vaddq_s32(vin3, vb3); + vout4 = vaddq_s32(vin4, vb4); + + vst1q_s32(ptr_out, vout1); + vst1q_s32(ptr_out + 4, vout2); + vst1q_s32(ptr_out + 8, vout3); + vst1q_s32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + +#if 0 + if (cnt > 0) { + asm( + "1: \n" + "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" + "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" + "vadd.s32 q2, q0, q1 @ add bias\n" + "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 1b @ jump to main loop\n" + :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ + [cnt] "+r"(cnt) + : + :"q0", "q1", "q2" + ); + } +#endif + for (; remain > 0; remain--) { + *(ptr_out++) += *(ptr_bias++); + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/funcs.h b/paddle/fluid/lite/arm/math/funcs.h new file mode 100644 index 0000000000000000000000000000000000000000..dd3ba2db509971dafd482303129ae4e24479dbdb --- /dev/null +++ b/paddle/fluid/lite/arm/math/funcs.h @@ -0,0 +1,53 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/lite/arm/math/packed_sgemm.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void fill_bias_fc(T* tensor, const T* bias, const int num, const int channel); + +template +void fc_compute_eigen(const T* x, int x_h, int x_w, // + const T* w, int w_h, int w_w, // + const T* b, // + T* out) { + using matrix_t = + Eigen::Matrix; + + Eigen::Map X(x, x_h, x_w); + Eigen::Map W(w, w_h, w_w); + Eigen::Map Out(out, x_h, w_w); + + Out = X * W; + + if (b) { + Eigen::Map> B(b, w_w); + Out = Out.array().rowwise() + B.array(); + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/packed_sgemm.cc b/paddle/fluid/lite/arm/math/packed_sgemm.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3a5322739076c1bfca9323e619478d4274bc037 --- /dev/null +++ b/paddle/fluid/lite/arm/math/packed_sgemm.cc @@ -0,0 +1,3049 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/packed_sgemm.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +void prepackA_8x12(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax); +void prepackA_trans_8x12(float *out, const float *in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax); +void sgemm_conv_8x12(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext *ctx); +#else +// for kA72 +void prepackA_6x8(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax); +void prepackA_trans_6x8(float *out, const float *in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax); +// for kA73 +void prepackA_4x8(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax); +void prepackA_trans_4x8(float *out, const float *in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax); +// for kA72, 6x8 +void sgemm_conv_6x8(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext *ctx); +// for kA73, 4x8 +void sgemm_conv_4x8(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext *ctx); +#endif // __aarch64__ + +/** + * \brief input data is not transpose + * for arm-v7a, transform data to block x k x 6 layout + * for arm-v8a, transform data to block x k x 8 layout + */ +void prepackA(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax, bool is_trans, + ARMContext *ctx) { +#ifdef __aarch64__ + if (is_trans) { + prepackA_trans_8x12(out, in, ldin, m0, mmax, k0, kmax); + } else { + prepackA_8x12(out, in, ldin, m0, mmax, k0, kmax); + } +#else + if (ctx->arch() == kA73) { + if (is_trans) { + prepackA_trans_4x8(out, in, ldin, m0, mmax, k0, kmax); + } else { + prepackA_4x8(out, in, ldin, m0, mmax, k0, kmax); + } + } else { + if (is_trans) { + prepackA_trans_6x8(out, in, ldin, m0, mmax, k0, kmax); + } else { + prepackA_6x8(out, in, ldin, m0, mmax, k0, kmax); + } + } +#endif +} + +void prepackA(TensorLite *tout, const TensorLite &tin, int m, int k, int group, + bool is_trans, ARMContext *ctx) { + int hblock = get_hblock(ctx->arch()); + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int group_size_round_up = ((m_roundup * k + 15) / 16) * 16; + if (tout->numel() < group_size_round_up * group) { + tout->Resize({group_size_round_up * group}); + } + int lda = k; + if (is_trans) { + lda = m; + } + for (int g = 0; g < group; ++g) { + const float *weights_group = tin.data() + g * m * k; + float *weights_trans_ptr = + tout->mutable_data() + g * group_size_round_up; + prepackA(weights_trans_ptr, weights_group, lda, 0, m, 0, k, is_trans, ctx); + } +} + +/// a: m*k b: k*n c: m*n +void sgemm_prepack(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool is_transB, ARMContext *ctx) { +#ifdef __aarch64__ + sgemm_conv_8x12(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, + ctx); +#else // armv7 + if (ctx->arch() == kA73) { + sgemm_conv_4x8(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, + ctx); + } else { + sgemm_conv_6x8(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, + ctx); + } +#endif // arm64 +} + +#ifdef __aarch64__ +void prepackA_8x12(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t *dout = reinterpret_cast(out); + const uint32_t *inptr = reinterpret_cast(in); + + int stride = x_len * 8; +#pragma omp parallel for + for (int y = m0; y < mmax; y += 8) { + uint32_t *outptr = dout + stride * (y - m0) / 8; + + const uint32_t *inptr0 = inptr + y * ldin + k0; + const uint32_t *inptr1 = inptr0 + ldin; + const uint32_t *inptr2 = inptr1 + ldin; + const uint32_t *inptr3 = inptr2 + ldin; + const uint32_t *inptr4 = inptr3 + ldin; + const uint32_t *inptr5 = inptr4 + ldin; + const uint32_t *inptr6 = inptr5 + ldin; + const uint32_t *inptr7 = inptr6 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + : + : [ptr0] "r"(inptr0), [ptr1] "r"(inptr1), [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3), [ptr4] "r"(inptr4), [ptr5] "r"(inptr5), + [ptr6] "r"(inptr6), [ptr7] "r"(inptr7) + : "memory"); + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= mmax) { + switch ((y + 7) - mmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + // Load up 8 elements (2 vectors) from each of 8 sources. + "LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3 + "LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3 + "LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3 + "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 + "prfm pldl1keep, [%[inptr0], #128] \n" + "LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3 + "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 + "LDP q8, q9, [%[inptr4]], #32\n" + "LDP q10, q11, [%[inptr5]], #32\n" + "LDP q12, q13, [%[inptr6]], #32\n" + "ZIP1 v18.4s, v8.4s, v12.4s\n" + "prfm pldl1keep, [%[inptr1], #128]\n" + "LDP q14, q15, [%[inptr7]], #32\n" + "ZIP1 v19.4s, v10.4s, v14.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + "prfm pldl1keep, [%[inptr2], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v0.4s, v4.4s\n" + "prfm pldl1keep, [%[inptr3], #128]\n" + "ZIP2 v17.4s, v2.4s, v6.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Write back the first + // element of each source + + "ZIP2 v18.4s, v8.4s, v12.4s\n" + "ZIP2 v19.4s, v10.4s, v14.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Write back the second + // element of each source + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr4], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP1 v16.4s, v1.4s, v5.4s\n" + "prfm pldl1keep, [%[inptr5], #128]\n" + "ZIP1 v17.4s, v3.4s, v7.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Third element + + "ZIP1 v18.4s, v9.4s, v13.4s\n" + "ZIP1 v19.4s, v11.4s, v15.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Fourth element + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr6], #128]\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v1.4s, v5.4s\n" + "ZIP2 v17.4s, v3.4s, v7.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Fifth element + + "ZIP2 v18.4s, v9.4s, v13.4s\n" + "prfm pldl1keep, [%[inptr7], #128]\n" + "ZIP2 v19.4s, v11.4s, v15.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Sixth element + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Seventh element + + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Eighth element + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } +} + +void prepackA_trans_8x12(float *out, const float *in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax) { + uint32_t *outptr = reinterpret_cast(out); + const uint32_t *inptr = + reinterpret_cast(in) + k0 * ldin + m0; + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 8 * (x_len / 8); + int right_pad = 8 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + uint32_t *outptr_row = outptr; + int stride_out = 8 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t *ptr0 = inptr + y * ldin; + const uint32_t *ptr1 = ptr0 + ldin; + const uint32_t *ptr2 = ptr1 + ldin; + const uint32_t *ptr3 = ptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + uint32_t *outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + + vst1q_u32(outptr_row_col, vr00); + vst1q_u32(outptr_row_col + 4, vr01); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + + vst1q_u32(outptr_row_col + 8, vr10); + vst1q_u32(outptr_row_col + 12, vr11); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + + vst1q_u32(outptr_row_col + 16, vr20); + vst1q_u32(outptr_row_col + 20, vr21); + + vst1q_u32(outptr_row_col + 24, vr30); + vst1q_u32(outptr_row_col + 28, vr31); + + ptr0 += 8; + ptr1 += 8; + ptr2 += 8; + ptr3 += 8; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + + uint32x4_t vr00_1 = vbslq_u32(vmask1, vr00, vzero); + uint32x4_t vr01_1 = vbslq_u32(vmask2, vr01, vzero); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + + vst1q_u32(outptr_row_col, vr00_1); + vst1q_u32(outptr_row_col + 4, vr01_1); + + uint32x4_t vr10_1 = vbslq_u32(vmask1, vr10, vzero); + uint32x4_t vr11_1 = vbslq_u32(vmask2, vr11, vzero); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + + vst1q_u32(outptr_row_col + 8, vr10_1); + vst1q_u32(outptr_row_col + 12, vr11_1); + + uint32x4_t vr20_1 = vbslq_u32(vmask1, vr20, vzero); + uint32x4_t vr21_1 = vbslq_u32(vmask2, vr21, vzero); + + uint32x4_t vr30_1 = vbslq_u32(vmask1, vr30, vzero); + uint32x4_t vr31_1 = vbslq_u32(vmask2, vr31, vzero); + + vst1q_u32(outptr_row_col + 16, vr20_1); + vst1q_u32(outptr_row_col + 20, vr21_1); + vst1q_u32(outptr_row_col + 24, vr30_1); + vst1q_u32(outptr_row_col + 28, vr31_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t *ptr0 = inptr + y * ldin; + uint32_t *outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + vst1q_u32(outptr_row_col, vr0); + vst1q_u32(outptr_row_col + 4, vr1); + + ptr0 += 8; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + + uint32x4_t vr0_1 = vbslq_u32(vmask1, vr0, vzero); + uint32x4_t vr1_1 = vbslq_u32(vmask2, vr1, vzero); + + vst1q_u32(outptr_row_col, vr0_1); + vst1q_u32(outptr_row_col + 4, vr1_1); + } + } +} + +#else // __aarch64__ +void prepackA_6x8(float* out, const float* in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t* dout = reinterpret_cast(out); + const uint32_t* inptr = reinterpret_cast(in); + + uint32_t* outptr = dout; + + //! data A is not transposed, transpose A to k * 6 + for (int y = m0; y < mmax; y += 6) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + const uint32_t* inptr4 = inptr3 + ldin; + const uint32_t* inptr5 = inptr4 + ldin; + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 5) >= mmax) { + switch ((y + 5) - mmax) { + case 4: + inptr1 = zerobuff; + case 3: + inptr2 = zerobuff; + case 2: + inptr3 = zerobuff; + case 1: + inptr4 = zerobuff; + case 0: + inptr5 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 6 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + "vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, " + "q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n" + "vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, " + "q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n" + + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + "vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; " + "q10=r44,r54,r45,r55\n" + + "vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; " + "q4=r01,r11,r21,r31\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n" + "vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n" + "vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n" + "vst1.32 {d17}, [%[outptr]]! @ write d16(q8,high),r41,r51\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + "vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; " + "q11=r46,r56,r47,r57\n" + + "vswp d3, d10 @ swap d3, d10, " + "q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n" + "vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n" + "vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n" + "vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n" + + "vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; " + "q6=r05,r15,r25,r35\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n" + "vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n" + "vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n" + "vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n" + + "vswp d7, d14 @ swap d7, d14, " + "q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n" + "vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n" + "vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n" + "vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + } + } +} + +void prepackA_trans_6x8(float* out, const float* in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax) { + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = + reinterpret_cast(in) + k0 * ldin + m0; + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 6 * (x_len / 6); + int right_pad = 6 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + uint32_t* outptr_row = outptr; + int stride_out = 6 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t* ptr0 = inptr + y * ldin; + const uint32_t* ptr1 = ptr0 + ldin; + const uint32_t* ptr2 = ptr1 + ldin; + const uint32_t* ptr3 = ptr2 + ldin; + + uint32_t* outptr_row_col = outptr_row + y * 6; + int i = 0; + for (; i < x_len - 5; i += 6) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vld1.32 {d4-d6}, [%[ptr1]]! @ load r1, 6 elements\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d2}, [%[ptr2]]! @ load r2, 6 elements\n" + "vld1.32 {d4-d6}, [%[ptr3]]! @ load r3, 6 elements\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vld1.32 {d4-d6}, [%[ptr1]]! @ load r1, 6 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d6, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d2}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d6}, [%[ptr3]]! @ load r3, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d6, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t* ptr0 = inptr + y * ldin; + uint32_t* outptr_row_col = outptr_row + y * 6; + int i = 0; + for (; i < x_len - 5; i += 6) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + } +} + +void prepackA_4x8(float* out, const float* in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t* dout = reinterpret_cast(out); + const uint32_t* inptr = reinterpret_cast(in); + + uint32_t* outptr = dout; + //! data A is not transposed, transpose A to k * 4 + for (int y = m0; y < mmax; y += 4) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 3) >= mmax) { + switch ((y + 3) - mmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 4 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + + "vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; " + "q4=r01,r11,r21,r31\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n" + "vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + + "vswp d3, d10 @ swap d3, d10, " + "q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n" + "vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n" + + "vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; " + "q6=r05,r15,r25,r35\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n" + "vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n" + + "vswp d7, d14 @ swap d7, d14, " + "q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n" + "vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } +} + +void prepackA_trans_4x8(float* out, const float* in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax) { + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = + reinterpret_cast(in) + k0 * ldin + m0; + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 4 * (x_len / 4); + int right_pad = 4 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + uint32_t* outptr_row = outptr; + int stride_out = 4 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); +// uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask_buffer + 4), +// vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t* ptr0 = inptr + y * ldin; + const uint32_t* ptr1 = ptr0 + ldin; + const uint32_t* ptr2 = ptr1 + ldin; + const uint32_t* ptr3 = ptr2 + ldin; + + uint32_t* outptr_row_col = outptr_row + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vld1.32 {d2-d3}, [%[ptr1]]! @ load r1, 4 elements\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d4-d5}, [%[ptr2]]! @ load r2, 4 elements\n" + "vld1.32 {d6-d7}, [%[ptr3]]! @ load r3, 4 elements\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vld1.32 {d2-d3}, [%[ptr1]]! @ load r1, 4 elements\n" + "vld1.32 {d4-d5}, [%[ptr2]]! @ load r2, 4 elements\n" + "vld1.32 {d6-d7}, [%[ptr3]]! @ load r3, 4 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t* ptr0 = inptr + y * ldin; + uint32_t* outptr_row_col = outptr_row + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + } +} + +#endif // __aarch64__ + +/** +* \brief input data is transpose +* for arm-v7a, transform data to block x k x 8 layout +* for arm-v8a, transform data to block x k x 12 layout +*/ +#ifdef __aarch64__ +void loadb(float *out, const float *in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + uint32_t *outptr = reinterpret_cast(out); + const uint32_t *inptr = + reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 12 * (x_len / 12); + int right_pad = 12 - right_remain; + const size_t copy_len_remain = sizeof(float) * right_remain; + const size_t copy_len_pad = sizeof(float) * right_pad; + const size_t size_ldin = sizeof(float) * ldin; + + uint32_t *outptr_row = outptr; + int stride_out = 12 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + uint32x4_t vmask3 = + vcltq_u32(vld1q_u32(mask_buffer + 8), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t *ptr0 = inptr + y * ldin; + const uint32_t *ptr1 = ptr0 + ldin; + const uint32_t *ptr2 = ptr1 + ldin; + const uint32_t *ptr3 = ptr2 + ldin; + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + uint32_t *outptr_row_col = outptr_row + y * 12; + + int i = 0; + for (; i < x_len - 11; i += 12) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + uint32x4_t vr02 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + uint32x4_t vr12 = vld1q_u32(ptr1 + 8); + + vst1q_u32(outptr_row_col, vr00); + vst1q_u32(outptr_row_col + 4, vr01); + vst1q_u32(outptr_row_col + 8, vr02); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + uint32x4_t vr22 = vld1q_u32(ptr2 + 8); + + vst1q_u32(outptr_row_col + 12, vr10); + vst1q_u32(outptr_row_col + 16, vr11); + vst1q_u32(outptr_row_col + 20, vr12); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + uint32x4_t vr32 = vld1q_u32(ptr3 + 8); + + vst1q_u32(outptr_row_col + 24, vr20); + vst1q_u32(outptr_row_col + 28, vr21); + vst1q_u32(outptr_row_col + 32, vr22); + + vst1q_u32(outptr_row_col + 36, vr30); + vst1q_u32(outptr_row_col + 40, vr31); + vst1q_u32(outptr_row_col + 44, vr32); + + ptr0 += 12; + ptr1 += 12; + ptr2 += 12; + ptr3 += 12; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + uint32x4_t vr02 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + uint32x4_t vr12 = vld1q_u32(ptr1 + 8); + + uint32x4_t vr00_1 = vbslq_u32(vmask1, vr00, vzero); + uint32x4_t vr01_1 = vbslq_u32(vmask2, vr01, vzero); + uint32x4_t vr02_1 = vbslq_u32(vmask3, vr02, vzero); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + uint32x4_t vr22 = vld1q_u32(ptr2 + 8); + + vst1q_u32(outptr_row_col, vr00_1); + vst1q_u32(outptr_row_col + 4, vr01_1); + vst1q_u32(outptr_row_col + 8, vr02_1); + + uint32x4_t vr10_1 = vbslq_u32(vmask1, vr10, vzero); + uint32x4_t vr11_1 = vbslq_u32(vmask2, vr11, vzero); + uint32x4_t vr12_1 = vbslq_u32(vmask3, vr12, vzero); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + uint32x4_t vr32 = vld1q_u32(ptr3 + 8); + + vst1q_u32(outptr_row_col + 12, vr10_1); + vst1q_u32(outptr_row_col + 16, vr11_1); + vst1q_u32(outptr_row_col + 20, vr12_1); + + uint32x4_t vr20_1 = vbslq_u32(vmask1, vr20, vzero); + uint32x4_t vr21_1 = vbslq_u32(vmask2, vr21, vzero); + uint32x4_t vr22_1 = vbslq_u32(vmask3, vr22, vzero); + + uint32x4_t vr30_1 = vbslq_u32(vmask1, vr30, vzero); + uint32x4_t vr31_1 = vbslq_u32(vmask2, vr31, vzero); + uint32x4_t vr32_1 = vbslq_u32(vmask3, vr32, vzero); + + vst1q_u32(outptr_row_col + 24, vr20_1); + vst1q_u32(outptr_row_col + 28, vr21_1); + vst1q_u32(outptr_row_col + 32, vr22_1); + + vst1q_u32(outptr_row_col + 36, vr30_1); + vst1q_u32(outptr_row_col + 40, vr31_1); + vst1q_u32(outptr_row_col + 44, vr32_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t *ptr0 = inptr + y * ldin; + uint32_t *outptr_row_col = outptr_row + y * 12; + + int i = 0; + for (; i < x_len - 11; i += 12) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + uint32x4_t vr2 = vld1q_u32(ptr0 + 8); + vst1q_u32(outptr_row_col, vr0); + vst1q_u32(outptr_row_col + 4, vr1); + vst1q_u32(outptr_row_col + 8, vr2); + + ptr0 += 12; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + uint32x4_t vr2 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr0_1 = vbslq_u32(vmask1, vr0, vzero); + uint32x4_t vr1_1 = vbslq_u32(vmask2, vr1, vzero); + uint32x4_t vr2_1 = vbslq_u32(vmask3, vr2, vzero); + + vst1q_u32(outptr_row_col, vr0_1); + vst1q_u32(outptr_row_col + 4, vr1_1); + vst1q_u32(outptr_row_col + 8, vr2_1); + } + } +} + +void loadb_trans(float *out, const float *in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + uint32_t *outptr = reinterpret_cast(out); + const uint32_t *inptr = reinterpret_cast(in); + + //! data B is not transposed, transpose B to k * 12 + for (int y = n0; y < nmax; y += 12) { + const uint32_t *inptr0 = inptr + y * ldin + k0; + const uint32_t *inptr1 = inptr0 + ldin; + const uint32_t *inptr2 = inptr1 + ldin; + const uint32_t *inptr3 = inptr2 + ldin; + const uint32_t *inptr4 = inptr3 + ldin; + const uint32_t *inptr5 = inptr4 + ldin; + const uint32_t *inptr6 = inptr5 + ldin; + const uint32_t *inptr7 = inptr6 + ldin; + const uint32_t *inptr8 = inptr7 + ldin; + const uint32_t *inptr9 = inptr8 + ldin; + const uint32_t *inptr10 = inptr9 + ldin; + const uint32_t *inptr11 = inptr10 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + "prfm pldl1keep, [%[ptr8]] \n" + "prfm pldl1keep, [%[ptr8], #64] \n" + "prfm pldl1keep, [%[ptr9]] \n" + "prfm pldl1keep, [%[ptr9], #64] \n" + "prfm pldl1keep, [%[ptr10]] \n" + "prfm pldl1keep, [%[ptr10], #64] \n" + "prfm pldl1keep, [%[ptr11]] \n" + "prfm pldl1keep, [%[ptr11], #64] \n" + : + : [ptr0] "r"(inptr0), [ptr1] "r"(inptr1), [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3), [ptr4] "r"(inptr4), [ptr5] "r"(inptr5), + [ptr6] "r"(inptr6), [ptr7] "r"(inptr7), [ptr8] "r"(inptr8), + [ptr9] "r"(inptr9), [ptr10] "r"(inptr10), [ptr11] "r"(inptr11) + : "memory"); + + int x = x_len; + + //! cope with row index exceed real size, set to zero buffer + if ((y + 11) >= nmax) { + switch ((y + 11) - nmax) { + case 10: + inptr1 = zerobuff; + case 9: + inptr2 = zerobuff; + case 8: + inptr3 = zerobuff; + case 7: + inptr4 = zerobuff; + case 6: + inptr5 = zerobuff; + case 5: + inptr6 = zerobuff; + case 4: + inptr7 = zerobuff; + case 3: + inptr8 = zerobuff; + case 2: + inptr9 = zerobuff; + case 1: + inptr10 = zerobuff; + case 0: + inptr11 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3 + "LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3 + "LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3 + "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 + "prfm pldl1keep, [%[inptr0], #128] \n" + "LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3 + "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 + "LDP q8, q9, [%[inptr4]], #32\n" + "LDP q10, q11, [%[inptr5]], #32\n" + "LDP q12, q13, [%[inptr6]], #32\n" + "ZIP1 v18.4s, v8.4s, v12.4s\n" + "prfm pldl1keep, [%[inptr1], #128]\n" + "LDP q14, q15, [%[inptr7]], #32\n" + "ZIP1 v19.4s, v10.4s, v14.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + "prfm pldl1keep, [%[inptr2], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "LDP q24, q25, [%[inptr8]], #32\n" // q24=A0A1A2A3 + "LDP q26, q27, [%[inptr9]], #32\n" // q26=B0B1B2B3 + "LDP q28, q29, [%[inptr10]], #32\n" // q28=C0C1C2C3 + "LDP q30, q31, [%[inptr11]], #32\n" // q30=D0D1D2D3 + "prfm pldl1keep, [%[inptr3], #128]\n" + "prfm pldl1keep, [%[inptr4], #128]\n" + "ZIP1 v16.4s, v24.4s, v28.4s\n" // q16=A0C0A1C1 + "ZIP1 v17.4s, v26.4s, v30.4s\n" // q17=B0D0B1D1 + "STP q20, q21, [%[outptr]], #32\n" // Write back the first + // element of each source + "ZIP1 v18.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + "ZIP2 v19.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + + "ZIP2 v16.4s, v0.4s, v4.4s\n" + "prfm pldl1keep, [%[inptr5], #128]\n" + "ZIP2 v17.4s, v2.4s, v6.4s\n" + "STR q18, [%[outptr]], #16\n" // Write back the second element + // of each source + + "STP q22, q23, [%[outptr]], #32\n" // Write back the second + // element of each source + "ZIP2 v18.4s, v8.4s, v12.4s\n" + "prfm pldl1keep, [%[inptr6], #128]\n" + "STR q19, [%[outptr]], #16\n" // Write back the second element + // of each source + "ZIP2 v19.4s, v10.4s, v14.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr7], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v24.4s, v28.4s\n" // q16=A0C0A1C1 + "ZIP2 v17.4s, v26.4s, v30.4s\n" // q17=B0D0B1D1 + "prfm pldl1keep, [%[inptr8], #128]\n" + "STP q20, q21, [%[outptr]], #32\n" // Third element + "ZIP1 v18.4s, v16.4s, v17.4s\n" + "ZIP2 v19.4s, v16.4s, v17.4s\n" + + "ZIP1 v16.4s, v1.4s, v5.4s\n" + "prfm pldl1keep, [%[inptr9], #128]\n" + "ZIP1 v17.4s, v3.4s, v7.4s\n" + "STR q18, [%[outptr]], #16\n" // Write back the second element + // of each source + + "STP q22, q23, [%[outptr]], #32\n" // Fourth element + "ZIP1 v18.4s, v9.4s, v13.4s\n" + "prfm pldl1keep, [%[inptr10], #128]\n" + "STR q19, [%[outptr]], #16\n" // Write back the second element + // of each source + "ZIP1 v19.4s, v11.4s, v15.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr11], #128]\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP1 v16.4s, v25.4s, v29.4s\n" + "ZIP1 v17.4s, v27.4s, v31.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Fifth element + "ZIP1 v18.4s, v16.4s, v17.4s\n" + "ZIP2 v19.4s, v16.4s, v17.4s\n" + + "ZIP2 v16.4s, v1.4s, v5.4s\n" + "ZIP2 v17.4s, v3.4s, v7.4s\n" + "STR q18, [%[outptr]], #16\n" + + "STP q22, q23, [%[outptr]], #32\n" // Sixth element + "ZIP2 v18.4s, v9.4s, v13.4s\n" + "STR q19, [%[outptr]], #16\n" // Sixth element + + "ZIP2 v19.4s, v11.4s, v15.4s\n" + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v25.4s, v29.4s\n" + "ZIP2 v17.4s, v27.4s, v31.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Seventh element + + "ZIP1 v18.4s, v16.4s, v17.4s\n" + "ZIP2 v19.4s, v16.4s, v17.4s\n" + "STR q18, [%[outptr]], #16\n" + "STP q22, q23, [%[outptr]], #32\n" // Eighth element + "STR q19, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), + [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + *outptr++ = *inptr8++; + *outptr++ = *inptr9++; + *outptr++ = *inptr10++; + *outptr++ = *inptr11++; + } + } +} + +#else // __aarch64__ +void loadb(float* out, const float* in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = + reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 8 * (x_len / 8); + int right_pad = 8 - right_remain; + const size_t copy_len_remain = sizeof(float) * right_remain; + const size_t copy_len_pad = sizeof(float) * right_pad; + const size_t size_ldin = sizeof(float) * ldin; + + uint32_t* outptr_row = outptr; + int stride_out = 8 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t* ptr0 = inptr + y * ldin; + const uint32_t* ptr1 = ptr0 + ldin; + const uint32_t* ptr2 = ptr1 + ldin; + const uint32_t* ptr3 = ptr2 + ldin; + uint32_t* outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + //"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + //"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t* ptr0 = inptr + y * ldin; + uint32_t* outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + } +} + +void loadb_trans(float* out, const float* in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = reinterpret_cast(in); + //! data B is not transposed, transpose B to k * 8 + for (int y = n0; y < nmax; y += 8) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + const uint32_t* inptr4 = inptr3 + ldin; + const uint32_t* inptr5 = inptr4 + ldin; + const uint32_t* inptr6 = inptr5 + ldin; + const uint32_t* inptr7 = inptr6 + ldin; + + int x = x_len; + + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= nmax) { + switch ((y + 7) - nmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 8 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vst1.32 {d0}, [%[outptr]]! @ write d0(q0,low),r00,r10\n" + + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + "vst1.32 {d8}, [%[outptr]]! @ write d8(q4,low),r20,r30\n" + + "vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, " + "q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n" + "vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, " + "q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n" + "vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; " + "q10=r44,r54,r45,r55\n" + "vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n" + + "vld4.32 {d24-d27}, [%[inptr6]]! @ zip load r6, " + "q12,q13=r60,r64,r61,r65,r62,r66,r63,r67\n" + "vld4.32 {d28-d31}, [%[inptr7]]! @ zip load r7, " + "q14,q15=r70,r74,r71,r75,r72,r76,r73,r77\n" + "vtrn.32 q12, q14 @ trans data:q12=r60,r70,r61,r71; " + "q14=r64,r74,r65,r75\n" + "vst1.32 {d24}, [%[outptr]]! @ write d24(q8,low),r60,r70\n" + + //"pld [%[inptr0], #128] @ preload r0 data to cache, fill + // pipeline\n" + "vst1.32 {d1}, [%[outptr]]! @ write d1(q0,high),r01,r11\n" + "vst1.32 {d9}, [%[outptr]]! @ write d9(q4,high),r21,r31\n" + "vst1.32 {d17}, [%[outptr]]! @ write d17(q8,high),r41,r51\n" + "vst1.32 {d25}, [%[outptr]]! @ write d25(q12,high),r61,r71\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vst1.32 {d2}, [%[outptr]]! @ write d2(q1,low),r02,r12\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + "vst1.32 {d10}, [%[outptr]]! @ write d10(q5,low),r22,r32\n" + "vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; " + "q11=r46,r56,r47,r57\n" + "vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n" + "vtrn.32 q13, q15 @ trans data:q13=r62,r72,r63,r73; " + "q15=r66,r76,r67,r77\n" + "vst1.32 {d26}, [%[outptr]]! @ write d18(q9,low),r62,r72\n" + + //"pld [%[inptr1], #128] @ preload r1 data to cache, fill + // pipeline\n" + "vst1.32 {d3}, [%[outptr]]! @ write d3(q1,high),r03,r13\n" + "vst1.32 {d11}, [%[outptr]]! @ write d11(q5,high),r23,r33\n" + "vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n" + "vst1.32 {d27}, [%[outptr]]! @ write d27(q13,high),r63,r73\n" + + //"pld [%[inptr2], #128] @ preload r2 data to cache, fill + // pipeline\n" + "vst1.32 {d4}, [%[outptr]]! @ write d4(q2,low),r04,r14\n" + "vst1.32 {d12}, [%[outptr]]! @ write d12(q6,low),r24,r34\n" + "vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n" + "vst1.32 {d28}, [%[outptr]]! @ write d28(q14,low),r64,r74\n" + + //"pld [%[inptr3], #128] @ preload r3 data to cache, fill + // pipeline\n" + "vst1.32 {d5}, [%[outptr]]! @ write d5(q2,high),r05,r15\n" + "vst1.32 {d13}, [%[outptr]]! @ write d13(q6,high),r25,r35\n" + "vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n" + "vst1.32 {d29}, [%[outptr]]! @ write d29(q14,high),r65,r75\n" + + //"pld [%[inptr4], #128] @ preload r4 data to cache, fill + // pipeline\n" + "vst1.32 {d6}, [%[outptr]]! @ write d6(q3,low),r06,r16\n" + "vst1.32 {d14}, [%[outptr]]! @ write d14(q7,low),r26,r36\n" + "vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n" + "vst1.32 {d30}, [%[outptr]]! @ write d30(q15,low),r66,r76\n" + + //"pld [%[inptr5], #128] @ preload r5 data to cache, fill + // pipeline\n" + "vst1.32 {d7}, [%[outptr]]! @ write d7(q3,high),r07,r17\n" + "vst1.32 {d15}, [%[outptr]]! @ write d15(q7,high),r27,r37\n" + "vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n" + "vst1.32 {d31}, [%[outptr]]! @ write d31(q15,high),r67,r77\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } +} + +#endif // __aarch64__ + +#ifdef __aarch64__ +void sgemm_conv_8x12(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext *ctx) { + size_t l2_cache = + ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024; + float *workspace = ctx->workspace_data(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = (l2_cache - (MBLOCK * K)) / (sizeof(float) * (K + MBLOCK)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + // unroll 2 loop + int tail_pre = (K & (KBLOCK - 1)); + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float *b_pannel = workspace; + if (transB) { + loadb_trans(b_pannel, B, K, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK) { + unsigned int ymax = y + MBLOCK; + if (ymax > M) { + ymax = M; + } + + float bias_local[8] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + bias_local[6] = bias[y + 6]; + bias_local[7] = bias[y + 7]; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + float cout4[NBLOCK]; + float cout5[NBLOCK]; + float cout6[NBLOCK]; + float cout7[NBLOCK]; + + float *c_ptr0 = C + y * N + x0; + float *c_ptr1 = c_ptr0 + N; + float *c_ptr2 = c_ptr1 + N; + float *c_ptr3 = c_ptr2 + N; + float *c_ptr4 = c_ptr3 + N; + float *c_ptr5 = c_ptr4 + N; + float *c_ptr6 = c_ptr5 + N; + float *c_ptr7 = c_ptr6 + N; + + float *pout0 = c_ptr0; + float *pout1 = c_ptr1; + float *pout2 = c_ptr2; + float *pout3 = c_ptr3; + float *pout4 = c_ptr4; + float *pout5 = c_ptr5; + float *pout6 = c_ptr6; + float *pout7 = c_ptr7; + + const float *a_ptr_l = A_packed + y * K; + const float *b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { + case 6: + c_ptr1 = cout1; + case 5: + c_ptr2 = cout2; + case 4: + c_ptr3 = cout3; + case 3: + c_ptr4 = cout4; + case 2: + c_ptr5 = cout5; + case 1: + c_ptr6 = cout6; + case 0: + c_ptr7 = cout7; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + pout6 = c_ptr6; + pout7 = c_ptr7; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + c_ptr6 = cout6; + c_ptr7 = cout7; + } + const float *a_ptr = a_ptr_l; + int tail = tail_pre; + int k = k_pre; + + asm volatile( + // Initialize result registers, load initial operands, prime + // prefetches. + "ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "dup v8.4s, v2.s[0]\n" /* out0 = 0 */ + "dup v9.4s, v2.s[0]\n" /* out1 = 0*/ + "dup v10.4s, v2.s[0]\n" /* out2 = 0*/ + "dup v11.4s, v2.s[1]\n" /* out3 = 0*/ + "dup v12.4s, v2.s[1]\n" /* out4 = 0*/ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ + "dup v13.4s, v2.s[1]\n" /* out5 = 0*/ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ + "dup v14.4s, v2.s[2]\n" /* out6 = 0*/ + "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ + "dup v15.4s, v2.s[2]\n" /* out7 = 0*/ + "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ + "dup v16.4s, v2.s[2]\n" /* out8 = 0*/ + "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ + "dup v17.4s, v2.s[3]\n" /* out9 = 0*/ + "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ + "dup v18.4s, v2.s[3]\n" /* out10 = 0*/ + "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ + "dup v19.4s, v2.s[3]\n" /* out11 = 0*/ + "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/ + "dup v20.4s, v3.s[0]\n" /* out12 = 0*/ + "prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/ + "dup v21.4s, v3.s[0]\n" /* out13 = 0*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ + "dup v22.4s, v3.s[0]\n" /* out14 = 0*/ + "dup v23.4s, v3.s[1]\n" /* out15 = 0*/ + "dup v24.4s, v3.s[1]\n" /* out16 = 0*/ + "dup v25.4s, v3.s[1]\n" /* out17 = 0*/ + "dup v26.4s, v3.s[2]\n" /* out18 = 0*/ + "dup v27.4s, v3.s[2]\n" /* out19 = 0*/ + "dup v28.4s, v3.s[2]\n" /* out20 = 0*/ + "dup v29.4s, v3.s[3]\n" /* out21 = 0*/ + "dup v30.4s, v3.s[3]\n" /* out22 = 0*/ + "dup v31.4s, v3.s[3]\n" /* out23 = 0*/ + "cbz %w[k], 2f\n" /* check loop count > 0 */ + /* main loop */ + /* unrool 0*/ + "1:\n" /* main loop */ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = + q4 */ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q4 */ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7 */ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = + q4 */ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = + q4 */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4 */ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = + q4 */ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = + q4 */ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = + q4 */ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = + q4 */ + + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = + q5 */ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = + q5 */ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q5*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q5*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q5*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q5*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q5*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q5*/ + + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5 */ + + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q6*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q6*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q6*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q6*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q6*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q6*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q6*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q6*/ + + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ + + /* unrool 1 */ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q7 */ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q7 */ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q7 */ + "prfm pldl1keep, [%[a_ptr], #256]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q7 */ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q7 */ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7 + */ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q7 */ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q7 */ + + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ + + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q4 */ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = + q4 */ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q4*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q4*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q4*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q4*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q4*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q4*/ + + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q5*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q5*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q5*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q5*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q5*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q5*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q5*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ + /* unrool 2*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = + q6 */ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q6 */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = + q6*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = + q6*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = + q6*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = + q6*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = + q6*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = + q6*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = + q7*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = + q7*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q7*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q7*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q7*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q7*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q7*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q7*/ + + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ + + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q4*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q4*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q4*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q4*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q4*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q4*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q4*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q4*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ + /* unrool 3*/ + "fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = + q6*/ + "prfm pldl1keep, [%[a_ptr], #256]\n" + "fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "subs %w[k], %w[k], #1\n" /* loop count - 1*/ + "fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "bne 1b\n" + /* Target to use when K is 1 or 2 (i.e. zero iterations of main + loop)*/ + "2:\n" /* process tail*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "beq 3f\n" /*jump to tail = 1*/ + /* final unrool 0*/ + /* unrool 0, tail > 1*/ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = + q4*/ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q4*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = + q4*/ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = + q4*/ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q2, q3*/ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = + q4*/ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = + q4*/ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = + q4*/ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = + q4*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = + q5*/ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = + q5*/ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q5*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q5*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q5*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q5*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q5*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5*/ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q6*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q6*/ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q6*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q6*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q6*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q6*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q6*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q6*/ + "beq 4f\n" /*jump to tail = 2*/ + /* unrool 1, tail > 2*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q7*/ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q7*/ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q7*/ + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q7*/ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q7*/ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q7*/ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q7*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7*/ + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q4*/ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = + q4*/ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q4*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q4*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q4*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q4*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q4*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q4*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q5*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q5*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q5*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q5*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q5*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q5*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q5*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q5*/ + "beq 5f\n" /*jump to tail = 3*/ + /* unrool 2, tail = 4*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = + q6*/ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q6*/ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = + q6*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = + q6*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = + q6*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = + q6*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = + q6*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = + q6*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = + q7*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = + q7*/ + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q7*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q7*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q7*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q7*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q7*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q7*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q4*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q4*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q4*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q4*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q4*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q4*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q4*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q4*/ + /* unrool 3, tail = 4*/ + "fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 = + q6*/ + "fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==1 final tail*/ + "3: \n" /* tail=1*/ + "ldr q6, [%[b_ptr]], #16\n" /* load b2 to q6*/ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 = + q6*/ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==2 final tail*/ + "4:\n" /* tail = 2*/ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 = + q6*/ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==3 final tail*/ + "5:\n" /* tail = 3*/ + "ldr q4, [%[b_ptr]], #16\n" /* load b2, b0 to q4*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 = + q6*/ + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "11: \n" /* check if relu */ + "cbz %w[relu], 12f\n" /* skip relu */ + "movi v2.4s, #0\n" /* for relu*/ + "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ + "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ + "fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ + "fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ + "fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ + "fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ + "fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ + "fmax v15.4s, v15.4s, v2.4s\n" /* relu*/ + "fmax v16.4s,v16.4s,v2.4s\n" /* relu*/ + "fmax v17.4s,v17.4s,v2.4s\n" /* relu*/ + "fmax v18.4s, v18.4s, v2.4s\n" /* relu*/ + "fmax v19.4s, v19.4s, v2.4s\n" /* relu*/ + "fmax v20.4s, v20.4s, v2.4s\n" /* relu*/ + "fmax v21.4s, v21.4s, v2.4s\n" /* relu*/ + "fmax v22.4s, v22.4s, v2.4s\n" /* relu*/ + "fmax v23.4s, v23.4s, v2.4s\n" /* relu*/ + "fmax v24.4s,v24.4s,v2.4s\n" /* relu*/ + "fmax v25.4s,v25.4s,v2.4s\n" /* relu*/ + "fmax v26.4s, v26.4s, v2.4s\n" /* relu*/ + "fmax v27.4s, v27.4s, v2.4s\n" /* relu*/ + "fmax v28.4s, v28.4s, v2.4s\n" /* relu*/ + "fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ + "fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ + "fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ + "12: \n" + "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ + "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ + "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ + "st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48\n" /* store r3 */ + "st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48\n" /* store r4 */ + "st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48\n" /* store r5 */ + "st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */ + "st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */ + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), + [tail] "+r"(tail), [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), + [c_ptr4] "+r"(c_ptr4), [c_ptr5] "+r"(c_ptr5), + [c_ptr6] "+r"(c_ptr6), [c_ptr7] "+r"(c_ptr7) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + *pout6++ = cout6[i]; + *pout7++ = cout7[i]; + } + } + } + } + } +} +#else // __aarch64__ +/** + * \brief gemm with ablock = 6, bblock = 8, output 6x8 + * @param A + * @param B + * @param C + * @param M + * @param N + * @param K + * @param threads + * @param workspace + */ +void sgemm_conv_6x8(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext* ctx) { + size_t l2_cache = + ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024; + auto* workspace = ctx->workspace_data(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = + (l2_cache - (MBLOCK_OTH * K)) / (sizeof(float) * (K + MBLOCK_OTH)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float* b_pannel = workspace; + if (transB) { + loadb_trans(b_pannel, B, K, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_OTH) { + unsigned int ymax = y + MBLOCK_OTH; + if (ymax > M) { + ymax = M; + } + float* c_ptr0 = C + y * N + x0; + float* c_ptr1 = c_ptr0 + N; + float* c_ptr2 = c_ptr1 + N; + float* c_ptr3 = c_ptr2 + N; + float* c_ptr4 = c_ptr3 + N; + float* c_ptr5 = c_ptr4 + N; + + float* pout0 = c_ptr0; + float* pout1 = c_ptr1; + float* pout2 = c_ptr2; + float* pout3 = c_ptr3; + float* pout4 = c_ptr4; + float* pout5 = c_ptr5; + + float bias_local[6] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + float cout4[NBLOCK]; + float cout5[NBLOCK]; + + const float* a_ptr_l = A_packed + y * K; + const float* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 5) >= ymax) { + switch ((y + 5) - ymax) { + case 4: + c_ptr1 = cout1; + case 3: + c_ptr2 = cout2; + case 2: + c_ptr3 = cout3; + case 1: + c_ptr4 = cout4; + case 0: + c_ptr5 = cout5; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + } + const float* a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + asm volatile( + // sgemm 6x8 + "vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n" + "pld [%[a_ptr]] @ preload a\n" + "vdup.i32 q12,d4[0] @ out40=0\n" + "pld [%[b_ptr]] @ preload b\n" + "vdup.i32 q13,d4[0] @ out41=0\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vdup.i32 q14,d4[1] @ out50=0\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vdup.i32 q15,d4[1] @ out51=0\n" + "pld [%[a_ptr], #128] @ preload a\n" + "vdup.i32 q4, d2[0] @ out00=0\n" + "pld [%[b_ptr], #128] @ preload b\n" + "vdup.i32 q5, d2[0] @ out01=0\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vdup.i32 q6, d2[1] @ out10=0\n" + "pld [%[a_ptr], #192] @ preload a\n" + "vdup.i32 q7, d2[1] @ out11=0\n" + "pld [%[b_ptr], #192] @ preload a\n" + "vdup.i32 q8, d3[0] @ out20=0\n" + "pld [%[a_ptr], #256] @ preload a\n" + "vdup.i32 q9, d3[0] @ out21=0\n" + "pld [%[b_ptr], #256] @ preload a\n" + "vdup.i32 q10,d3[1] @ out30=0\n" + "pld [%[b_ptr], #320] @ preload b\n" + "vdup.i32 q11,d3[1] @ out31=0\n" + "pld [%[b_ptr], #384] @ preload b\n" + "cmp %[k], #0 @ check weather k is " + "bigger than 0\n" + "beq 0f @ jump to tail\n" + "1: @ main loop for k\n" + /* Unroll 0*/ + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4, a5, and next " + "a0, " + "a1\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 1 */ + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + /*"pld [%[a_ptr], #64] @ preload a\n"*/ + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + /*"pld [%[b_ptr], #192]\n"*/ + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4, a5, a0, a1\n" + /* Unroll 2 */ + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + /*"pld [%[a_ptr], #240] @ preload\n"*/ + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + /*"pld [%[b_ptr], #208]\n"*/ + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 3 */ + "vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n" + "subs %[k], %[k], #1 @ k--\n" + "vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n" + "bne 1b @ jump to main " + "loop\n" + "0: @ process tail\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "beq 3f @ jump to tail = " + "1\n" + /* Unroll 0*/ + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4,5, a0, a1\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "beq 4f @ jump to tail==2\n" + /* Unroll 1*/ + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "beq 5f @ jump to tail==3\n" + /* Unroll 2 */ + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4,a5, a0,a1\n" + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 3*/ + "vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n" + "b 2f\n" + /* tails==1 final tail*/ + "3: @ tail=1\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d2}, [%[a_ptr] :64]! @ load a4,a5\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "b 2f @ jump to end\n" + /* tails==2 final tail*/ + "4: @ tail == 2\n" + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "b 2f @ jump to end\n" + /* tails==3 final tail*/ + "5: @ tail=3\n" + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d0}, [%[a_ptr] :64]! @ load a4,a5\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "2: @ check relu\n" + "cmp %[relu], #0 @ check if has relu\n" + "ble 6f @ skip relu if relu <= 0\n" + "vmov.u32 q0, #0 @ for relu\n" + "vmax.f32 q4, q4, q0 @ for relu\n" + "vmax.f32 q5, q5, q0 @ for relu\n" + "vmax.f32 q6, q6, q0 @ for relu\n" + "vmax.f32 q7, q7, q0 @ for relu\n" + "vmax.f32 q8, q8, q0 @ for relu\n" + "vmax.f32 q9, q9, q0 @ for relu\n" + "vmax.f32 q10, q10, q0 @ for relu\n" + "vmax.f32 q11, q11, q0 @ for relu\n" + "vmax.f32 q12, q12, q0 @ for relu\n" + "vmax.f32 q13, q13, q0 @ for relu\n" + "vmax.f32 q14, q14, q0 @ for relu\n" + "vmax.f32 q15, q15, q0 @ for relu\n" + "6: @ store result\n" + "vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n" + "vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n" + "vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n" + "vst1.32 {d20-d23}, [%[c_ptr3]]! @ store r3\n" + "vst1.32 {d24-d27}, [%[c_ptr4]]! @ store r4\n" + "vst1.32 {d28-d31}, [%[c_ptr5]]! @ store r5\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), [c_ptr4] "+r"(c_ptr4), + [c_ptr5] "+r"(c_ptr5), [k] "+r"(k), [tails] "+r"(tails) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "memory"); + + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + } + } + } + } + } +} + +void sgemm_conv_4x8(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext* ctx) { + size_t l2_cache = + ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024; + void* workspace = ctx->get_work_space(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = + (l2_cache - (MBLOCK_A73 * K)) / (sizeof(float) * (K + MBLOCK_A73)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float* b_pannel = static_cast(workspace); + if (transB) { + loadb_trans(b_pannel, B, K, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_A73) { + unsigned int ymax = y + MBLOCK_A73; + if (ymax > M) { + ymax = M; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + + float bias_local[4] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } + + float* c_ptr0 = C + y * N + x0; + float* c_ptr1 = c_ptr0 + N; + float* c_ptr2 = c_ptr1 + N; + float* c_ptr3 = c_ptr2 + N; + + float* pout0 = c_ptr0; + float* pout1 = c_ptr1; + float* pout2 = c_ptr2; + float* pout3 = c_ptr3; + + const float* a_ptr_l = A_packed + y * K; + const float* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + c_ptr1 = cout1; + case 1: + c_ptr2 = cout1; + case 0: + c_ptr3 = cout1; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + } + const float* a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + asm volatile( + "vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias\n" + "vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load a0~a3\n" + "vdup.32 q8, d4[0] @ add bias to out00\n" + "pld [%[a_ptr]] @ preload a, 64byte\n" + "vdup.32 q9, d4[0] @ add bias to out01\n" + "pld [%[b_ptr]] @ preload b\n" + "vdup.32 q10, d4[1] @ add bias to out10\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vdup.32 q11, d4[1] @ add bias to out11\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load b1\n" + "vdup.32 q12, d5[0] @ add bias to out20\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vdup.32 q13, d5[0] @ add bias to out21\n" + "pld [%[a_ptr], #128] @ preload a\n" + "vdup.32 q14, d5[1] @ add bias to out30\n" + "pld [%[b_ptr], #128] @ preload b\n" + "vdup.32 q15, d5[1] @ add bias to out31\n" + "pld [%[b_ptr], #192] @ preload b\n" + "cmp %[k], #0 @ check weather k is " + "bigger than 0\n" + "beq 0f @ jump to tail\n" + "1: @ main loop for k\n" + /* Unroll 0*/ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + /* Unroll 1 */ + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n" + /* Unroll 2 */ + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load next a0~a3\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + /* Unroll 3 */ + "vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n" + "subs %[k], %[k], #1 @ k--\n" + "bne 1b @ jump to main " + "loop\n" + "0: @ process tail\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "beq 3f @ jump to tail = " + "1\n" + /* Unroll 0*/ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + "beq 4f @ jump to tail==2\n" + /* Unroll 1 */ + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n" + "beq 5f @ jump to tail==3\n" + /* Unroll 2 */ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n" + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + /* Unroll 3 */ + "vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n" + "b 2f\n" + /* tails==1 final tail */ + "3: @ tail=1\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + /*aptr - 16 */ + "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" + "b 2f @ jump to end\n" + /* tails==2 final tail*/ + "4: @ tail == 2\n" + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out7 += b2 * a3\n" + "b 2f @ jump to end\n" + /* tails==3 final tail*/ + "5: @ tail=3\n" + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + /*aptr - 16*/ + "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" + "2: @ check relu\n" + "cmp %[relu], #0 @ check if has relu\n" + "ble 6f @ skip relu if relu <= 0\n" + "vmov.u32 q0, #0 @ for relu\n" + "vmax.f32 q8, q8, q0 @ for relu\n" + "vmax.f32 q9, q9, q0 @ for relu\n" + "vmax.f32 q10, q10, q0 @ for relu\n" + "vmax.f32 q11, q11, q0 @ for relu\n" + "vmax.f32 q12, q12, q0 @ for relu\n" + "vmax.f32 q13, q13, q0 @ for relu\n" + "vmax.f32 q14, q14, q0 @ for relu\n" + "vmax.f32 q15, q15, q0 @ for relu\n" + "6: @ store result\n" + "vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n" + "vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n" + "vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n" + "vst1.32 {d28-d31}, [%[c_ptr3]]! @ store r3\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), [k] "+r"(k), [tails] "+r"(tails) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "memory"); + + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + } + } + } + } + } +} +#endif // __aarch64__ + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/packed_sgemm.h b/paddle/fluid/lite/arm/math/packed_sgemm.h new file mode 100644 index 0000000000000000000000000000000000000000..160b432c8d80fe126ef44d137b85850b779024c5 --- /dev/null +++ b/paddle/fluid/lite/arm/math/packed_sgemm.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/lite/core/context.h" +#include "paddle/fluid/lite/core/cpu_info.h" +#include "paddle/fluid/lite/core/lite_tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +constexpr int MBLOCK = 8; +constexpr int NBLOCK = 12; +constexpr int KBLOCK = 4; +inline int get_hblock(ARMArch arch) { return MBLOCK; } +#else +constexpr int MBLOCK_A73 = 4; +constexpr int MBLOCK_OTH = 6; +constexpr int NBLOCK = 8; +constexpr int KBLOCK = 4; +inline int get_hblock(ARMArch arch) { + if (arch == kA73) { + return MBLOCK_A73; + } else { + return MBLOCK_OTH; + } +} +#endif // __aarch64__ + +void prepackA(float* out, const float* in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax, bool is_trans, + ARMContext* ctx); + +void prepackA(TensorLite* tout, const TensorLite& tin, int m, int k, int group, + bool is_trans, ARMContext* ctx); + +void sgemm_prepack(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool is_transB, ARMContext* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 25fdf32c1c07f27ecf885117d708442650fe3335..4e55ba74f970f3afab7f32a4d645b9893159b25f 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -23,7 +23,8 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_l cc_library(variable_lite SRCS variable.cc) cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite) cc_library(scope_lite SRCS scope.cc) -cc_library(context_lite SRCS context.cc DEPS any_lite) +cc_library(cpu_info_lite SRCS cpu_info.cc) +cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite ${tensor_lite}) cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) diff --git a/paddle/fluid/lite/core/context.cc b/paddle/fluid/lite/core/context.cc index fa01f1d3e19dc91c7d713691546eea13170a2e04..c2dfe2aba955ffc678f1a4cd9ab8a5b87ff751dc 100644 --- a/paddle/fluid/lite/core/context.cc +++ b/paddle/fluid/lite/core/context.cc @@ -12,8 +12,317 @@ // See the License for the specific language governing permissions and // limitations under the License. -// -// Created by chunwei on 19-2-22. -// - #include "paddle/fluid/lite/core/context.h" +#include "paddle/fluid/lite/core/cpu_info.h" + +#ifdef LITE_WITH_ANDROID +#include +#include +#endif +#if __APPLE__ +#include "TargetConditionals.h" +#if TARGET_OS_IPHONE +#include +#include +#include +#endif // TARGET_OS_IPHONE +#endif // __APPLE__ + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM + +void ARMContext::SetCache(int l1size, int l2size, int l3size) { + DeviceInfo& dev = DeviceInfo::Global(); + int cpu_count = arm_get_cpucount(); + dev.L1_cache_.resize(cpu_count); + dev.L2_cache_.resize(cpu_count); + dev.L3_cache_.resize(cpu_count); + for (int i = 0; i < cpu_count; ++i) { + dev.L1_cache_[i] = l1size; + dev.L2_cache_[i] = l2size; + dev.L3_cache_[i] = l3size; + } + workspace_.Resize({2 * (l1size + l2size)}); +} + +ARMContext::ARMContext() { + active_ids_ = {0}; + mode_ = LITE_POWER_HIGH; + DeviceInfo& dev = DeviceInfo::Global(); + workspace_.Resize( + {static_cast(dev.L2_cache_[active_ids_[0]] / sizeof(float))}); +#ifdef TARGET_IOS + arch_ = APPLE; // use 6x8 +#else + if (dev.big_core_ids_.size() > 0) { + arch_ = dev.archs_[dev.big_core_ids_[0]]; + } +#endif +} + +PowerMode ARMContext::mode() const { return mode_; } + +int ARMContext::threads() const { return active_ids_.size(); } + +ARMContext::ARMContext(const ARMContext& ctx) { + mode_ = ctx.mode_; + active_ids_ = ctx.active_ids_; + workspace_ = ctx.workspace_; + arch_ = ctx.arch_; + count_ = ctx.count_; +} + +ARMContext& ARMContext::operator=(const ARMContext& ctx) { + mode_ = ctx.mode_; + active_ids_ = ctx.active_ids_; + workspace_ = ctx.workspace_; + arch_ = ctx.arch_; + count_ = ctx.count_; + return *this; +} + +void ARMContext::BindDev() { +#ifdef USE_OPENMP + int num_threads = active_ids_.size(); + omp_set_num_threads(num_threads); +#ifdef LITE_WITH_ANDROID + std::vector ssarets; + for (int j = 0; j < num_threads; ++j) { + ssarets.push_back(0); + } +#pragma omp parallel for + for (int i = 0; i < num_threads; i++) { + ssarets[i] = set_sched_affinity(active_ids_); + } + for (int i = 0; i < num_threads; i++) { + if (ssarets[i] != 0) { + LOGE("set cpu affinity failed, cpuID: %d\n", active_ids_[i]); + return; + } + } +#endif // LITE_WITH_ANDROID +#else // USE_OPENMP +#ifdef LITE_WITH_ANDROID + std::vector cpuid1; + cpuid1.push_back(active_ids_[0]); + int ssaret = set_sched_affinity(cpuid1); + if (ssaret != 0) { + printf("set cpu affinity failed, cpuID: %d\n", active_ids_[0]); + return; + } +#endif // LITE_WITH_ANDROID +#endif // USE_OPENMP +} + +void ARMContext::SetRunMode(PowerMode mode, int threads) { + DeviceInfo& dev = DeviceInfo::Global(); + int big_core_size = dev.big_core_ids_.size(); + int small_core_size = dev.little_core_ids_.size(); + if (threads > big_core_size + small_core_size) { + threads = big_core_size + small_core_size; + } +#ifdef USE_OPENMP + count_++; + int shift_num = (count_ / 10) % big_core_size; + switch (mode) { + case LITE_POWER_FULL: + mode_ = mode; + active_ids_.clear(); + for (int i = 0; i < threads; ++i) { + if (i < big_core_size) { + active_ids_.push_back(dev.big_core_ids_[i]); + } else { + active_ids_.push_back(dev.little_core_ids_[i - big_core_size]); + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + case LITE_POWER_HIGH: + active_ids_.clear(); + if (big_core_size > 0) { + mode_ = LITE_POWER_HIGH; + if (threads > big_core_size) { + LOGE("threads: %d, exceed the big cores size: %d\n", threads, + big_core_size); + active_ids_ = dev.big_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.big_core_ids_[i]); + } + } + } else { + mode_ = LITE_POWER_LOW; + LOGE("HIGH POWER MODE is not support, switch to little cores\n"); + if (threads > small_core_size) { + active_ids_ = dev.little_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.little_core_ids_[i]); + } + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + case LITE_POWER_LOW: + active_ids_.clear(); + if (small_core_size > 0) { + mode_ = LITE_POWER_LOW; + if (threads > small_core_size) { + LOGW("threads: %d, exceed the little cores size: %d\n", threads, + small_core_size); + active_ids_ = dev.little_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.little_core_ids_[i]); + } + } + } else { + mode_ = LITE_POWER_HIGH; + LOGW("LOW POWER MODE is not support, switch to big cores\n"); + if (threads > big_core_size) { + active_ids_ = dev.big_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.big_core_ids_[i]); + } + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + case LITE_POWER_NO_BIND: + mode_ = LITE_POWER_NO_BIND; + active_ids_.clear(); + if (threads > dev.core_ids_.size()) { + active_ids_.resize(dev.core_ids_.size()); + } else { + active_ids_.resize(threads); + } + break; + case LITE_POWER_RAND_HIGH: + active_ids_.clear(); + if (big_core_size > 0) { + mode_ = LITE_POWER_RAND_HIGH; + if (threads > big_core_size) { + LOGW("threads: %d, exceed the big cores size: %d\n", threads, + big_core_size); + active_ids_ = dev.big_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back( + dev.big_core_ids_[(i + shift_num) % big_core_size]); + } + } + } else { + mode_ = LITE_POWER_LOW; + LOGW("HIGH POWER MODE is not support, switch to little cores\n"); + if (threads > small_core_size) { + active_ids_ = dev.little_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.little_core_ids_[i]); + } + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + case LITE_POWER_RAND_LOW: + active_ids_.clear(); + if (small_core_size > 0) { + mode_ = LITE_POWER_RAND_LOW; + if (threads > small_core_size) { + LOGW("threads: %d, exceed the little cores size: %d\n", threads, + small_core_size); + active_ids_ = dev.little_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back( + dev.little_core_ids_[(i + shift_num) % small_core_size]); + } + } + } else { + mode_ = LITE_POWER_HIGH; + LOGW("LOW POWER MODE is not support, switch to big cores\n"); + if (threads > big_core_size) { + active_ids_ = dev.big_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.big_core_ids_[i]); + } + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + } + //! fix multi-threads LITE_POWER_HIGH mode + if (mode_ == LITE_POWER_NO_BIND || threads > 1) { + int threads = active_ids_.size(); + omp_set_num_threads(threads); + } else { + if (check_online(active_ids_)) { + BindDev(); + } else { + LOG(ERROR) << "core id " << active_ids_[0] + << " is offline, switch to NO BIND MODE"; + int threads = active_ids_.size(); + omp_set_num_threads(threads); + } + } +#else + if (big_core_size > 0) { + active_ids_ = {dev.big_core_ids_[0]}; + } else { + active_ids_ = {0}; + } +#endif + //! alloc memory for sgemm in this context + int temp_mem_size = + DeviceInfo::Global().L2_cache_[active_ids_[0]] / sizeof(float); + workspace_.Resize({temp_mem_size}); + arch_ = DeviceInfo::Global().archs_[active_ids_[0]]; +} + +ARMArch ARMContext::arch() const { return arch_; } + +void ARMContext::SetArch(ARMArch arch) { arch_ = arch; } + +int ARMContext::l1_cache_size() const { + DeviceInfo& dev = DeviceInfo::Global(); + return dev.L1_cache_[active_ids_[0]]; +} + +int ARMContext::l2_cache_size() const { + DeviceInfo& dev = DeviceInfo::Global(); + return dev.L2_cache_[active_ids_[0]]; +} + +int ARMContext::l3_cache_size() const { + DeviceInfo& dev = DeviceInfo::Global(); + return dev.L3_cache_[active_ids_[0]]; +} + +bool ARMContext::ExtendWorkspace(DDimLite dims) { + auto count = dims.product(); + auto old = workspace_.dims(); + if (count == old.product()) { + return false; + } + + workspace_.Resize( + {static_cast(count + l2_cache_size() / sizeof(float))}); + return true; +} +#endif // LITE_WITH_ARM +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/context.h b/paddle/fluid/lite/core/context.h index 01253e0de1952793c3a896de5d97228b7f53e2ec..e09a03f55bdbbc314968ec0cd9d7109806e846d9 100644 --- a/paddle/fluid/lite/core/context.h +++ b/paddle/fluid/lite/core/context.h @@ -26,6 +26,8 @@ #include #include #include +#include "paddle/fluid/lite/core/cpu_info.h" +#include "paddle/fluid/lite/core/lite_tensor.h" #include "paddle/fluid/lite/core/target_wrapper.h" namespace paddle { @@ -34,7 +36,44 @@ namespace lite { struct HostContext {}; #ifdef LITE_WITH_ARM -struct ARMContext {}; + +struct ARMContext { + public: + ARMContext(); + ARMContext(PowerMode mode, int threads); + ARMContext(const ARMContext& ctx); + + ARMContext& operator=(const ARMContext& ctx); + + void SetRunMode(PowerMode mode, int threads); + void SetCache(int l1size, int l2size, int l3size); + void SetArch(ARMArch arch); + void BindDev(); + + PowerMode mode() const; + int threads() const; + ARMArch arch() const; + + template + T* workspace_data() { + return workspace_.mutable_data(); + } + + int l1_cache_size() const; + int l2_cache_size() const; + int l3_cache_size() const; + bool ExtendWorkspace(DDimLite dims); + + private: + // LITE_POWER_HIGH stands for using big cores, + // LITE_POWER_LOW stands for using small core, + // LITE_POWER_FULL stands for using all cores + ARMArch arch_; + PowerMode mode_; + std::vector active_ids_; + TensorLite workspace_; + int64_t count_{0}; +}; #endif #ifdef LITE_WITH_CUDA diff --git a/paddle/fluid/lite/core/cpu_info.cc b/paddle/fluid/lite/core/cpu_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..0336c2d7ac46a1948d756a064a3fe50a0a987f4d --- /dev/null +++ b/paddle/fluid/lite/core/cpu_info.cc @@ -0,0 +1,629 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/cpu_info.h" +#include + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM + +void DeviceInfo::get_info(DeviceInfo* dev) { + set_default_cache(dev); + dev->compute_core_num_ = arm_get_cpucount(); + dev->max_memory_ = arm_get_meminfo(); + +// get max freq +#ifdef LITE_WITH_ANDROID + std::vector max_freq(dev->compute_core_num_); + for (int i = 0; i < dev->compute_core_num_; ++i) { + max_freq[i] = get_max_freq_khz(i) / 1000; + } + std::string cpu_name = arm_get_cpu_name(); + if (get_cpu_info_from_name(dev, cpu_name) != true) { + arm_sort_cpuid_by_max_frequency(dev->compute_core_num_, &dev->core_ids_, + max_freq, &dev->cluster_ids_); + dev->big_core_ids_.clear(); + dev->little_core_ids_.clear(); + for (int i = 0; i < dev->cluster_ids_.size(); ++i) { + if (dev->cluster_ids_[i] == 0) { + dev->big_core_ids_.push_back(dev->core_ids_[i]); + } else { + dev->little_core_ids_.push_back(dev->core_ids_[i]); + } + } + arm_get_cpu_arch(&dev->archs_); + } + + LOG(INFO) << "ARM multiprocessors number: " << dev->compute_core_num_; + for (int i = 0; i < dev->compute_core_num_; ++i) { + LOG(INFO) << "ARM multiprocessors ID: " << dev->core_ids_[i] + << ", frequence: " << max_freq[i] + << ", cluster ID: " << dev->cluster_ids_[dev->core_ids_[i]] + << ", CPU ARCH: A" << dev->archs_[i]; + } + LOG(INFO) << "L1 DataCache size is: "; + for (int i = 0; i < dev->compute_core_num_; ++i) { + LOG(INFO) << dev->L1_cache_[i] / 1024 << " KB"; + } + LOG(INFO) << "L2 Cache size is: "; + for (int i = 0; i < dev->compute_core_num_; ++i) { + LOG(INFO) << dev->L2_cache_[i] / 1024 << " KB"; + } + LOG(INFO) << "Total memory: " << dev->max_memory_ << "KB"; + + dev->max_freq_ = max_freq[0]; + for (int j = 1; j < dev->compute_core_num_; ++j) { + if (dev->max_freq_ < max_freq[j]) { + dev->max_freq_ = max_freq[j]; + } + } +#elif defined(TARGET_IOS) + arm_get_cpu_arch(&dev->archs_); +#endif +} + +// cache_id : 0 -> L1, 1 -> L2, 2 -> L3 +void set_cache_info(DeviceInfo* cpu_info, int cache_id, int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + std::vector* cache; + switch (cache_id) { + case 0: + cache = &cpu_info->L1_cache_; + break; + case 1: + cache = &cpu_info->L2_cache_; + break; + case 2: + cache = &cpu_info->L3_cache_; + break; + default: + break; + } + int core_num = cpu_info->compute_core_num_; + cache->resize(core_num); + if (argc == 1) { + int cache_size = va_arg(arg_ptr, int); + for (int i = 0; i < core_num; ++i) { + (*cache)[i] = cache_size; + } + } else { + int big_core_num = cpu_info->big_core_ids_.size(); + int little_core_num = cpu_info->little_core_ids_.size(); + int big_core_cache_size = va_arg(arg_ptr, int); + int little_core_cache_size = va_arg(arg_ptr, int); + for (int i = 0; i < big_core_num; ++i) { + (*cache)[cpu_info->big_core_ids_[i]] = big_core_cache_size; + } + for (int i = 0; i < little_core_num; ++i) { + (*cache)[cpu_info->little_core_ids_[i]] = little_core_cache_size; + } + } + va_end(arg_ptr); +} + +void set_arch_info(DeviceInfo* cpu_info, int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + int core_num = cpu_info->compute_core_num_; + cpu_info->archs_.resize(core_num); + if (argc == 1) { + ARMArch arch = (ARMArch)va_arg(arg_ptr, int); + for (int i = 0; i < core_num; ++i) { + cpu_info->archs_[i] = arch; + } + } else { + ARMArch big_core_arch = (ARMArch)va_arg(arg_ptr, int); + ARMArch little_core_arch = (ARMArch)va_arg(arg_ptr, int); + int big_core_num = cpu_info->big_core_ids_.size(); + int little_core_num = cpu_info->little_core_ids_.size(); + for (int i = 0; i < big_core_num; ++i) { + cpu_info->archs_[cpu_info->big_core_ids_[i]] = big_core_arch; + } + for (int i = 0; i < little_core_num; ++i) { + cpu_info->archs_[cpu_info->little_core_ids_[i]] = little_core_arch; + } + } + va_end(arg_ptr); +} + +bool get_cpu_info_from_name(DeviceInfo* cpu_info, std::string hardware_name) { + /* Snapdragon */ + if (hardware_name.find("SDM845") != std::string::npos) { // 845 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 2, kA75, kA55); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 2, 256 * 1024, 128 * 1024); + set_cache_info(cpu_info, 2, 1, 2048 * 1024); + return true; + + } else if (hardware_name.find("SDM710") != std::string::npos) { // 710 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 0, 0}; + set_arch_info(cpu_info, 2, kA75, kA55); + return true; + } else if (hardware_name.find("MSM8998") != std::string::npos) { // 835 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 2, kA73, kA53); + set_cache_info(cpu_info, 0, 2, 64 * 1024); + set_cache_info(cpu_info, 1, 2, 1024 * 1024, + /*real cache size is 2M, while that will get bad performace + on conv3x3s1 or gemm, set to 1M or 512K*/ + 1024 * 1024); + return true; + + } else if (hardware_name.find("MSM8996") != std::string::npos) { // 820 + cpu_info->compute_core_num_ = 4; + cpu_info->core_ids_ = {0, 1, 2, 3}; + cpu_info->big_core_ids_ = {2, 3}; + cpu_info->little_core_ids_ = {0, 1}; + cpu_info->cluster_ids_ = {1, 1, 0, 0}; + set_arch_info(cpu_info, 1, kA72); + set_cache_info(cpu_info, 0, 1, 24 * 1024); + set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024); + return true; + + } else if (hardware_name.find("SDM660") != std::string::npos || + hardware_name.find("SDM636") != std::string::npos) { // 660, 636 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA73); + set_cache_info(cpu_info, 0, 2, 64 * 1024, 32 * 1024); + set_cache_info(cpu_info, 1, 1, 1024 * 1024); + return true; + + } else if (hardware_name.find("MSM8976") != std::string::npos) { // 652,653 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 2, kA72, kA53); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024); + return true; + + } else if (hardware_name.find("MSM8953") != std::string::npos) { // 625 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->little_core_ids_ = {}; + cpu_info->cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA53); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 1, 1024 * 1024); + return true; + + } else if (hardware_name.find("MSM8939") != std::string::npos) { // 615 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {0, 1, 2, 3}; + cpu_info->little_core_ids_ = {4, 5, 6, 7}; + cpu_info->cluster_ids_ = {0, 0, 0, 0, 1, 1, 1, 1}; + set_arch_info(cpu_info, 1, kA53); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 2, 512 * 1024, 256 * 1024); + return true; + + /* MediaTek */ + + } else if (hardware_name.find("MT6797") != + std::string::npos) { // X20/X23/X25/X27 + cpu_info->compute_core_num_ = 10; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + cpu_info->big_core_ids_ = {8, 9}; + cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0}; + set_arch_info(cpu_info, 2, kA72, kA53); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024); + return true; + + } else if (hardware_name.find("MT6799") != std::string::npos) { // X30 + cpu_info->compute_core_num_ = 10; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + cpu_info->big_core_ids_ = {8, 9}; + cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0}; + set_arch_info(cpu_info, 2, kA73, kA53); + return true; + + } else if (hardware_name.find("MT6795") != std::string::npos || + hardware_name.find("MT6762") != std::string::npos || + hardware_name.find("MT6755T") != std::string::npos || + hardware_name.find("MT6755S") != std::string::npos || + hardware_name.find("MT6753") != std::string::npos || + hardware_name.find("MT6752") != std::string::npos || + hardware_name.find("MT6750") != std::string::npos) { + // X10, P22, P15/P18, MT6753, MT6752/MT6752M, MT6750 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->little_core_ids_ = {}; + cpu_info->cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA53); + return true; + + } else if (hardware_name.find("MT6758") != std::string::npos || + hardware_name.find("MT6757") != std::string::npos || + hardware_name.find("MT6763") != std::string::npos || + hardware_name.find("MT6755M") != std::string::npos || + hardware_name.find("MT6755") != + std::string::npos) { // P30, P20/P25, P23, P10 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA53); + return true; + + } else if (hardware_name.find("MT6771") != std::string::npos) { // P60 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 2, kA73, kA53); + return true; + + } else if (hardware_name.find("MT6765") != std::string::npos || + hardware_name.find("MT6739") != std::string::npos || + hardware_name.find("MT6738") != std::string::npos || + hardware_name.find("MT6737") != + std::string::npos) { // A22, MT6739, MT6738, MT6767 + cpu_info->compute_core_num_ = 4; + cpu_info->core_ids_ = {0, 1, 2, 3}; + cpu_info->big_core_ids_ = {0, 0, 0, 0}; + cpu_info->little_core_ids_ = {}; + cpu_info->cluster_ids_ = {0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA53); + return true; + } + return false; +} + +size_t arm_get_meminfo() { +#ifdef LITE_WITH_ANDROID + // get cpu count from /proc/cpuinfo + FILE* fp = fopen("/proc/meminfo", "rb"); + if (!fp) { + return 1; + } + + size_t memsize = 0; + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + sscanf(s, "MemTotal: %d kB", &memsize); + } + + fclose(fp); + + return memsize; +#elif defined(TARGET_IOS) + // to be implemented + printf("not implemented\n"); + return 0; +#endif +} + +int arm_get_cpucount() { +#ifdef LITE_WITH_ANDROID + // get cpu count from /sys/devices/system/cpu/cpunum/uevent + int max_cpu_count = 20; + int count = 0; + for (int i = 0; i < max_cpu_count; ++i) { + char path[256]; + snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/uevent", i); + FILE* fp = fopen(path, "rb"); + if (!fp) { + break; + } + count++; + fclose(fp); + } + if (count < 1) { + count = 1; + } + return count; +#elif defined(TARGET_IOS) + int count = 0; + size_t len = sizeof(count); + sysctlbyname("hw.ncpu", &count, &len, NULL, 0); + if (count < 1) { + count = 1; + } + return count; +#else + return 1; +#endif +} + +void arm_get_cpu_arch(std::vector* archs) { +#ifdef LITE_WITH_ANDROID + archs->clear(); + //! get CPU ARCH + FILE* fp = fopen("/proc/cpuinfo", "rb"); + if (!fp) { + return; + } + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + if (strstr(line, "part") != NULL) { + int arch_id = 0; + sscanf(s, "CPU part\t: %x", &arch_id); + switch (arch_id) { + case 0xd03: + archs->push_back(kA53); + break; + case 0xd05: + archs->push_back(kA55); + break; + case 0xd07: + archs->push_back(kA57); + break; + case 0xd08: + archs->push_back(kA72); + break; + case 0xd09: + archs->push_back(kA73); + break; + case 0xd0a: + archs->push_back(kA75); + break; + case 0x800: + // 835 + archs->push_back(kA73); + break; + case 0x205: + // 820 + archs->push_back(kA72); + break; + default: + LOG(ERROR) << "unknow type"; + archs->push_back(kARMArch_UNKOWN); + } + } + } + fclose(fp); + int cpu_count = arm_get_cpucount(); + if (archs->size() < cpu_count) { + for (int i = archs->size(); i < cpu_count; ++i) { + archs->push_back(archs->at(i - 1)); + } + } +#endif +#ifdef TARGET_IOS + int cpu_count = arm_get_cpucount(); + for (int i = 0; i < cpu_count; ++i) { + archs->push_back(APPLE); + } +#endif +} + +#ifdef LITE_WITH_ANDROID + +void set_default_cache(DeviceInfo* dev) { + int cpu_count = arm_get_cpucount(); + dev->L1_cache_.resize(cpu_count); + dev->L2_cache_.resize(cpu_count); + dev->L3_cache_.resize(cpu_count); +#ifdef TARGET_IOS + for (int i = 0; i < cpu_count; ++i) { + dev->L1_cache_[i] = 64 * 1024; + dev->L2_cache_[i] = 2048 * 1024; + dev->L3_cache_[i] = 0; + } +#else + for (int i = 0; i < cpu_count; ++i) { + dev->L1_cache_[i] = 32 * 1024; + dev->L2_cache_[i] = 512 * 1024; + dev->L3_cache_[i] = 0; + } +#endif +} +std::string arm_get_cpu_name() { + FILE* fp = fopen("/proc/cpuinfo", "rb"); + if (!fp) { + return ""; + } + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + if (strstr(line, "Hardware") != NULL) { + fclose(fp); + return std::string(line); + } + } + fclose(fp); + return ""; +} + +int get_max_freq_khz(int cpuid) { + // first try, for all possible cpu + char path[256]; + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", cpuid); + + FILE* fp = fopen(path, "rb"); + + if (!fp) { + // second try, for online cpu + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state", + cpuid); + fp = fopen(path, "rb"); + + if (!fp) { + // third try, for online cpu + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", cpuid); + fp = fopen(path, "rb"); + + if (!fp) { + return -1; + } + + int max_freq_khz = -1; + fscanf(fp, "%d", &max_freq_khz); + + fclose(fp); + + return max_freq_khz; + } + } + + int max_freq_khz = 0; + while (!feof(fp)) { + int freq_khz = 0; + int nscan = fscanf(fp, "%d %*d", &freq_khz); + if (nscan != 1) { + break; + } + + if (freq_khz > max_freq_khz) { + max_freq_khz = freq_khz; + } + } + + fclose(fp); + + return max_freq_khz; +} + +int arm_sort_cpuid_by_max_frequency(int cpu_count, std::vector* cpuids, + const std::vector& cpu_freq, + std::vector* cluster_ids) { + if (cpu_count == 0) { + return 0; + } + + cpuids->resize(cpu_count); + cluster_ids->resize(cpu_count); + + for (int i = 0; i < cpu_count; i++) { + cpuids->at(i) = i; + } + + // sort cpuid as big core first + // simple bubble sort + + for (int i = 0; i < cpu_count; i++) { + for (int j = i + 1; j < cpu_count; j++) { + if (cpu_freq[i] < cpu_freq[j]) { + // swap + int tmp = cpuids->at(i); + cpuids->at(i) = cpuids->at(j); + cpuids->at(j) = tmp; + } + } + } + // SMP + int mid_max_freq_khz = + (cpu_freq[cpuids->at(0)] + cpu_freq[cpuids->at(cpu_count - 1)]) / 2; + + for (int i = 0; i < cpu_count; i++) { + cpuids->at(i) = i; + if (cpu_freq[i] >= mid_max_freq_khz) { + cluster_ids->at(i) = 0; + } else { + cluster_ids->at(i) = 1; + } + } + return 0; +} + +int check_online(const std::vector& core_ids) { + if (core_ids.size() == 0) { + return 0; + } + char path[256]; + int online = 1; + for (int i = 0; i < core_ids.size(); ++i) { + snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/online", + core_ids[i]); + FILE* fp = fopen(path, "rb"); + if (!fp) { + return 0; + } + int cur_online = 0; + fscanf(fp, "%d", &cur_online); + online &= cur_online; + fclose(fp); + } + return online; +} + +int set_sched_affinity(const std::vector& cpuids) { +// #define CPU_SETSIZE 1024 +// #define __NCPUBITS (8 * sizeof (unsigned long)) +// typedef struct +// { +// unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; +// } cpu_set_t; + +// set affinity for thread +#ifdef __GLIBC__ + pid_t pid = syscall(SYS_gettid); +#else + pid_t pid = gettid(); +#endif + cpu_set_t mask; + CPU_ZERO(&mask); + for (int i = 0; i < cpuids.size(); i++) { + CPU_SET(cpuids[i], &mask); + } + + int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask); + if (syscallret) { + LOG(ERROR) << "syscall error " << syscallret; + return -1; + } + + return 0; +} + +#endif // LITE_WITH_ANDROID + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/cpu_info.h b/paddle/fluid/lite/core/cpu_info.h new file mode 100644 index 0000000000000000000000000000000000000000..23a996f80e0c9ec78e8d9a90088eeea26aa80f1f --- /dev/null +++ b/paddle/fluid/lite/core/cpu_info.h @@ -0,0 +1,125 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/utils/cp_logging.h" + +#ifdef LITE_WITH_ANDROID +#include +#include +#endif + +#if __APPLE__ +#include "TargetConditionals.h" +#if TARGET_OS_IPHONE +#include +#include +#include +#endif // TARGET_OS_IPHONE +#endif // __APPLE__ + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM + +typedef enum { + LITE_POWER_HIGH = 0, + LITE_POWER_LOW = 1, + LITE_POWER_FULL = 2, + LITE_POWER_NO_BIND = 3, + LITE_POWER_RAND_HIGH = 4, + LITE_POWER_RAND_LOW = 5 +} PowerMode; + +typedef enum { + kAPPLE = 0, + kA53 = 53, + kA55 = 55, + kA57 = 57, + kA72 = 72, + kA73 = 73, + kA75 = 75, + kA76 = 76, + kARMArch_UNKOWN = -1 +} ARMArch; + +class DeviceInfo { + public: + int idx_; + int max_freq_; + int min_freq_; + int generate_arch_; + int compute_core_num_; + int max_memory_; + int sharemem_size_; + + std::string device_name_; + std::string compute_ability_; + + std::vector L1_cache_; + std::vector L2_cache_; + std::vector L3_cache_; + std::vector core_ids_; + std::vector big_core_ids_; + std::vector little_core_ids_; + std::vector cluster_ids_; + std::vector archs_; + + static DeviceInfo& Global() { + static auto* x = new DeviceInfo; + return *x; + } + + static void init_info() { + auto& info = Global(); + get_info(&info); + } + + private: + DeviceInfo() = default; + static void get_info(DeviceInfo* dev); +}; + +size_t arm_get_meminfo(); + +int arm_get_cpucount(); + +void arm_get_cpu_arch(std::vector* archs); + +bool get_cpu_info_from_name(DeviceInfo* cpu_info, std::string hardware_name); + +#ifdef LITE_WITH_ANDROID + +void set_default_cache(DeviceInfo* dev); + +std::string arm_get_cpu_name(); + +int get_max_freq_khz(int cpuid); + +int arm_sort_cpuid_by_max_frequency(int cpu_count, std::vector* cpuids, + const std::vector& cpu_freq, + std::vector* cluster_ids); +int check_online(const std::vector& core_ids); +int set_sched_affinity(const std::vector& cpuids); + +#endif // LITE_WITH_ANDROID + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 6846dbb920d2b8ebef0ad1062ff3074ac9409e37..2eee83bd4a51f2d90ddbf81055146403c78314fa 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -44,7 +44,7 @@ class KernelBase { virtual void Run() = 0; void SetContext(std::unique_ptr&& ctx) { - context_ = std::move(ctx); + ctx_ = std::move(ctx); } template void SetParam(T param) { @@ -86,7 +86,7 @@ class KernelBase { virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual DataLayoutType layout() const = 0; - const KernelContext* context() const { return context_.get(); } + const KernelContext* context() const { return ctx_.get(); } virtual std::string name() const = 0; // Short human-readable document. @@ -134,7 +134,7 @@ class KernelBase { void Torch() {} protected: - std::unique_ptr context_; + std::unique_ptr ctx_; mutable operators::param_t param_; // The corresponding op type. std::string op_type_{}; @@ -152,9 +152,6 @@ template class KernelLite : public KernelBase { public: - // Set runtime context. - void SetContext(std::unique_ptr&& ctx) { ctx_ = ctx; } - // Run the kernel. virtual void Run() { CHECK(false) << "Not Implemented"; } @@ -168,9 +165,6 @@ class KernelLite : public KernelBase { KernelLite() = default; virtual ~KernelLite() = default; - - protected: - std::unique_ptr ctx_; }; template diff --git a/paddle/fluid/lite/core/lite_tensor.h b/paddle/fluid/lite/core/lite_tensor.h index 3fe29cc33313e08fe3feca6771a6e7bf40416fd1..433bc6911164f106bbf595b77a5665597bc1ce34 100644 --- a/paddle/fluid/lite/core/lite_tensor.h +++ b/paddle/fluid/lite/core/lite_tensor.h @@ -14,6 +14,7 @@ #pragma once #include +#include // for multiplies #include #include #include @@ -40,6 +41,10 @@ class DDimLite : public DDimBase { size_t size() const { return data_.size(); } bool empty() const { return data_.empty(); } + value_type product() const { + return std::accumulate(std::begin(data_), std::end(data_), 1, + std::multiplies()); + } const std::vector &data() const { return data_; } private: @@ -61,8 +66,10 @@ class TensorLite : public TensorBase { } void Resize(const DDimLite &ddim) { dims_ = ddim; } + void Resize(const std::vector &x) { dims_ = DDimLite(x); } const DDimLite &dims() const { return dims_; } + int64_t numel() const { return dims_.product(); } const LoD &lod() const { return lod_; } LoD *mutable_lod() { return &lod_; } diff --git a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc index 3d2012306b9fd4c89a72f2e29223be75e847c204..1852fc2fcbee3fedb09835a8c6d4c2bd67705a53 100644 --- a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc +++ b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc @@ -32,7 +32,6 @@ class RuntimeContextAssignPass : public StmtPass { if (!node.IsStmt()) continue; auto& inst = node.AsStmt(); - switch (inst.picked_kernel().target()) { case TARGET(kHost): case TARGET(kX86): @@ -42,6 +41,11 @@ class RuntimeContextAssignPass : public StmtPass { case TARGET(kCUDA): inst.picked_kernel().SetContext(NewCudaContext()); break; +#endif +#ifdef LITE_WITH_ARM + case TARGET(kARM): + inst.picked_kernel().SetContext(NewARMContext()); + break; #endif default: LOG(FATAL) << "unsupported target " @@ -54,9 +58,18 @@ class RuntimeContextAssignPass : public StmtPass { std::unique_ptr ctx(new KernelContext); ctx->As(); // Some initialization here. + return ctx; } +#ifdef LITE_WITH_ARM + std::unique_ptr NewARMContext() { + DeviceInfo::init_info(); + std::unique_ptr ctx(new KernelContext); + ctx->As(); + return ctx; + } +#endif #ifdef LITE_WITH_CUDA std::unique_ptr NewCudaContext() { std::unique_ptr ctx(new KernelContext); @@ -66,9 +79,7 @@ class RuntimeContextAssignPass : public StmtPass { cuda.blas_fp32 = cublas_fp32_; return ctx; } -#endif -#ifdef LITE_WITH_CUDA void InitCudaBlas() { cublas_fp32_ = std::make_shared>(); } diff --git a/paddle/fluid/lite/kernels/CMakeLists.txt b/paddle/fluid/lite/kernels/CMakeLists.txt index 0708e7d9a04b318b37d586f58c984156000620a5..ce22ba1216664cdf539ee4f576016adc389622ca 100644 --- a/paddle/fluid/lite/kernels/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/CMakeLists.txt @@ -1,5 +1,5 @@ message(STATUS "add lite kernels") -set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite ${tensor_lite}) +set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite context_lite ${tensor_lite}) add_subdirectory(host) add_subdirectory(arm) add_subdirectory(cuda) diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index b5fc0bdea8a719648d606bbd215cb1564834183a..75dc9fe43adf169f129572a642cf19b285541996 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -4,11 +4,13 @@ endif() message(STATUS "compile with lite ARM kernels") -cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} eigen3) +cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3) +lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm eigen3) + set(arm_kernels fc_compute_arm relu_compute_arm diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.cc b/paddle/fluid/lite/kernels/arm/fc_compute.cc index 6b7060227d8d40b5b75276879fb9ce8e2abd7cdc..b26551e0533a5ae68c930cc1b9512ba0ca13253a 100644 --- a/paddle/fluid/lite/kernels/arm/fc_compute.cc +++ b/paddle/fluid/lite/kernels/arm/fc_compute.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/kernels/arm/fc_compute.h" -#include +#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/type_system.h" @@ -22,24 +22,42 @@ namespace lite { namespace kernels { namespace arm { -// NOTE should use pure std C++ implementation. void FcCompute::Run() { auto& param = this->Param(); + auto x_dims = param.input->dims(); + auto w_dims = param.w->dims(); - CHECK_GE(param.input->dims().size(), 2UL); + CHECK_GE(x_dims.size(), 2UL); + CHECK_EQ(w_dims.size(), 2UL); CHECK_EQ(param.output->dims().size(), 2UL); - fc_compute_eigen( - param.input->data(), // x - param.input->dims().Slice(0, param.in_num_col_dims).production(), - param.input->dims() - .Slice(param.in_num_col_dims, param.input->dims().size()) - .production(), - param.w->data(), // w - param.w->dims()[1], // w_w - param.w->dims()[0], // w_h - param.bias->data(), // b - param.output->mutable_data()); + const auto* i_data = param.input->data(); + const auto* w_data = param.w->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + + int x_h = x_dims.Slice(0, param.in_num_col_dims).production(); + int x_w = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); + int n = w_dims[1]; + CHECK_EQ(x_w, static_cast(w_dims[0])); + auto& ctx = this->ctx_->template As(); + if (x_h > 1) { + float* packed_in = static_cast(ctx.workspace_data()) + + ctx.l2_cache_size() / sizeof(float); + lite::arm::math::prepackA(packed_in, i_data, x_w, 0, x_h, 0, x_w, false, + &ctx); + lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, x_h, n, + x_w, false, false, false, &ctx); + + if (param.bias) { + CHECK_EQ(param.bias->numel(), n); + lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n); + } + } else { + // use sgemmv + // sgemv((const float*)weights, (const float*)din, (float*)dout, + // false, n, x_w, _param->_flag_bias, (float*)bias, false); + } } TargetType FcCompute::target() const { return TARGET(kARM); } diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.h b/paddle/fluid/lite/kernels/arm/fc_compute.h index 36f3e0723124169905bba40fcd209a516dfd0dce..414517843354f638ed37f54ef596dc6db53193ce 100644 --- a/paddle/fluid/lite/kernels/arm/fc_compute.h +++ b/paddle/fluid/lite/kernels/arm/fc_compute.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/operators/fc_op.h" @@ -34,52 +33,6 @@ class FcCompute : public KernelLite { virtual ~FcCompute() = default; }; -template -void fc_compute_eigen(const T* x, int x_w, int x_h, // - const T* w, int w_w, int w_h, // - const T* b, // - T* out) { - using matrix_t = - Eigen::Matrix; - - Eigen::Map X(x, x_h, x_w); - Eigen::Map W(w, w_h, w_w); - Eigen::Map Out(out, x_h, w_h); - - Out = X * W.transpose(); - - if (b) { - Eigen::Map> B(b, w_h); - Out = Out.array().rowwise() + B.transpose().array(); - } -} - -template -__attribute__((optimize("unroll-loops"))) // -T dot(const T* x, const T* y, int dim) { - T out{}; - for (int i = 0; i < dim; i++) { - out += x[i] * y[i]; - } - return out; -} - -template -void fc_compute_naive(const T* x, int x_w, int x_h, // - const T* w, int w_w, int w_h, // - const T* b, // - T* out) { - CHECK_EQ(x_w, w_w); - // out shape: (x_h, w_w) - memset(out, 0, x_h * w_h * sizeof(T)); - - for (int r = 0; r < x_h; r++) { - for (int c = 0; c < w_h; c++) { - out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c]; - } - } -} - } // namespace arm } // namespace kernels } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/fc_compute_test.cc b/paddle/fluid/lite/kernels/arm/fc_compute_test.cc index 5f5de8a89de9eed74716fe97c034903898801f4e..1949a3a1eb1133df12bd5d9176265237fbba94e5 100644 --- a/paddle/fluid/lite/kernels/arm/fc_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/fc_compute_test.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/lite/kernels/arm/fc_compute.h" #include #include +#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { @@ -22,60 +23,79 @@ namespace lite { namespace kernels { namespace arm { -TEST(fc_compute_naive, test) { - lite::Tensor x, w, b, out, out1; - const int batch_size = 2; +TEST(fc_arm, retrive_op) { + auto fc = + KernelRegistry::Global().Create("fc"); + ASSERT_FALSE(fc.empty()); + ASSERT_TRUE(fc.front()); +} + +TEST(fc_arm, init) { + FcCompute fc; + ASSERT_EQ(fc.precision(), PRECISION(kFloat)); + ASSERT_EQ(fc.target(), TARGET(kARM)); +} + +TEST(fc_arm, compare_test) { + lite::Tensor x, w, b, out, ref; + constexpr int batch_size = 2; x.Resize({batch_size, 3}); - w.Resize({4, 3}); + w.Resize({3, 4}); b.Resize({1, 4}); out.Resize({batch_size, 4}); - out1.Resize({batch_size, 4}); + ref.Resize({batch_size, 4}); auto x_data = x.mutable_data(); auto w_data = w.mutable_data(); auto b_data = b.mutable_data(); auto out_data = out.mutable_data(); - auto out_data1 = out1.mutable_data(); + auto ref_data = ref.mutable_data(); - for (int i = 0; i < product(x.dims()); i++) x_data[i] = i; - for (int i = 0; i < product(w.dims()); i++) w_data[i] = i; - for (int i = 0; i < product(b.dims()); i++) b_data[i] = i; - - fc_compute_naive(x_data, 3, batch_size, // - w_data, 3, 4, // - b_data, out_data); - fc_compute_eigen(x_data, 3, batch_size, // - w_data, 3, 4, // - b_data, out_data1); - - for (int i = 0; i < product(out.dims()); i++) { - EXPECT_NEAR(out_data[0], out_data1[0], 1e-6); + for (int64_t i = 0; i < x.dims().product(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < w.dims().product(); i++) { + w_data[i] = static_cast(i); + } + for (int64_t i = 0; i < b.dims().product(); i++) { + b_data[i] = static_cast(i); } -} -TEST(fc_arm, init) { + // TODO(TJ): enable bias soon + b_data = nullptr; + lite::arm::math::fc_compute_eigen(x_data, batch_size, 3, // + w_data, 3, 4, // + b_data, ref_data); + + // fc compute kernel FcCompute fc; - ASSERT_EQ(fc.precision(), PRECISION(kFloat)); - ASSERT_EQ(fc.target(), TARGET(kARM)); -} + operators::FcParam param; -TEST(fc_arm, algorithm) { - using matrix_t = Eigen::Matrix; - using matrix_map_t = Eigen::Map; + param.in_num_col_dims = 1; + param.input = &x; + param.w = &w; + param.bias = nullptr; + param.output = &out; + param.in_mat_dims = x.dims(); - // dim 10, 20 - std::vector input(10 * 20); - std::vector w(20 * 20); - std::vector output(10 * 20); + DeviceInfo::init_info(); + std::unique_ptr ctx(new KernelContext); + ctx->As(); + fc.SetParam(param); + fc.SetContext(std::move(ctx)); + fc.Run(); - Eigen::Map input_mat(input.data(), 10, 20); - Eigen::Map weight_mat(w.data(), 20, 20); - matrix_map_t output_mat(output.data(), 10, 20); + VLOG(3) << "output vs ref"; + for (int i = 0; i < out.dims().product(); i++) { + VLOG(3) << out_data[i] << " vs " << ref_data[i]; + } - output_mat = weight_mat.transpose() * input_mat; + for (int i = 0; i < out.dims().product(); ++i) { + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } } -TEST(fc_arm, compute) { +TEST(fc_arm, num_col_dims) { FcCompute fc; operators::FcParam param; @@ -84,20 +104,28 @@ TEST(fc_arm, compute) { lite::Tensor bias; lite::Tensor output; - x.Resize(DDim(std::vector({1, 10, 20}))); - w.Resize(DDim(std::vector({20, 20}))); - bias.Resize(DDim(std::vector({1, 10}))); - output.Resize(DDim(std::vector({10, 20}))); + x.Resize({1, 2, 3}); + w.Resize({3, 4}); + bias.Resize({1, 4}); + output.Resize({2, 4}); auto* x_data = x.mutable_data(); auto* w_data = w.mutable_data(); auto* bias_data = bias.mutable_data(); auto* output_data = output.mutable_data(); - for (int i = 0; i < 10 * 20; i++) x_data[i] = i; - for (int i = 0; i < 20 * 20; i++) w_data[i] = i; - for (int i = 0; i < 10; i++) bias_data[i] = i; - for (int i = 0; i < 10 * 20; i++) output_data[i] = 0; + for (int64_t i = 0; i < x.dims().product(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < w.dims().product(); i++) { + w_data[i] = static_cast(i); + } + for (int64_t i = 0; i < bias.dims().product(); i++) { + bias_data[i] = static_cast(i); + } + for (int64_t i = 0; i < output.dims().product(); i++) { + output_data[i] = static_cast(i); + } param.in_num_col_dims = 2; param.input = &x; @@ -106,20 +134,13 @@ TEST(fc_arm, compute) { param.output = &output; param.in_mat_dims = x.dims(); + std::unique_ptr ctx(new KernelContext); + ctx->As(); + DeviceInfo::init_info(); + fc.SetParam(param); + fc.SetContext(std::move(ctx)); fc.Run(); - - LOG(INFO) << "x"; - for (int i = 0; i < 10 * 20; i++) LOG(INFO) << x_data[i]; - - LOG(INFO) << "output:"; - for (int i = 0; i < 10 * 20; i++) LOG(INFO) << output.data()[i]; -} - -TEST(fc, retrive_op) { - auto fc = - KernelRegistry::Global().Create("fc"); - ASSERT_TRUE(fc); } } // namespace arm diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.h b/paddle/fluid/lite/kernels/cuda/mul_compute.h index 597d84683268b49e9b6311ad89b1ce1e3e0b0874..5eb30bf8dfd3293f3ea9b6ed84ef0528ff3ef426 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.h +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.h @@ -35,8 +35,8 @@ class MulCompute : public KernelLite { using param_t = operators::MulParam; void Run() override { - CHECK(context_) << "running context should be set first"; - auto& context = context_->As(); + CHECK(ctx_) << "running context should be set first"; + auto& context = ctx_->As(); CHECK(context.blas_fp32) << "blas should init first"; /* auto& blas = *context.blas_fp32; diff --git a/paddle/fluid/lite/kernels/x86/activation_compute.cc b/paddle/fluid/lite/kernels/x86/activation_compute.cc index 3001a98da118f2107245c184ad04a9920660c8c6..79f3829b61b19d652ae17e6d637c0cdc11ae739b 100644 --- a/paddle/fluid/lite/kernels/x86/activation_compute.cc +++ b/paddle/fluid/lite/kernels/x86/activation_compute.cc @@ -60,7 +60,7 @@ class SquareCompute : public KernelLite { using param_t = operators::ActivationParam; void Run() override { - auto& context = context_->As(); + auto& context = ctx_->As(); auto& param = *param_.get_mutable(); CHECK(context.x86_device_context); @@ -82,7 +82,7 @@ class SquareGradCompute : public KernelLite { using param_t = operators::ActivationGradParam; void Run() override { - auto& context = context_->As(); + auto& context = ctx_->As(); auto& param = *param_.get_mutable(); CHECK(context.x86_device_context); param.X_grad->template mutable_data(); diff --git a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc index d4ead92e431e65013670ce81f207456cd3c3760a..e2ca9a52df61df82de9eb581b4fb882fb129a69a 100644 --- a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc +++ b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc @@ -38,7 +38,7 @@ class ElementwiseSubCompute void Run() override { auto& param = *param_.get_mutable(); - auto& context = context_->As(); + auto& context = ctx_->As(); CHECK(context.x86_device_context); param.Out->template mutable_data(); diff --git a/paddle/fluid/lite/tools/build.sh b/paddle/fluid/lite/tools/build.sh index e3c639f18c0d7a058ec4853ff80f8544ae21a28a..37a04f901bc6515fea1f06a3fe44d3520decc511 100755 --- a/paddle/fluid/lite/tools/build.sh +++ b/paddle/fluid/lite/tools/build.sh @@ -22,7 +22,9 @@ function cmake_arm { -DLITE_WITH_CUDA=OFF \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ -DWITH_TESTING=ON \ + -DWITH_MKL=OFF \ -DWITH_MKLDNN=OFF + make cxx_api_lite_bin -j8 } function build {