未验证 提交 4b253569 编写于 作者: T tensor-tang 提交者: GitHub

[Lite] enable fc kernel (#17674)

* add fc unit test

* refine eigen fc
add cpu info, arm context
init packed sgemm

* enable packed sgemm

* add arm math

* pass fc ut

* follow comments
上级 e170ea03
......@@ -16,6 +16,8 @@ if(NOT ANDROID)
return()
endif()
add_definitions(-DLITE_WITH_ANDROID)
if(NOT DEFINED ANDROID_NDK)
set(ANDROID_NDK $ENV{NDK_ROOT})
if(NOT ANDROID_NDK)
......
......@@ -118,6 +118,7 @@ endfunction()
add_subdirectory(core)
add_subdirectory(x86)
add_subdirectory(arm)
add_subdirectory(host)
add_subdirectory(cuda)
add_subdirectory(operators)
......
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc DEPS ${lite_kernel_deps} eigen3)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/funcs.h"
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void fill_bias_fc<float>(float *tensor, const float *bias, const int num,
const int channel) {
int cnt = channel >> 4;
int remain = channel & 15;
for (int j = 0; j < num; ++j) {
const float *ptr_bias = bias;
float *ptr_out = tensor + j * channel;
float32x4_t vout1;
float32x4_t vout2;
float32x4_t vout3;
float32x4_t vout4;
for (int i = 0; i < cnt; ++i) {
float32x4_t vin1 = vld1q_f32(ptr_out);
float32x4_t vb1 = vld1q_f32(ptr_bias);
float32x4_t vin2 = vld1q_f32(ptr_out + 4);
float32x4_t vb2 = vld1q_f32(ptr_bias + 4);
float32x4_t vin3 = vld1q_f32(ptr_out + 8);
float32x4_t vb3 = vld1q_f32(ptr_bias + 8);
float32x4_t vin4 = vld1q_f32(ptr_out + 12);
float32x4_t vb4 = vld1q_f32(ptr_bias + 12);
vout1 = vaddq_f32(vin1, vb1);
vout2 = vaddq_f32(vin2, vb2);
vout3 = vaddq_f32(vin3, vb3);
vout4 = vaddq_f32(vin4, vb4);
vst1q_f32(ptr_out, vout1);
vst1q_f32(ptr_out + 4, vout2);
vst1q_f32(ptr_out + 8, vout3);
vst1q_f32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
#if 0
if (cnt > 0) {
asm(
"1: \n"
"vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n"
"vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n"
"vadd.f32 q2, q0, q1 @ add bias\n"
"vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 1b @ jump to main loop\n"
:[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \
[cnt] "+r"(cnt)
:
:"q0", "q1", "q2"
);
}
#endif
for (; remain > 0; remain--) {
*(ptr_out++) += *(ptr_bias++);
}
}
}
template <>
void fill_bias_fc<int>(int *tensor, const int *bias, const int num,
const int channel) {
int cnt = channel >> 4;
int remain = channel & 15;
for (int j = 0; j < num; ++j) {
const int *ptr_bias = bias;
int *ptr_out = tensor + j * channel;
int32x4_t vout1;
int32x4_t vout2;
int32x4_t vout3;
int32x4_t vout4;
for (int i = 0; i < cnt; ++i) {
int32x4_t vin1 = vld1q_s32(ptr_out);
int32x4_t vb1 = vld1q_s32(ptr_bias);
int32x4_t vin2 = vld1q_s32(ptr_out + 4);
int32x4_t vb2 = vld1q_s32(ptr_bias + 4);
int32x4_t vin3 = vld1q_s32(ptr_out + 8);
int32x4_t vb3 = vld1q_s32(ptr_bias + 8);
int32x4_t vin4 = vld1q_s32(ptr_out + 12);
int32x4_t vb4 = vld1q_s32(ptr_bias + 12);
vout1 = vaddq_s32(vin1, vb1);
vout2 = vaddq_s32(vin2, vb2);
vout3 = vaddq_s32(vin3, vb3);
vout4 = vaddq_s32(vin4, vb4);
vst1q_s32(ptr_out, vout1);
vst1q_s32(ptr_out + 4, vout2);
vst1q_s32(ptr_out + 8, vout3);
vst1q_s32(ptr_out + 12, vout4);
ptr_out += 16;
ptr_bias += 16;
}
#if 0
if (cnt > 0) {
asm(
"1: \n"
"vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n"
"vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n"
"vadd.s32 q2, q0, q1 @ add bias\n"
"vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 1b @ jump to main loop\n"
:[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \
[cnt] "+r"(cnt)
:
:"q0", "q1", "q2"
);
}
#endif
for (; remain > 0; remain--) {
*(ptr_out++) += *(ptr_bias++);
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <cmath>
#include "paddle/fluid/lite/arm/math/packed_sgemm.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void fill_bias_fc(T* tensor, const T* bias, const int num, const int channel);
template <typename T>
void fc_compute_eigen(const T* x, int x_h, int x_w, //
const T* w, int w_h, int w_w, //
const T* b, //
T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_w);
Out = X * W;
if (b) {
Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>> B(b, w_w);
Out = Out.array().rowwise() + B.array();
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/packed_sgemm.h"
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#ifdef __aarch64__
void prepackA_8x12(float *out, const float *in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax);
void prepackA_trans_8x12(float *out, const float *in, const int ldin,
const int m0, const int mmax, const int k0,
const int kmax);
void sgemm_conv_8x12(const float *A_packed, const float *B, const float *bias,
float *C, int M, int N, int K, bool is_bias, bool is_relu,
bool transB, ARMContext *ctx);
#else
// for kA72
void prepackA_6x8(float *out, const float *in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax);
void prepackA_trans_6x8(float *out, const float *in, const int ldin,
const int m0, const int mmax, const int k0,
const int kmax);
// for kA73
void prepackA_4x8(float *out, const float *in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax);
void prepackA_trans_4x8(float *out, const float *in, const int ldin,
const int m0, const int mmax, const int k0,
const int kmax);
// for kA72, 6x8
void sgemm_conv_6x8(const float *A_packed, const float *B, const float *bias,
float *C, int M, int N, int K, bool is_bias, bool is_relu,
bool transB, ARMContext *ctx);
// for kA73, 4x8
void sgemm_conv_4x8(const float *A_packed, const float *B, const float *bias,
float *C, int M, int N, int K, bool is_bias, bool is_relu,
bool transB, ARMContext *ctx);
#endif // __aarch64__
/**
* \brief input data is not transpose
* for arm-v7a, transform data to block x k x 6 layout
* for arm-v8a, transform data to block x k x 8 layout
*/
void prepackA(float *out, const float *in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax, bool is_trans,
ARMContext *ctx) {
#ifdef __aarch64__
if (is_trans) {
prepackA_trans_8x12(out, in, ldin, m0, mmax, k0, kmax);
} else {
prepackA_8x12(out, in, ldin, m0, mmax, k0, kmax);
}
#else
if (ctx->arch() == kA73) {
if (is_trans) {
prepackA_trans_4x8(out, in, ldin, m0, mmax, k0, kmax);
} else {
prepackA_4x8(out, in, ldin, m0, mmax, k0, kmax);
}
} else {
if (is_trans) {
prepackA_trans_6x8(out, in, ldin, m0, mmax, k0, kmax);
} else {
prepackA_6x8(out, in, ldin, m0, mmax, k0, kmax);
}
}
#endif
}
void prepackA(TensorLite *tout, const TensorLite &tin, int m, int k, int group,
bool is_trans, ARMContext *ctx) {
int hblock = get_hblock(ctx->arch());
int m_roundup = hblock * ((m + hblock - 1) / hblock);
int group_size_round_up = ((m_roundup * k + 15) / 16) * 16;
if (tout->numel() < group_size_round_up * group) {
tout->Resize({group_size_round_up * group});
}
int lda = k;
if (is_trans) {
lda = m;
}
for (int g = 0; g < group; ++g) {
const float *weights_group = tin.data<float>() + g * m * k;
float *weights_trans_ptr =
tout->mutable_data<float>() + g * group_size_round_up;
prepackA(weights_trans_ptr, weights_group, lda, 0, m, 0, k, is_trans, ctx);
}
}
/// a: m*k b: k*n c: m*n
void sgemm_prepack(const float *A_packed, const float *B, const float *bias,
float *C, int M, int N, int K, bool is_bias, bool is_relu,
bool is_transB, ARMContext *ctx) {
#ifdef __aarch64__
sgemm_conv_8x12(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB,
ctx);
#else // armv7
if (ctx->arch() == kA73) {
sgemm_conv_4x8(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB,
ctx);
} else {
sgemm_conv_6x8(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB,
ctx);
}
#endif // arm64
}
#ifdef __aarch64__
void prepackA_8x12(float *out, const float *in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax) {
int x_len = kmax - k0;
uint32_t zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(uint32_t) * x_len);
uint32_t *dout = reinterpret_cast<uint32_t *>(out);
const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in);
int stride = x_len * 8;
#pragma omp parallel for
for (int y = m0; y < mmax; y += 8) {
uint32_t *outptr = dout + stride * (y - m0) / 8;
const uint32_t *inptr0 = inptr + y * ldin + k0;
const uint32_t *inptr1 = inptr0 + ldin;
const uint32_t *inptr2 = inptr1 + ldin;
const uint32_t *inptr3 = inptr2 + ldin;
const uint32_t *inptr4 = inptr3 + ldin;
const uint32_t *inptr5 = inptr4 + ldin;
const uint32_t *inptr6 = inptr5 + ldin;
const uint32_t *inptr7 = inptr6 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
"prfm pldl1keep, [%[ptr4]] \n"
"prfm pldl1keep, [%[ptr4], #64] \n"
"prfm pldl1keep, [%[ptr5]] \n"
"prfm pldl1keep, [%[ptr5], #64] \n"
"prfm pldl1keep, [%[ptr6]] \n"
"prfm pldl1keep, [%[ptr6], #64] \n"
"prfm pldl1keep, [%[ptr7]] \n"
"prfm pldl1keep, [%[ptr7], #64] \n"
:
: [ptr0] "r"(inptr0), [ptr1] "r"(inptr1), [ptr2] "r"(inptr2),
[ptr3] "r"(inptr3), [ptr4] "r"(inptr4), [ptr5] "r"(inptr5),
[ptr6] "r"(inptr6), [ptr7] "r"(inptr7)
: "memory");
int x = x_len;
//! cope with row index exceed real size, set to zero buffer
if ((y + 7) >= mmax) {
switch ((y + 7) - mmax) {
case 6:
inptr1 = zerobuff;
case 5:
inptr2 = zerobuff;
case 4:
inptr3 = zerobuff;
case 3:
inptr4 = zerobuff;
case 2:
inptr5 = zerobuff;
case 1:
inptr6 = zerobuff;
case 0:
inptr7 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
asm volatile(
// Load up 8 elements (2 vectors) from each of 8 sources.
"LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3
"LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3
"LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3
"ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1
"prfm pldl1keep, [%[inptr0], #128] \n"
"LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3
"ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1
"LDP q8, q9, [%[inptr4]], #32\n"
"LDP q10, q11, [%[inptr5]], #32\n"
"LDP q12, q13, [%[inptr6]], #32\n"
"ZIP1 v18.4s, v8.4s, v12.4s\n"
"prfm pldl1keep, [%[inptr1], #128]\n"
"LDP q14, q15, [%[inptr7]], #32\n"
"ZIP1 v19.4s, v10.4s, v14.4s\n"
"ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0
"prfm pldl1keep, [%[inptr2], #128]\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"ZIP2 v16.4s, v0.4s, v4.4s\n"
"prfm pldl1keep, [%[inptr3], #128]\n"
"ZIP2 v17.4s, v2.4s, v6.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Write back the first
// element of each source
"ZIP2 v18.4s, v8.4s, v12.4s\n"
"ZIP2 v19.4s, v10.4s, v14.4s\n"
"STP q22, q23, [%[outptr]], #32\n" // Write back the second
// element of each source
"ZIP1 v20.4s, v16.4s, v17.4s\n"
"prfm pldl1keep, [%[inptr4], #128]\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"ZIP1 v16.4s, v1.4s, v5.4s\n"
"prfm pldl1keep, [%[inptr5], #128]\n"
"ZIP1 v17.4s, v3.4s, v7.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Third element
"ZIP1 v18.4s, v9.4s, v13.4s\n"
"ZIP1 v19.4s, v11.4s, v15.4s\n"
"STP q22, q23, [%[outptr]], #32\n" // Fourth element
"ZIP1 v20.4s, v16.4s, v17.4s\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"prfm pldl1keep, [%[inptr6], #128]\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"ZIP2 v16.4s, v1.4s, v5.4s\n"
"ZIP2 v17.4s, v3.4s, v7.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Fifth element
"ZIP2 v18.4s, v9.4s, v13.4s\n"
"prfm pldl1keep, [%[inptr7], #128]\n"
"ZIP2 v19.4s, v11.4s, v15.4s\n"
"STP q22, q23, [%[outptr]], #32\n" // Sixth element
"ZIP1 v20.4s, v16.4s, v17.4s\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Seventh element
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"STP q22, q23, [%[outptr]], #32\n" // Eighth element
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "cc", "memory");
}
for (; x > 0; x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
*outptr++ = *inptr4++;
*outptr++ = *inptr5++;
*outptr++ = *inptr6++;
*outptr++ = *inptr7++;
}
}
}
void prepackA_trans_8x12(float *out, const float *in, const int ldin,
const int m0, const int mmax, const int k0,
const int kmax) {
uint32_t *outptr = reinterpret_cast<uint32_t *>(out);
const uint32_t *inptr =
reinterpret_cast<const uint32_t *>(in) + k0 * ldin + m0;
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = mmax - m0;
int y_len = kmax - k0;
int right_remain = x_len - 8 * (x_len / 8);
int right_pad = 8 - right_remain;
if (right_remain == 0) {
right_pad = 0;
}
uint32_t *outptr_row = outptr;
int stride_out = 8 * y_len;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
uint32x4_t vmask2 =
vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const uint32_t *ptr0 = inptr + y * ldin;
const uint32_t *ptr1 = ptr0 + ldin;
const uint32_t *ptr2 = ptr1 + ldin;
const uint32_t *ptr3 = ptr2 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3)
: "memory");
uint32_t *outptr_row_col = outptr_row + y * 8;
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32x4_t vr00 = vld1q_u32(ptr0);
uint32x4_t vr01 = vld1q_u32(ptr0 + 4);
uint32x4_t vr10 = vld1q_u32(ptr1);
uint32x4_t vr11 = vld1q_u32(ptr1 + 4);
vst1q_u32(outptr_row_col, vr00);
vst1q_u32(outptr_row_col + 4, vr01);
uint32x4_t vr20 = vld1q_u32(ptr2);
uint32x4_t vr21 = vld1q_u32(ptr2 + 4);
vst1q_u32(outptr_row_col + 8, vr10);
vst1q_u32(outptr_row_col + 12, vr11);
uint32x4_t vr30 = vld1q_u32(ptr3);
uint32x4_t vr31 = vld1q_u32(ptr3 + 4);
vst1q_u32(outptr_row_col + 16, vr20);
vst1q_u32(outptr_row_col + 20, vr21);
vst1q_u32(outptr_row_col + 24, vr30);
vst1q_u32(outptr_row_col + 28, vr31);
ptr0 += 8;
ptr1 += 8;
ptr2 += 8;
ptr3 += 8;
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32x4_t vr00 = vld1q_u32(ptr0);
uint32x4_t vr01 = vld1q_u32(ptr0 + 4);
uint32x4_t vr10 = vld1q_u32(ptr1);
uint32x4_t vr11 = vld1q_u32(ptr1 + 4);
uint32x4_t vr00_1 = vbslq_u32(vmask1, vr00, vzero);
uint32x4_t vr01_1 = vbslq_u32(vmask2, vr01, vzero);
uint32x4_t vr20 = vld1q_u32(ptr2);
uint32x4_t vr21 = vld1q_u32(ptr2 + 4);
vst1q_u32(outptr_row_col, vr00_1);
vst1q_u32(outptr_row_col + 4, vr01_1);
uint32x4_t vr10_1 = vbslq_u32(vmask1, vr10, vzero);
uint32x4_t vr11_1 = vbslq_u32(vmask2, vr11, vzero);
uint32x4_t vr30 = vld1q_u32(ptr3);
uint32x4_t vr31 = vld1q_u32(ptr3 + 4);
vst1q_u32(outptr_row_col + 8, vr10_1);
vst1q_u32(outptr_row_col + 12, vr11_1);
uint32x4_t vr20_1 = vbslq_u32(vmask1, vr20, vzero);
uint32x4_t vr21_1 = vbslq_u32(vmask2, vr21, vzero);
uint32x4_t vr30_1 = vbslq_u32(vmask1, vr30, vzero);
uint32x4_t vr31_1 = vbslq_u32(vmask2, vr31, vzero);
vst1q_u32(outptr_row_col + 16, vr20_1);
vst1q_u32(outptr_row_col + 20, vr21_1);
vst1q_u32(outptr_row_col + 24, vr30_1);
vst1q_u32(outptr_row_col + 28, vr31_1);
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const uint32_t *ptr0 = inptr + y * ldin;
uint32_t *outptr_row_col = outptr_row + y * 8;
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32x4_t vr0 = vld1q_u32(ptr0);
uint32x4_t vr1 = vld1q_u32(ptr0 + 4);
vst1q_u32(outptr_row_col, vr0);
vst1q_u32(outptr_row_col + 4, vr1);
ptr0 += 8;
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32x4_t vr0 = vld1q_u32(ptr0);
uint32x4_t vr1 = vld1q_u32(ptr0 + 4);
uint32x4_t vr0_1 = vbslq_u32(vmask1, vr0, vzero);
uint32x4_t vr1_1 = vbslq_u32(vmask2, vr1, vzero);
vst1q_u32(outptr_row_col, vr0_1);
vst1q_u32(outptr_row_col + 4, vr1_1);
}
}
}
#else // __aarch64__
void prepackA_6x8(float* out, const float* in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax) {
int x_len = kmax - k0;
uint32_t zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(uint32_t) * x_len);
uint32_t* dout = reinterpret_cast<uint32_t*>(out);
const uint32_t* inptr = reinterpret_cast<const uint32_t*>(in);
uint32_t* outptr = dout;
//! data A is not transposed, transpose A to k * 6
for (int y = m0; y < mmax; y += 6) {
const uint32_t* inptr0 = inptr + y * ldin + k0;
const uint32_t* inptr1 = inptr0 + ldin;
const uint32_t* inptr2 = inptr1 + ldin;
const uint32_t* inptr3 = inptr2 + ldin;
const uint32_t* inptr4 = inptr3 + ldin;
const uint32_t* inptr5 = inptr4 + ldin;
int x = x_len;
//! cope with row index exceed real size, set to zero buffer
if ((y + 5) >= mmax) {
switch ((y + 5) - mmax) {
case 4:
inptr1 = zerobuff;
case 3:
inptr2 = zerobuff;
case 2:
inptr3 = zerobuff;
case 1:
inptr4 = zerobuff;
case 0:
inptr5 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
//! zip load 8 elements (2 neon Q registers) from each of 6 rows
asm volatile(
"vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, "
"q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n"
"vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, "
"q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n"
"vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, "
"q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n"
"vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, "
"q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n"
"vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, "
"q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n"
"vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, "
"q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n"
"vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; "
"q2=r04,r14,r05,r15\n"
"vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; "
"q6=r24,r34,r25,r35\n"
"vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; "
"q10=r44,r54,r45,r55\n"
"vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; "
"q4=r01,r11,r21,r31\n"
"vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n"
"vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n"
"vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n"
"vst1.32 {d17}, [%[outptr]]! @ write d16(q8,high),r41,r51\n"
"vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; "
"q3=r06,r16,r07,r17\n"
"vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; "
"q7=r26,r36,r27,r37\n"
"vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; "
"q11=r46,r56,r47,r57\n"
"vswp d3, d10 @ swap d3, d10, "
"q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n"
"vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n"
"vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n"
"vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n"
"vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n"
"vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; "
"q6=r05,r15,r25,r35\n"
"vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n"
"vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n"
"vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n"
"vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n"
"vswp d7, d14 @ swap d7, d14, "
"q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n"
"vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n"
"vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n"
"vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n"
"vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[outptr] "+r"(outptr)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "cc", "memory");
}
for (; x > 0; x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
*outptr++ = *inptr4++;
*outptr++ = *inptr5++;
}
}
}
void prepackA_trans_6x8(float* out, const float* in, const int ldin,
const int m0, const int mmax, const int k0,
const int kmax) {
uint32_t* outptr = reinterpret_cast<uint32_t*>(out);
const uint32_t* inptr =
reinterpret_cast<const uint32_t*>(in) + k0 * ldin + m0;
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = mmax - m0;
int y_len = kmax - k0;
int right_remain = x_len - 6 * (x_len / 6);
int right_pad = 6 - right_remain;
if (right_remain == 0) {
right_pad = 0;
}
uint32_t* outptr_row = outptr;
int stride_out = 6 * y_len;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
uint32x4_t vmask2 =
vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const uint32_t* ptr0 = inptr + y * ldin;
const uint32_t* ptr1 = ptr0 + ldin;
const uint32_t* ptr2 = ptr1 + ldin;
const uint32_t* ptr3 = ptr2 + ldin;
uint32_t* outptr_row_col = outptr_row + y * 6;
int i = 0;
for (; i < x_len - 5; i += 6) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n"
"vld1.32 {d4-d6}, [%[ptr1]]! @ load r1, 6 elements\n"
"vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n"
"vld1.32 {d0-d2}, [%[ptr2]]! @ load r2, 6 elements\n"
"vld1.32 {d4-d6}, [%[ptr3]]! @ load r3, 6 elements\n"
"vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n"
: [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3)
:
: "q0", "q1", "q2", "q3", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_pad > 0) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n"
"vld1.32 {d4-d6}, [%[ptr1]]! @ load r1, 6 elements\n"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n"
"vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif d6, %e[vzero], %e[vmask2] @ bit select, pad zero\n"
"vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n"
"vld1.32 {d0-d2}, [%[ptr2]]! @ load r2, 8 elements\n"
"vld1.32 {d4-d6}, [%[ptr3]]! @ load r3, 8 elements\n"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n"
"vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif d6, %e[vzero], %e[vmask2] @ bit select, pad zero\n"
"vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n"
: [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero)
: "q0", "q1", "q2", "q3", "cc", "memory");
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const uint32_t* ptr0 = inptr + y * ldin;
uint32_t* outptr_row_col = outptr_row + y * 6;
int i = 0;
for (; i < x_len - 5; i += 6) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n"
"vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
:
: "q0", "q1", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_pad > 0) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n"
"vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero)
: "q0", "q1", "cc", "memory");
}
}
}
void prepackA_4x8(float* out, const float* in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax) {
int x_len = kmax - k0;
uint32_t zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(uint32_t) * x_len);
uint32_t* dout = reinterpret_cast<uint32_t*>(out);
const uint32_t* inptr = reinterpret_cast<const uint32_t*>(in);
uint32_t* outptr = dout;
//! data A is not transposed, transpose A to k * 4
for (int y = m0; y < mmax; y += 4) {
const uint32_t* inptr0 = inptr + y * ldin + k0;
const uint32_t* inptr1 = inptr0 + ldin;
const uint32_t* inptr2 = inptr1 + ldin;
const uint32_t* inptr3 = inptr2 + ldin;
int x = x_len;
//! cope with row index exceed real size, set to zero buffer
if ((y + 3) >= mmax) {
switch ((y + 3) - mmax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
//! zip load 8 elements (2 neon Q registers) from each of 4 rows
asm volatile(
"vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, "
"q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n"
"vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, "
"q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n"
"vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, "
"q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n"
"vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, "
"q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n"
"vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; "
"q2=r04,r14,r05,r15\n"
"vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; "
"q6=r24,r34,r25,r35\n"
"vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; "
"q4=r01,r11,r21,r31\n"
"vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n"
"vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n"
"vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; "
"q3=r06,r16,r07,r17\n"
"vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; "
"q7=r26,r36,r27,r37\n"
"vswp d3, d10 @ swap d3, d10, "
"q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n"
"vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n"
"vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n"
"vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; "
"q6=r05,r15,r25,r35\n"
"vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n"
"vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n"
"vswp d7, d14 @ swap d7, d14, "
"q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n"
"vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n"
"vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3), [outptr] "+r"(outptr)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "cc", "memory");
}
for (; x > 0; x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
}
}
}
void prepackA_trans_4x8(float* out, const float* in, const int ldin,
const int m0, const int mmax, const int k0,
const int kmax) {
uint32_t* outptr = reinterpret_cast<uint32_t*>(out);
const uint32_t* inptr =
reinterpret_cast<const uint32_t*>(in) + k0 * ldin + m0;
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = mmax - m0;
int y_len = kmax - k0;
int right_remain = x_len - 4 * (x_len / 4);
int right_pad = 4 - right_remain;
if (right_remain == 0) {
right_pad = 0;
}
uint32_t* outptr_row = outptr;
int stride_out = 4 * y_len;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
// uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask_buffer + 4),
// vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const uint32_t* ptr0 = inptr + y * ldin;
const uint32_t* ptr1 = ptr0 + ldin;
const uint32_t* ptr2 = ptr1 + ldin;
const uint32_t* ptr3 = ptr2 + ldin;
uint32_t* outptr_row_col = outptr_row + y * 4;
int i = 0;
for (; i < x_len - 3; i += 4) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n"
"vld1.32 {d2-d3}, [%[ptr1]]! @ load r1, 4 elements\n"
"vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d2-d3}, [%[outptr]]! @ write to output ptr\n"
"vld1.32 {d4-d5}, [%[ptr2]]! @ load r2, 4 elements\n"
"vld1.32 {d6-d7}, [%[ptr3]]! @ load r3, 4 elements\n"
"vst1.32 {d4-d5}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d6-d7}, [%[outptr]]! @ write to output ptr\n"
: [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3)
:
: "q0", "q1", "q2", "q3", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_pad > 0) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n"
"vld1.32 {d2-d3}, [%[ptr1]]! @ load r1, 4 elements\n"
"vld1.32 {d4-d5}, [%[ptr2]]! @ load r2, 4 elements\n"
"vld1.32 {d6-d7}, [%[ptr3]]! @ load r3, 4 elements\n"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif q1, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif q3, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d2-d3}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d5}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d6-d7}, [%[outptr]]! @ write to output ptr\n"
: [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3)
: [vmask1] "w"(vmask1), [vzero] "w"(vzero)
: "q0", "q1", "q2", "q3", "cc", "memory");
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const uint32_t* ptr0 = inptr + y * ldin;
uint32_t* outptr_row_col = outptr_row + y * 4;
int i = 0;
for (; i < x_len - 3; i += 4) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n"
"vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
:
: "q0", "q1", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_pad > 0) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [vmask1] "w"(vmask1), [vzero] "w"(vzero)
: "q0", "q1", "cc", "memory");
}
}
}
#endif // __aarch64__
/**
* \brief input data is transpose
* for arm-v7a, transform data to block x k x 8 layout
* for arm-v8a, transform data to block x k x 12 layout
*/
#ifdef __aarch64__
void loadb(float *out, const float *in, const int ldin, const int k0,
const int kmax, const int n0, const int nmax) {
uint32_t *outptr = reinterpret_cast<uint32_t *>(out);
const uint32_t *inptr =
reinterpret_cast<const uint32_t *>(in) + k0 * ldin + n0;
uint32_t mask_buffer[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
int x_len = nmax - n0;
int y_len = kmax - k0;
int right_remain = x_len - 12 * (x_len / 12);
int right_pad = 12 - right_remain;
const size_t copy_len_remain = sizeof(float) * right_remain;
const size_t copy_len_pad = sizeof(float) * right_pad;
const size_t size_ldin = sizeof(float) * ldin;
uint32_t *outptr_row = outptr;
int stride_out = 12 * y_len;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
uint32x4_t vmask2 =
vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain));
uint32x4_t vmask3 =
vcltq_u32(vld1q_u32(mask_buffer + 8), vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const uint32_t *ptr0 = inptr + y * ldin;
const uint32_t *ptr1 = ptr0 + ldin;
const uint32_t *ptr2 = ptr1 + ldin;
const uint32_t *ptr3 = ptr2 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3)
: "memory");
uint32_t *outptr_row_col = outptr_row + y * 12;
int i = 0;
for (; i < x_len - 11; i += 12) {
uint32x4_t vr00 = vld1q_u32(ptr0);
uint32x4_t vr01 = vld1q_u32(ptr0 + 4);
uint32x4_t vr02 = vld1q_u32(ptr0 + 8);
uint32x4_t vr10 = vld1q_u32(ptr1);
uint32x4_t vr11 = vld1q_u32(ptr1 + 4);
uint32x4_t vr12 = vld1q_u32(ptr1 + 8);
vst1q_u32(outptr_row_col, vr00);
vst1q_u32(outptr_row_col + 4, vr01);
vst1q_u32(outptr_row_col + 8, vr02);
uint32x4_t vr20 = vld1q_u32(ptr2);
uint32x4_t vr21 = vld1q_u32(ptr2 + 4);
uint32x4_t vr22 = vld1q_u32(ptr2 + 8);
vst1q_u32(outptr_row_col + 12, vr10);
vst1q_u32(outptr_row_col + 16, vr11);
vst1q_u32(outptr_row_col + 20, vr12);
uint32x4_t vr30 = vld1q_u32(ptr3);
uint32x4_t vr31 = vld1q_u32(ptr3 + 4);
uint32x4_t vr32 = vld1q_u32(ptr3 + 8);
vst1q_u32(outptr_row_col + 24, vr20);
vst1q_u32(outptr_row_col + 28, vr21);
vst1q_u32(outptr_row_col + 32, vr22);
vst1q_u32(outptr_row_col + 36, vr30);
vst1q_u32(outptr_row_col + 40, vr31);
vst1q_u32(outptr_row_col + 44, vr32);
ptr0 += 12;
ptr1 += 12;
ptr2 += 12;
ptr3 += 12;
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32x4_t vr00 = vld1q_u32(ptr0);
uint32x4_t vr01 = vld1q_u32(ptr0 + 4);
uint32x4_t vr02 = vld1q_u32(ptr0 + 8);
uint32x4_t vr10 = vld1q_u32(ptr1);
uint32x4_t vr11 = vld1q_u32(ptr1 + 4);
uint32x4_t vr12 = vld1q_u32(ptr1 + 8);
uint32x4_t vr00_1 = vbslq_u32(vmask1, vr00, vzero);
uint32x4_t vr01_1 = vbslq_u32(vmask2, vr01, vzero);
uint32x4_t vr02_1 = vbslq_u32(vmask3, vr02, vzero);
uint32x4_t vr20 = vld1q_u32(ptr2);
uint32x4_t vr21 = vld1q_u32(ptr2 + 4);
uint32x4_t vr22 = vld1q_u32(ptr2 + 8);
vst1q_u32(outptr_row_col, vr00_1);
vst1q_u32(outptr_row_col + 4, vr01_1);
vst1q_u32(outptr_row_col + 8, vr02_1);
uint32x4_t vr10_1 = vbslq_u32(vmask1, vr10, vzero);
uint32x4_t vr11_1 = vbslq_u32(vmask2, vr11, vzero);
uint32x4_t vr12_1 = vbslq_u32(vmask3, vr12, vzero);
uint32x4_t vr30 = vld1q_u32(ptr3);
uint32x4_t vr31 = vld1q_u32(ptr3 + 4);
uint32x4_t vr32 = vld1q_u32(ptr3 + 8);
vst1q_u32(outptr_row_col + 12, vr10_1);
vst1q_u32(outptr_row_col + 16, vr11_1);
vst1q_u32(outptr_row_col + 20, vr12_1);
uint32x4_t vr20_1 = vbslq_u32(vmask1, vr20, vzero);
uint32x4_t vr21_1 = vbslq_u32(vmask2, vr21, vzero);
uint32x4_t vr22_1 = vbslq_u32(vmask3, vr22, vzero);
uint32x4_t vr30_1 = vbslq_u32(vmask1, vr30, vzero);
uint32x4_t vr31_1 = vbslq_u32(vmask2, vr31, vzero);
uint32x4_t vr32_1 = vbslq_u32(vmask3, vr32, vzero);
vst1q_u32(outptr_row_col + 24, vr20_1);
vst1q_u32(outptr_row_col + 28, vr21_1);
vst1q_u32(outptr_row_col + 32, vr22_1);
vst1q_u32(outptr_row_col + 36, vr30_1);
vst1q_u32(outptr_row_col + 40, vr31_1);
vst1q_u32(outptr_row_col + 44, vr32_1);
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const uint32_t *ptr0 = inptr + y * ldin;
uint32_t *outptr_row_col = outptr_row + y * 12;
int i = 0;
for (; i < x_len - 11; i += 12) {
uint32x4_t vr0 = vld1q_u32(ptr0);
uint32x4_t vr1 = vld1q_u32(ptr0 + 4);
uint32x4_t vr2 = vld1q_u32(ptr0 + 8);
vst1q_u32(outptr_row_col, vr0);
vst1q_u32(outptr_row_col + 4, vr1);
vst1q_u32(outptr_row_col + 8, vr2);
ptr0 += 12;
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32x4_t vr0 = vld1q_u32(ptr0);
uint32x4_t vr1 = vld1q_u32(ptr0 + 4);
uint32x4_t vr2 = vld1q_u32(ptr0 + 8);
uint32x4_t vr0_1 = vbslq_u32(vmask1, vr0, vzero);
uint32x4_t vr1_1 = vbslq_u32(vmask2, vr1, vzero);
uint32x4_t vr2_1 = vbslq_u32(vmask3, vr2, vzero);
vst1q_u32(outptr_row_col, vr0_1);
vst1q_u32(outptr_row_col + 4, vr1_1);
vst1q_u32(outptr_row_col + 8, vr2_1);
}
}
}
void loadb_trans(float *out, const float *in, const int ldin, const int k0,
const int kmax, const int n0, const int nmax) {
int x_len = kmax - k0;
uint32_t zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(uint32_t) * x_len);
uint32_t *outptr = reinterpret_cast<uint32_t *>(out);
const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in);
//! data B is not transposed, transpose B to k * 12
for (int y = n0; y < nmax; y += 12) {
const uint32_t *inptr0 = inptr + y * ldin + k0;
const uint32_t *inptr1 = inptr0 + ldin;
const uint32_t *inptr2 = inptr1 + ldin;
const uint32_t *inptr3 = inptr2 + ldin;
const uint32_t *inptr4 = inptr3 + ldin;
const uint32_t *inptr5 = inptr4 + ldin;
const uint32_t *inptr6 = inptr5 + ldin;
const uint32_t *inptr7 = inptr6 + ldin;
const uint32_t *inptr8 = inptr7 + ldin;
const uint32_t *inptr9 = inptr8 + ldin;
const uint32_t *inptr10 = inptr9 + ldin;
const uint32_t *inptr11 = inptr10 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
"prfm pldl1keep, [%[ptr4]] \n"
"prfm pldl1keep, [%[ptr4], #64] \n"
"prfm pldl1keep, [%[ptr5]] \n"
"prfm pldl1keep, [%[ptr5], #64] \n"
"prfm pldl1keep, [%[ptr6]] \n"
"prfm pldl1keep, [%[ptr6], #64] \n"
"prfm pldl1keep, [%[ptr7]] \n"
"prfm pldl1keep, [%[ptr7], #64] \n"
"prfm pldl1keep, [%[ptr8]] \n"
"prfm pldl1keep, [%[ptr8], #64] \n"
"prfm pldl1keep, [%[ptr9]] \n"
"prfm pldl1keep, [%[ptr9], #64] \n"
"prfm pldl1keep, [%[ptr10]] \n"
"prfm pldl1keep, [%[ptr10], #64] \n"
"prfm pldl1keep, [%[ptr11]] \n"
"prfm pldl1keep, [%[ptr11], #64] \n"
:
: [ptr0] "r"(inptr0), [ptr1] "r"(inptr1), [ptr2] "r"(inptr2),
[ptr3] "r"(inptr3), [ptr4] "r"(inptr4), [ptr5] "r"(inptr5),
[ptr6] "r"(inptr6), [ptr7] "r"(inptr7), [ptr8] "r"(inptr8),
[ptr9] "r"(inptr9), [ptr10] "r"(inptr10), [ptr11] "r"(inptr11)
: "memory");
int x = x_len;
//! cope with row index exceed real size, set to zero buffer
if ((y + 11) >= nmax) {
switch ((y + 11) - nmax) {
case 10:
inptr1 = zerobuff;
case 9:
inptr2 = zerobuff;
case 8:
inptr3 = zerobuff;
case 7:
inptr4 = zerobuff;
case 6:
inptr5 = zerobuff;
case 5:
inptr6 = zerobuff;
case 4:
inptr7 = zerobuff;
case 3:
inptr8 = zerobuff;
case 2:
inptr9 = zerobuff;
case 1:
inptr10 = zerobuff;
case 0:
inptr11 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3
"LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3
"LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3
"ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1
"prfm pldl1keep, [%[inptr0], #128] \n"
"LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3
"ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1
"LDP q8, q9, [%[inptr4]], #32\n"
"LDP q10, q11, [%[inptr5]], #32\n"
"LDP q12, q13, [%[inptr6]], #32\n"
"ZIP1 v18.4s, v8.4s, v12.4s\n"
"prfm pldl1keep, [%[inptr1], #128]\n"
"LDP q14, q15, [%[inptr7]], #32\n"
"ZIP1 v19.4s, v10.4s, v14.4s\n"
"ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0
"prfm pldl1keep, [%[inptr2], #128]\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"LDP q24, q25, [%[inptr8]], #32\n" // q24=A0A1A2A3
"LDP q26, q27, [%[inptr9]], #32\n" // q26=B0B1B2B3
"LDP q28, q29, [%[inptr10]], #32\n" // q28=C0C1C2C3
"LDP q30, q31, [%[inptr11]], #32\n" // q30=D0D1D2D3
"prfm pldl1keep, [%[inptr3], #128]\n"
"prfm pldl1keep, [%[inptr4], #128]\n"
"ZIP1 v16.4s, v24.4s, v28.4s\n" // q16=A0C0A1C1
"ZIP1 v17.4s, v26.4s, v30.4s\n" // q17=B0D0B1D1
"STP q20, q21, [%[outptr]], #32\n" // Write back the first
// element of each source
"ZIP1 v18.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0
"ZIP2 v19.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0
"ZIP2 v16.4s, v0.4s, v4.4s\n"
"prfm pldl1keep, [%[inptr5], #128]\n"
"ZIP2 v17.4s, v2.4s, v6.4s\n"
"STR q18, [%[outptr]], #16\n" // Write back the second element
// of each source
"STP q22, q23, [%[outptr]], #32\n" // Write back the second
// element of each source
"ZIP2 v18.4s, v8.4s, v12.4s\n"
"prfm pldl1keep, [%[inptr6], #128]\n"
"STR q19, [%[outptr]], #16\n" // Write back the second element
// of each source
"ZIP2 v19.4s, v10.4s, v14.4s\n"
"ZIP1 v20.4s, v16.4s, v17.4s\n"
"prfm pldl1keep, [%[inptr7], #128]\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"ZIP2 v16.4s, v24.4s, v28.4s\n" // q16=A0C0A1C1
"ZIP2 v17.4s, v26.4s, v30.4s\n" // q17=B0D0B1D1
"prfm pldl1keep, [%[inptr8], #128]\n"
"STP q20, q21, [%[outptr]], #32\n" // Third element
"ZIP1 v18.4s, v16.4s, v17.4s\n"
"ZIP2 v19.4s, v16.4s, v17.4s\n"
"ZIP1 v16.4s, v1.4s, v5.4s\n"
"prfm pldl1keep, [%[inptr9], #128]\n"
"ZIP1 v17.4s, v3.4s, v7.4s\n"
"STR q18, [%[outptr]], #16\n" // Write back the second element
// of each source
"STP q22, q23, [%[outptr]], #32\n" // Fourth element
"ZIP1 v18.4s, v9.4s, v13.4s\n"
"prfm pldl1keep, [%[inptr10], #128]\n"
"STR q19, [%[outptr]], #16\n" // Write back the second element
// of each source
"ZIP1 v19.4s, v11.4s, v15.4s\n"
"ZIP1 v20.4s, v16.4s, v17.4s\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"prfm pldl1keep, [%[inptr11], #128]\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"ZIP1 v16.4s, v25.4s, v29.4s\n"
"ZIP1 v17.4s, v27.4s, v31.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Fifth element
"ZIP1 v18.4s, v16.4s, v17.4s\n"
"ZIP2 v19.4s, v16.4s, v17.4s\n"
"ZIP2 v16.4s, v1.4s, v5.4s\n"
"ZIP2 v17.4s, v3.4s, v7.4s\n"
"STR q18, [%[outptr]], #16\n"
"STP q22, q23, [%[outptr]], #32\n" // Sixth element
"ZIP2 v18.4s, v9.4s, v13.4s\n"
"STR q19, [%[outptr]], #16\n" // Sixth element
"ZIP2 v19.4s, v11.4s, v15.4s\n"
"ZIP1 v20.4s, v16.4s, v17.4s\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"ZIP2 v16.4s, v25.4s, v29.4s\n"
"ZIP2 v17.4s, v27.4s, v31.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Seventh element
"ZIP1 v18.4s, v16.4s, v17.4s\n"
"ZIP2 v19.4s, v16.4s, v17.4s\n"
"STR q18, [%[outptr]], #16\n"
"STP q22, q23, [%[outptr]], #32\n" // Eighth element
"STR q19, [%[outptr]], #16\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8),
[inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10),
[inptr11] "+r"(inptr11), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
}
for (; x > 0; x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
*outptr++ = *inptr4++;
*outptr++ = *inptr5++;
*outptr++ = *inptr6++;
*outptr++ = *inptr7++;
*outptr++ = *inptr8++;
*outptr++ = *inptr9++;
*outptr++ = *inptr10++;
*outptr++ = *inptr11++;
}
}
}
#else // __aarch64__
void loadb(float* out, const float* in, const int ldin, const int k0,
const int kmax, const int n0, const int nmax) {
uint32_t* outptr = reinterpret_cast<uint32_t*>(out);
const uint32_t* inptr =
reinterpret_cast<const uint32_t*>(in) + k0 * ldin + n0;
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = nmax - n0;
int y_len = kmax - k0;
int right_remain = x_len - 8 * (x_len / 8);
int right_pad = 8 - right_remain;
const size_t copy_len_remain = sizeof(float) * right_remain;
const size_t copy_len_pad = sizeof(float) * right_pad;
const size_t size_ldin = sizeof(float) * ldin;
uint32_t* outptr_row = outptr;
int stride_out = 8 * y_len;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
uint32x4_t vmask2 =
vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const uint32_t* ptr0 = inptr + y * ldin;
const uint32_t* ptr1 = ptr0 + ldin;
const uint32_t* ptr2 = ptr1 + ldin;
const uint32_t* ptr3 = ptr2 + ldin;
uint32_t* outptr_row_col = outptr_row + y * 8;
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n"
"vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n"
"vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n"
"vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n"
: [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3)
:
: "q0", "q1", "q2", "q3", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n"
"vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n"
//"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
"vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n"
"vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n"
"vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n"
//"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
"vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
"vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n"
: [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero)
: "q0", "q1", "q2", "q3", "cc", "memory");
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const uint32_t* ptr0 = inptr + y * ldin;
uint32_t* outptr_row_col = outptr_row + y * 8;
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
:
: "q0", "q1", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n"
"vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero)
: "q0", "q1", "cc", "memory");
}
}
}
void loadb_trans(float* out, const float* in, const int ldin, const int k0,
const int kmax, const int n0, const int nmax) {
int x_len = kmax - k0;
uint32_t zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(uint32_t) * x_len);
uint32_t* outptr = reinterpret_cast<uint32_t*>(out);
const uint32_t* inptr = reinterpret_cast<const uint32_t*>(in);
//! data B is not transposed, transpose B to k * 8
for (int y = n0; y < nmax; y += 8) {
const uint32_t* inptr0 = inptr + y * ldin + k0;
const uint32_t* inptr1 = inptr0 + ldin;
const uint32_t* inptr2 = inptr1 + ldin;
const uint32_t* inptr3 = inptr2 + ldin;
const uint32_t* inptr4 = inptr3 + ldin;
const uint32_t* inptr5 = inptr4 + ldin;
const uint32_t* inptr6 = inptr5 + ldin;
const uint32_t* inptr7 = inptr6 + ldin;
int x = x_len;
//! cope with row index exceed real size, set to zero buffer
if ((y + 7) >= nmax) {
switch ((y + 7) - nmax) {
case 6:
inptr1 = zerobuff;
case 5:
inptr2 = zerobuff;
case 4:
inptr3 = zerobuff;
case 3:
inptr4 = zerobuff;
case 2:
inptr5 = zerobuff;
case 1:
inptr6 = zerobuff;
case 0:
inptr7 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
//! zip load 8 elements (2 neon Q registers) from each of 8 rows
asm volatile(
"vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, "
"q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n"
"vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, "
"q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n"
"vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; "
"q2=r04,r14,r05,r15\n"
"vst1.32 {d0}, [%[outptr]]! @ write d0(q0,low),r00,r10\n"
"vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, "
"q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n"
"vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, "
"q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n"
"vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; "
"q6=r24,r34,r25,r35\n"
"vst1.32 {d8}, [%[outptr]]! @ write d8(q4,low),r20,r30\n"
"vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, "
"q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n"
"vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, "
"q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n"
"vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; "
"q10=r44,r54,r45,r55\n"
"vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n"
"vld4.32 {d24-d27}, [%[inptr6]]! @ zip load r6, "
"q12,q13=r60,r64,r61,r65,r62,r66,r63,r67\n"
"vld4.32 {d28-d31}, [%[inptr7]]! @ zip load r7, "
"q14,q15=r70,r74,r71,r75,r72,r76,r73,r77\n"
"vtrn.32 q12, q14 @ trans data:q12=r60,r70,r61,r71; "
"q14=r64,r74,r65,r75\n"
"vst1.32 {d24}, [%[outptr]]! @ write d24(q8,low),r60,r70\n"
//"pld [%[inptr0], #128] @ preload r0 data to cache, fill
// pipeline\n"
"vst1.32 {d1}, [%[outptr]]! @ write d1(q0,high),r01,r11\n"
"vst1.32 {d9}, [%[outptr]]! @ write d9(q4,high),r21,r31\n"
"vst1.32 {d17}, [%[outptr]]! @ write d17(q8,high),r41,r51\n"
"vst1.32 {d25}, [%[outptr]]! @ write d25(q12,high),r61,r71\n"
"vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; "
"q3=r06,r16,r07,r17\n"
"vst1.32 {d2}, [%[outptr]]! @ write d2(q1,low),r02,r12\n"
"vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; "
"q7=r26,r36,r27,r37\n"
"vst1.32 {d10}, [%[outptr]]! @ write d10(q5,low),r22,r32\n"
"vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; "
"q11=r46,r56,r47,r57\n"
"vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n"
"vtrn.32 q13, q15 @ trans data:q13=r62,r72,r63,r73; "
"q15=r66,r76,r67,r77\n"
"vst1.32 {d26}, [%[outptr]]! @ write d18(q9,low),r62,r72\n"
//"pld [%[inptr1], #128] @ preload r1 data to cache, fill
// pipeline\n"
"vst1.32 {d3}, [%[outptr]]! @ write d3(q1,high),r03,r13\n"
"vst1.32 {d11}, [%[outptr]]! @ write d11(q5,high),r23,r33\n"
"vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n"
"vst1.32 {d27}, [%[outptr]]! @ write d27(q13,high),r63,r73\n"
//"pld [%[inptr2], #128] @ preload r2 data to cache, fill
// pipeline\n"
"vst1.32 {d4}, [%[outptr]]! @ write d4(q2,low),r04,r14\n"
"vst1.32 {d12}, [%[outptr]]! @ write d12(q6,low),r24,r34\n"
"vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n"
"vst1.32 {d28}, [%[outptr]]! @ write d28(q14,low),r64,r74\n"
//"pld [%[inptr3], #128] @ preload r3 data to cache, fill
// pipeline\n"
"vst1.32 {d5}, [%[outptr]]! @ write d5(q2,high),r05,r15\n"
"vst1.32 {d13}, [%[outptr]]! @ write d13(q6,high),r25,r35\n"
"vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n"
"vst1.32 {d29}, [%[outptr]]! @ write d29(q14,high),r65,r75\n"
//"pld [%[inptr4], #128] @ preload r4 data to cache, fill
// pipeline\n"
"vst1.32 {d6}, [%[outptr]]! @ write d6(q3,low),r06,r16\n"
"vst1.32 {d14}, [%[outptr]]! @ write d14(q7,low),r26,r36\n"
"vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n"
"vst1.32 {d30}, [%[outptr]]! @ write d30(q15,low),r66,r76\n"
//"pld [%[inptr5], #128] @ preload r5 data to cache, fill
// pipeline\n"
"vst1.32 {d7}, [%[outptr]]! @ write d7(q3,high),r07,r17\n"
"vst1.32 {d15}, [%[outptr]]! @ write d15(q7,high),r27,r37\n"
"vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n"
"vst1.32 {d31}, [%[outptr]]! @ write d31(q15,high),r67,r77\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q13", "q14", "q15", "cc", "memory");
}
for (; x > 0; x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
*outptr++ = *inptr4++;
*outptr++ = *inptr5++;
*outptr++ = *inptr6++;
*outptr++ = *inptr7++;
}
}
}
#endif // __aarch64__
#ifdef __aarch64__
void sgemm_conv_8x12(const float *A_packed, const float *B, const float *bias,
float *C, int M, int N, int K, bool is_bias, bool is_relu,
bool transB, ARMContext *ctx) {
size_t l2_cache =
ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024;
float *workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block = (l2_cache - (MBLOCK * K)) / (sizeof(float) * (K + MBLOCK));
x_block /= NBLOCK;
x_block *= NBLOCK;
int x_num = (N + (x_block - 1)) / x_block;
x_block = (N + x_num - 1) / x_num;
x_block = (x_block + NBLOCK - 1) / NBLOCK;
x_block *= NBLOCK;
x_block = x_block < NBLOCK ? NBLOCK : x_block;
// unroll 2 loop
int tail_pre = (K & (KBLOCK - 1));
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;
bool flag_p_remain = false;
int remain = 0;
//! apanel is pre_compute outside gemm
for (unsigned int x0 = 0; x0 < N; x0 += x_block) {
unsigned int xmax = x0 + x_block;
if (xmax > N) {
xmax = N;
}
int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK;
remain = xmax - x0 - (bblocks - 1) * NBLOCK;
if (remain > 0) {
flag_p_remain = true;
}
//! load bpanel
float *b_pannel = workspace;
if (transB) {
loadb_trans(b_pannel, B, K, 0, K, x0, xmax);
} else {
loadb(b_pannel, B, N, 0, K, x0, xmax);
}
#pragma omp parallel for num_threads(threads)
for (unsigned int y = 0; y < M; y += MBLOCK) {
unsigned int ymax = y + MBLOCK;
if (ymax > M) {
ymax = M;
}
float bias_local[8] = {0};
if (is_bias) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
bias_local[4] = bias[y + 4];
bias_local[5] = bias[y + 5];
bias_local[6] = bias[y + 6];
bias_local[7] = bias[y + 7];
}
float cout0[NBLOCK];
float cout1[NBLOCK];
float cout2[NBLOCK];
float cout3[NBLOCK];
float cout4[NBLOCK];
float cout5[NBLOCK];
float cout6[NBLOCK];
float cout7[NBLOCK];
float *c_ptr0 = C + y * N + x0;
float *c_ptr1 = c_ptr0 + N;
float *c_ptr2 = c_ptr1 + N;
float *c_ptr3 = c_ptr2 + N;
float *c_ptr4 = c_ptr3 + N;
float *c_ptr5 = c_ptr4 + N;
float *c_ptr6 = c_ptr5 + N;
float *c_ptr7 = c_ptr6 + N;
float *pout0 = c_ptr0;
float *pout1 = c_ptr1;
float *pout2 = c_ptr2;
float *pout3 = c_ptr3;
float *pout4 = c_ptr4;
float *pout5 = c_ptr5;
float *pout6 = c_ptr6;
float *pout7 = c_ptr7;
const float *a_ptr_l = A_packed + y * K;
const float *b_ptr = b_pannel;
for (int xb = 0; xb < bblocks; xb++) {
if ((y + 7) >= ymax) {
switch ((y + 7) - ymax) {
case 6:
c_ptr1 = cout1;
case 5:
c_ptr2 = cout2;
case 4:
c_ptr3 = cout3;
case 3:
c_ptr4 = cout4;
case 2:
c_ptr5 = cout5;
case 1:
c_ptr6 = cout6;
case 0:
c_ptr7 = cout7;
default:
break;
}
}
if (flag_p_remain && (xb == bblocks - 1)) {
pout0 = c_ptr0;
pout1 = c_ptr1;
pout2 = c_ptr2;
pout3 = c_ptr3;
pout4 = c_ptr4;
pout5 = c_ptr5;
pout6 = c_ptr6;
pout7 = c_ptr7;
c_ptr0 = cout0;
c_ptr1 = cout1;
c_ptr2 = cout2;
c_ptr3 = cout3;
c_ptr4 = cout4;
c_ptr5 = cout5;
c_ptr6 = cout6;
c_ptr7 = cout7;
}
const float *a_ptr = a_ptr_l;
int tail = tail_pre;
int k = k_pre;
asm volatile(
// Initialize result registers, load initial operands, prime
// prefetches.
"ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/
"dup v8.4s, v2.s[0]\n" /* out0 = 0 */
"dup v9.4s, v2.s[0]\n" /* out1 = 0*/
"dup v10.4s, v2.s[0]\n" /* out2 = 0*/
"dup v11.4s, v2.s[1]\n" /* out3 = 0*/
"dup v12.4s, v2.s[1]\n" /* out4 = 0*/
"prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/
"dup v13.4s, v2.s[1]\n" /* out5 = 0*/
"prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/
"dup v14.4s, v2.s[2]\n" /* out6 = 0*/
"prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/
"dup v15.4s, v2.s[2]\n" /* out7 = 0*/
"prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/
"dup v16.4s, v2.s[2]\n" /* out8 = 0*/
"prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/
"dup v17.4s, v2.s[3]\n" /* out9 = 0*/
"prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/
"dup v18.4s, v2.s[3]\n" /* out10 = 0*/
"prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/
"dup v19.4s, v2.s[3]\n" /* out11 = 0*/
"prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/
"dup v20.4s, v3.s[0]\n" /* out12 = 0*/
"prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/
"dup v21.4s, v3.s[0]\n" /* out13 = 0*/
"prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/
"dup v22.4s, v3.s[0]\n" /* out14 = 0*/
"dup v23.4s, v3.s[1]\n" /* out15 = 0*/
"dup v24.4s, v3.s[1]\n" /* out16 = 0*/
"dup v25.4s, v3.s[1]\n" /* out17 = 0*/
"dup v26.4s, v3.s[2]\n" /* out18 = 0*/
"dup v27.4s, v3.s[2]\n" /* out19 = 0*/
"dup v28.4s, v3.s[2]\n" /* out20 = 0*/
"dup v29.4s, v3.s[3]\n" /* out21 = 0*/
"dup v30.4s, v3.s[3]\n" /* out22 = 0*/
"dup v31.4s, v3.s[3]\n" /* out23 = 0*/
"cbz %w[k], 2f\n" /* check loop count > 0 */
/* main loop */
/* unrool 0*/
"1:\n" /* main loop */
"fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =
q4 */
"fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =
q4 */
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7 */
"fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =
q4 */
"fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =
q4 */
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4 */
"fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 =
q4 */
"fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 =
q4 */
"fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 =
q4 */
"fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 =
q4 */
"fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 =
q5 */
"fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 =
q5 */
"fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 =
q5*/
"fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 =
q5*/
"fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 =
q5*/
"fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 =
q5*/
"fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 =
q5*/
"fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 =
q5*/
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5 */
"fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 =
q6*/
"fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 =
q6*/
"prfm pldl1keep, [%[b_ptr], #384]\n"
"fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 =
q6*/
"fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 =
q6*/
"fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 =
q6*/
"fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 =
q6*/
"fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 =
q6*/
"fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 =
q6*/
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */
/* unrool 1 */
"fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =
q7 */
"fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =
q7 */
"fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =
q7 */
"prfm pldl1keep, [%[a_ptr], #256]\n"
"fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =
q7 */
"fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =
q7 */
"fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7
*/
"fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =
q7 */
"fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =
q7 */
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */
"fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =
q4 */
"fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 =
q4 */
"fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =
q4*/
"fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =
q4*/
"fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =
q4*/
"fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =
q4*/
"fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =
q4*/
"fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =
q4*/
"fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =
q5*/
"fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =
q5*/
"fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =
q5*/
"fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =
q5*/
"fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =
q5*/
"fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =
q5*/
"fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =
q5*/
"fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =
q5*/
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */
/* unrool 2*/
"fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =
q6 */
"fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =
q6 */
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/
"fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =
q6*/
"fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =
q6*/
"fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 =
q6*/
"fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 =
q6*/
"fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 =
q6*/
"fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 =
q6*/
"fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 =
q7*/
"fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 =
q7*/
"prfm pldl1keep, [%[b_ptr], #384]\n"
"fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 =
q7*/
"fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 =
q7*/
"fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 =
q7*/
"fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 =
q7*/
"fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 =
q7*/
"fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 =
q7*/
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/
"fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 =
q4*/
"fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 =
q4*/
"fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 =
q4*/
"fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 =
q4*/
"fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 =
q4*/
"fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 =
q4*/
"fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 =
q4*/
"fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 =
q4*/
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/
/* unrool 3*/
"fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =
q5*/
"fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =
q5*/
"fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =
q5*/
"fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =
q5*/
"fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =
q5*/
"fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =
q5*/
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/
"fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =
q6*/
"fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 =
q6*/
"prfm pldl1keep, [%[a_ptr], #256]\n"
"fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =
q6*/
"prfm pldl1keep, [%[b_ptr], #384]\n"
"fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =
q7*/
"subs %w[k], %w[k], #1\n" /* loop count - 1*/
"fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =
q7*/
"bne 1b\n"
/* Target to use when K is 1 or 2 (i.e. zero iterations of main
loop)*/
"2:\n" /* process tail*/
"subs %w[tail], %w[tail], #1\n" /* tail--*/
"beq 3f\n" /*jump to tail = 1*/
/* final unrool 0*/
/* unrool 0, tail > 1*/
"fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =
q4*/
"fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =
q4*/
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/
"fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =
q4*/
"fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =
q4*/
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q2, q3*/
"fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 =
q4*/
"fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 =
q4*/
"fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 =
q4*/
"fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 =
q4*/
"subs %w[tail], %w[tail], #1\n" /* tail--*/
"fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 =
q5*/
"fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 =
q5*/
"fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 =
q5*/
"fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 =
q5*/
"fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 =
q5*/
"fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 =
q5*/
"fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 =
q5*/
"fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 =
q5*/
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5*/
"fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 =
q6*/
"fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 =
q6*/
"fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 =
q6*/
"fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 =
q6*/
"fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 =
q6*/
"fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 =
q6*/
"fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 =
q6*/
"fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 =
q6*/
"beq 4f\n" /*jump to tail = 2*/
/* unrool 1, tail > 2*/
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/
"fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =
q7*/
"fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =
q7*/
"fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =
q7*/
"fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =
q7*/
"fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =
q7*/
"fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7*/
"fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =
q7*/
"fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =
q7*/
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7*/
"fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =
q4*/
"fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 =
q4*/
"fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =
q4*/
"fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =
q4*/
"fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =
q4*/
"fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =
q4*/
"fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =
q4*/
"fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =
q4*/
"subs %w[tail], %w[tail], #1\n" /* tail--*/
"fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =
q5*/
"fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =
q5*/
"fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =
q5*/
"fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =
q5*/
"fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =
q5*/
"fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =
q5*/
"fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =
q5*/
"fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =
q5*/
"beq 5f\n" /*jump to tail = 3*/
/* unrool 2, tail = 4*/
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/
"fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =
q6*/
"fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =
q6*/
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/
"fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =
q6*/
"fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =
q6*/
"fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 =
q6*/
"fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 =
q6*/
"fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 =
q6*/
"fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 =
q6*/
"fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 =
q7*/
"fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 =
q7*/
"fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 =
q7*/
"fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 =
q7*/
"fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 =
q7*/
"fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 =
q7*/
"fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 =
q7*/
"fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 =
q7*/
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/
"fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 =
q4*/
"fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 =
q4*/
"fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 =
q4*/
"fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 =
q4*/
"fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 =
q4*/
"fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 =
q4*/
"fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 =
q4*/
"fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 =
q4*/
/* unrool 3, tail = 4*/
"fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =
q5*/
"fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =
q5*/
"fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =
q5*/
"fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =
q5*/
"fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =
q5*/
"fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =
q5*/
"fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =
q6*/
"fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 =
q6*/
"fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =
q6*/
"fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =
q7*/
"fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =
q7*/
"b 11f\n"
/* tails==1 final tail*/
"3: \n" /* tail=1*/
"ldr q6, [%[b_ptr]], #16\n" /* load b2 to q6*/
"fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 =
q5*/
"fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 =
q5*/
"fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 =
q5*/
"fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 =
q5*/
"fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 =
q5*/
"fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 =
q5*/
"fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 =
q6*/
"fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 =
q6*/
"fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 =
q6*/
"fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 =
q7*/
"fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 =
q7*/
"b 11f\n"
/* tails==2 final tail*/
"4:\n" /* tail = 2*/
"fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =
q5*/
"fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =
q5*/
"fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =
q5*/
"fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =
q5*/
"fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =
q5*/
"fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =
q5*/
"fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =
q6*/
"fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 =
q6*/
"fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =
q6*/
"fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =
q7*/
"fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =
q7*/
"b 11f\n"
/* tails==3 final tail*/
"5:\n" /* tail = 3*/
"ldr q4, [%[b_ptr]], #16\n" /* load b2, b0 to q4*/
"fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 =
q5*/
"fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 =
q5*/
"fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 =
q5*/
"fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 =
q5*/
"fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 =
q5*/
"fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 =
q5*/
"fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 =
q6*/
"fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 =
q6*/
"fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 =
q6*/
"fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 =
q7*/
"fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 =
q7*/
"11: \n" /* check if relu */
"cbz %w[relu], 12f\n" /* skip relu */
"movi v2.4s, #0\n" /* for relu*/
"fmax v8.4s, v8.4s, v2.4s\n" /* relu*/
"fmax v9.4s, v9.4s, v2.4s\n" /* relu*/
"fmax v10.4s, v10.4s, v2.4s\n" /* relu*/
"fmax v11.4s, v11.4s, v2.4s\n" /* relu*/
"fmax v12.4s, v12.4s, v2.4s\n" /* relu*/
"fmax v13.4s, v13.4s, v2.4s\n" /* relu*/
"fmax v14.4s, v14.4s, v2.4s\n" /* relu*/
"fmax v15.4s, v15.4s, v2.4s\n" /* relu*/
"fmax v16.4s,v16.4s,v2.4s\n" /* relu*/
"fmax v17.4s,v17.4s,v2.4s\n" /* relu*/
"fmax v18.4s, v18.4s, v2.4s\n" /* relu*/
"fmax v19.4s, v19.4s, v2.4s\n" /* relu*/
"fmax v20.4s, v20.4s, v2.4s\n" /* relu*/
"fmax v21.4s, v21.4s, v2.4s\n" /* relu*/
"fmax v22.4s, v22.4s, v2.4s\n" /* relu*/
"fmax v23.4s, v23.4s, v2.4s\n" /* relu*/
"fmax v24.4s,v24.4s,v2.4s\n" /* relu*/
"fmax v25.4s,v25.4s,v2.4s\n" /* relu*/
"fmax v26.4s, v26.4s, v2.4s\n" /* relu*/
"fmax v27.4s, v27.4s, v2.4s\n" /* relu*/
"fmax v28.4s, v28.4s, v2.4s\n" /* relu*/
"fmax v29.4s, v29.4s, v2.4s\n" /* relu*/
"fmax v30.4s, v30.4s, v2.4s\n" /* relu*/
"fmax v31.4s, v31.4s, v2.4s\n" /* relu*/
"12: \n"
"st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */
"st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */
"st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */
"st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48\n" /* store r3 */
"st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48\n" /* store r4 */
"st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48\n" /* store r5 */
"st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */
"st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k),
[tail] "+r"(tail), [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1),
[c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3),
[c_ptr4] "+r"(c_ptr4), [c_ptr5] "+r"(c_ptr5),
[c_ptr6] "+r"(c_ptr6), [c_ptr7] "+r"(c_ptr7)
: [bias_ptr] "r"(bias_local), [relu] "r"(is_relu)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i];
*pout1++ = cout1[i];
*pout2++ = cout2[i];
*pout3++ = cout3[i];
*pout4++ = cout4[i];
*pout5++ = cout5[i];
*pout6++ = cout6[i];
*pout7++ = cout7[i];
}
}
}
}
}
}
#else // __aarch64__
/**
* \brief gemm with ablock = 6, bblock = 8, output 6x8
* @param A
* @param B
* @param C
* @param M
* @param N
* @param K
* @param threads
* @param workspace
*/
void sgemm_conv_6x8(const float* A_packed, const float* B, const float* bias,
float* C, int M, int N, int K, bool is_bias, bool is_relu,
bool transB, ARMContext* ctx) {
size_t l2_cache =
ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024;
auto* workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block =
(l2_cache - (MBLOCK_OTH * K)) / (sizeof(float) * (K + MBLOCK_OTH));
x_block /= NBLOCK;
x_block *= NBLOCK;
int x_num = (N + (x_block - 1)) / x_block;
x_block = (N + x_num - 1) / x_num;
x_block = (x_block + NBLOCK - 1) / NBLOCK;
x_block *= NBLOCK;
x_block = x_block < NBLOCK ? NBLOCK : x_block;
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;
int tail_pre = (K & (KBLOCK - 1));
if (tail_pre == 0) {
tail_pre = KBLOCK;
}
bool flag_p_remain = false;
int remain = 0;
//! apanel is pre_compute outside gemm
for (unsigned int x0 = 0; x0 < N; x0 += x_block) {
unsigned int xmax = x0 + x_block;
if (xmax > N) {
xmax = N;
}
int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK;
remain = xmax - x0 - (bblocks - 1) * NBLOCK;
if (remain > 0) {
flag_p_remain = true;
}
//! load bpanel
float* b_pannel = workspace;
if (transB) {
loadb_trans(b_pannel, B, K, 0, K, x0, xmax);
} else {
loadb(b_pannel, B, N, 0, K, x0, xmax);
}
#pragma omp parallel for num_threads(threads)
for (unsigned int y = 0; y < M; y += MBLOCK_OTH) {
unsigned int ymax = y + MBLOCK_OTH;
if (ymax > M) {
ymax = M;
}
float* c_ptr0 = C + y * N + x0;
float* c_ptr1 = c_ptr0 + N;
float* c_ptr2 = c_ptr1 + N;
float* c_ptr3 = c_ptr2 + N;
float* c_ptr4 = c_ptr3 + N;
float* c_ptr5 = c_ptr4 + N;
float* pout0 = c_ptr0;
float* pout1 = c_ptr1;
float* pout2 = c_ptr2;
float* pout3 = c_ptr3;
float* pout4 = c_ptr4;
float* pout5 = c_ptr5;
float bias_local[6] = {0};
if (is_bias) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
bias_local[4] = bias[y + 4];
bias_local[5] = bias[y + 5];
}
float cout0[NBLOCK];
float cout1[NBLOCK];
float cout2[NBLOCK];
float cout3[NBLOCK];
float cout4[NBLOCK];
float cout5[NBLOCK];
const float* a_ptr_l = A_packed + y * K;
const float* b_ptr = b_pannel;
for (int xb = 0; xb < bblocks; xb++) {
if ((y + 5) >= ymax) {
switch ((y + 5) - ymax) {
case 4:
c_ptr1 = cout1;
case 3:
c_ptr2 = cout2;
case 2:
c_ptr3 = cout3;
case 1:
c_ptr4 = cout4;
case 0:
c_ptr5 = cout5;
default:
break;
}
}
if (flag_p_remain && (xb == bblocks - 1)) {
pout0 = c_ptr0;
pout1 = c_ptr1;
pout2 = c_ptr2;
pout3 = c_ptr3;
pout4 = c_ptr4;
pout5 = c_ptr5;
c_ptr0 = cout0;
c_ptr1 = cout1;
c_ptr2 = cout2;
c_ptr3 = cout3;
c_ptr4 = cout4;
c_ptr5 = cout5;
}
const float* a_ptr = a_ptr_l;
int tails = tail_pre;
int k = k_pre;
asm volatile(
// sgemm 6x8
"vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements\n"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n"
"pld [%[a_ptr]] @ preload a\n"
"vdup.i32 q12,d4[0] @ out40=0\n"
"pld [%[b_ptr]] @ preload b\n"
"vdup.i32 q13,d4[0] @ out41=0\n"
"pld [%[a_ptr], #64] @ preload a\n"
"vdup.i32 q14,d4[1] @ out50=0\n"
"pld [%[b_ptr], #64] @ preload b\n"
"vdup.i32 q15,d4[1] @ out51=0\n"
"pld [%[a_ptr], #128] @ preload a\n"
"vdup.i32 q4, d2[0] @ out00=0\n"
"pld [%[b_ptr], #128] @ preload b\n"
"vdup.i32 q5, d2[0] @ out01=0\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"vdup.i32 q6, d2[1] @ out10=0\n"
"pld [%[a_ptr], #192] @ preload a\n"
"vdup.i32 q7, d2[1] @ out11=0\n"
"pld [%[b_ptr], #192] @ preload a\n"
"vdup.i32 q8, d3[0] @ out20=0\n"
"pld [%[a_ptr], #256] @ preload a\n"
"vdup.i32 q9, d3[0] @ out21=0\n"
"pld [%[b_ptr], #256] @ preload a\n"
"vdup.i32 q10,d3[1] @ out30=0\n"
"pld [%[b_ptr], #320] @ preload b\n"
"vdup.i32 q11,d3[1] @ out31=0\n"
"pld [%[b_ptr], #384] @ preload b\n"
"cmp %[k], #0 @ check weather k is "
"bigger than 0\n"
"beq 0f @ jump to tail\n"
"1: @ main loop for k\n"
/* Unroll 0*/
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4, a5, and next "
"a0, "
"a1\n"
"vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n"
"vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n"
"vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n"
"vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
/* Unroll 1 */
"vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n"
"vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n"
/*"pld [%[a_ptr], #64] @ preload a\n"*/
"vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n"
/*"pld [%[b_ptr], #192]\n"*/
"vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n"
"vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n"
"vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4, a5, a0, a1\n"
/* Unroll 2 */
"vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n"
"vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n"
/*"pld [%[a_ptr], #240] @ preload\n"*/
"vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n"
/*"pld [%[b_ptr], #208]\n"*/
"vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n"
"vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
/* Unroll 3 */
"vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n"
"vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n"
"vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n"
"vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n"
"subs %[k], %[k], #1 @ k--\n"
"vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n"
"bne 1b @ jump to main "
"loop\n"
"0: @ process tail\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"beq 3f @ jump to tail = "
"1\n"
/* Unroll 0*/
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4,5, a0, a1\n"
"vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n"
"vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n"
"vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n"
"vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"beq 4f @ jump to tail==2\n"
/* Unroll 1*/
"vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n"
"vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n"
"vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n"
"vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"beq 5f @ jump to tail==3\n"
/* Unroll 2 */
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4,a5, a0,a1\n"
"vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n"
"vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n"
"vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n"
"vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n"
"vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
/* Unroll 3*/
"vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n"
"vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n"
"vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n"
"vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n"
"vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n"
"vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n"
"b 2f\n"
/* tails==1 final tail*/
"3: @ tail=1\n"
"vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n"
"vld1.32 {d2}, [%[a_ptr] :64]! @ load a4,a5\n"
"vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n"
"vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n"
"vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n"
"vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n"
"b 2f @ jump to end\n"
/* tails==2 final tail*/
"4: @ tail == 2\n"
"vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n"
"vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n"
"vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n"
"vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n"
"vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n"
"vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n"
"b 2f @ jump to end\n"
/* tails==3 final tail*/
"5: @ tail=3\n"
"vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n"
"vld1.32 {d0}, [%[a_ptr] :64]! @ load a4,a5\n"
"vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n"
"vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n"
"vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n"
"vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n"
"vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n"
"vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n"
"2: @ check relu\n"
"cmp %[relu], #0 @ check if has relu\n"
"ble 6f @ skip relu if relu <= 0\n"
"vmov.u32 q0, #0 @ for relu\n"
"vmax.f32 q4, q4, q0 @ for relu\n"
"vmax.f32 q5, q5, q0 @ for relu\n"
"vmax.f32 q6, q6, q0 @ for relu\n"
"vmax.f32 q7, q7, q0 @ for relu\n"
"vmax.f32 q8, q8, q0 @ for relu\n"
"vmax.f32 q9, q9, q0 @ for relu\n"
"vmax.f32 q10, q10, q0 @ for relu\n"
"vmax.f32 q11, q11, q0 @ for relu\n"
"vmax.f32 q12, q12, q0 @ for relu\n"
"vmax.f32 q13, q13, q0 @ for relu\n"
"vmax.f32 q14, q14, q0 @ for relu\n"
"vmax.f32 q15, q15, q0 @ for relu\n"
"6: @ store result\n"
"vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n"
"vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n"
"vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n"
"vst1.32 {d20-d23}, [%[c_ptr3]]! @ store r3\n"
"vst1.32 {d24-d27}, [%[c_ptr4]]! @ store r4\n"
"vst1.32 {d28-d31}, [%[c_ptr5]]! @ store r5\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr0] "+r"(c_ptr0),
[c_ptr1] "+r"(c_ptr1), [c_ptr2] "+r"(c_ptr2),
[c_ptr3] "+r"(c_ptr3), [c_ptr4] "+r"(c_ptr4),
[c_ptr5] "+r"(c_ptr5), [k] "+r"(k), [tails] "+r"(tails)
: [bias_ptr] "r"(bias_local), [relu] "r"(is_relu)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q13", "q14", "q15", "cc", "memory");
if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i];
*pout1++ = cout1[i];
*pout2++ = cout2[i];
*pout3++ = cout3[i];
*pout4++ = cout4[i];
*pout5++ = cout5[i];
}
}
}
}
}
}
void sgemm_conv_4x8(const float* A_packed, const float* B, const float* bias,
float* C, int M, int N, int K, bool is_bias, bool is_relu,
bool transB, ARMContext* ctx) {
size_t l2_cache =
ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024;
void* workspace = ctx->get_work_space();
int threads = ctx->threads();
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block =
(l2_cache - (MBLOCK_A73 * K)) / (sizeof(float) * (K + MBLOCK_A73));
x_block /= NBLOCK;
x_block *= NBLOCK;
int x_num = (N + (x_block - 1)) / x_block;
x_block = (N + x_num - 1) / x_num;
x_block = (x_block + NBLOCK - 1) / NBLOCK;
x_block *= NBLOCK;
x_block = x_block < NBLOCK ? NBLOCK : x_block;
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;
int tail_pre = (K & (KBLOCK - 1));
if (tail_pre == 0) {
tail_pre = KBLOCK;
}
bool flag_p_remain = false;
int remain = 0;
//! apanel is pre_compute outside gemm
for (unsigned int x0 = 0; x0 < N; x0 += x_block) {
unsigned int xmax = x0 + x_block;
if (xmax > N) {
xmax = N;
}
int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK;
remain = xmax - x0 - (bblocks - 1) * NBLOCK;
if (remain > 0) {
flag_p_remain = true;
}
//! load bpanel
float* b_pannel = static_cast<float*>(workspace);
if (transB) {
loadb_trans(b_pannel, B, K, 0, K, x0, xmax);
} else {
loadb(b_pannel, B, N, 0, K, x0, xmax);
}
#pragma omp parallel for num_threads(threads)
for (unsigned int y = 0; y < M; y += MBLOCK_A73) {
unsigned int ymax = y + MBLOCK_A73;
if (ymax > M) {
ymax = M;
}
float cout0[NBLOCK];
float cout1[NBLOCK];
float cout2[NBLOCK];
float cout3[NBLOCK];
float bias_local[4] = {0};
if (is_bias) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
}
float* c_ptr0 = C + y * N + x0;
float* c_ptr1 = c_ptr0 + N;
float* c_ptr2 = c_ptr1 + N;
float* c_ptr3 = c_ptr2 + N;
float* pout0 = c_ptr0;
float* pout1 = c_ptr1;
float* pout2 = c_ptr2;
float* pout3 = c_ptr3;
const float* a_ptr_l = A_packed + y * K;
const float* b_ptr = b_pannel;
for (int xb = 0; xb < bblocks; xb++) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
case 2:
c_ptr1 = cout1;
case 1:
c_ptr2 = cout1;
case 0:
c_ptr3 = cout1;
default:
break;
}
}
if (flag_p_remain && (xb == bblocks - 1)) {
pout0 = c_ptr0;
pout1 = c_ptr1;
pout2 = c_ptr2;
pout3 = c_ptr3;
c_ptr0 = cout0;
c_ptr1 = cout1;
c_ptr2 = cout2;
c_ptr3 = cout3;
}
const float* a_ptr = a_ptr_l;
int tails = tail_pre;
int k = k_pre;
asm volatile(
"vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias\n"
"vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load a0~a3\n"
"vdup.32 q8, d4[0] @ add bias to out00\n"
"pld [%[a_ptr]] @ preload a, 64byte\n"
"vdup.32 q9, d4[0] @ add bias to out01\n"
"pld [%[b_ptr]] @ preload b\n"
"vdup.32 q10, d4[1] @ add bias to out10\n"
"pld [%[a_ptr], #64] @ preload a\n"
"vdup.32 q11, d4[1] @ add bias to out11\n"
"vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load b1\n"
"vdup.32 q12, d5[0] @ add bias to out20\n"
"pld [%[b_ptr], #64] @ preload b\n"
"vdup.32 q13, d5[0] @ add bias to out21\n"
"pld [%[a_ptr], #128] @ preload a\n"
"vdup.32 q14, d5[1] @ add bias to out30\n"
"pld [%[b_ptr], #128] @ preload b\n"
"vdup.32 q15, d5[1] @ add bias to out31\n"
"pld [%[b_ptr], #192] @ preload b\n"
"cmp %[k], #0 @ check weather k is "
"bigger than 0\n"
"beq 0f @ jump to tail\n"
"1: @ main loop for k\n"
/* Unroll 0*/
"vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n"
"vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n"
"vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n"
"vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n"
"vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n"
/* Unroll 1 */
"vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n"
"pld [%[b_ptr], #64] @ preload b\n"
"vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n"
"vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n"
"vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n"
"vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n"
"vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n"
/* Unroll 2 */
"vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n"
"vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load next a0~a3\n"
"vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n"
"vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n"
/* Unroll 3 */
"vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n"
"pld [%[a_ptr], #64] @ preload a\n"
"vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n"
"subs %[k], %[k], #1 @ k--\n"
"bne 1b @ jump to main "
"loop\n"
"0: @ process tail\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"beq 3f @ jump to tail = "
"1\n"
/* Unroll 0*/
"vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n"
"vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n"
"vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n"
"beq 4f @ jump to tail==2\n"
/* Unroll 1 */
"vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n"
"vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n"
"vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n"
"vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n"
"vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n"
"vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n"
"vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n"
"beq 5f @ jump to tail==3\n"
/* Unroll 2 */
"vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n"
"vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n"
"vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n"
/* Unroll 3 */
"vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n"
"vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n"
"b 2f\n"
/* tails==1 final tail */
"3: @ tail=1\n"
"vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n"
"vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n"
/*aptr - 16 */
"sub %[a_ptr], %[a_ptr], #16 @ tail--\n"
"b 2f @ jump to end\n"
/* tails==2 final tail*/
"4: @ tail == 2\n"
"vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n"
"vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q7, d2[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q7, d2[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q7, d3[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q7, d3[1] @ out7 += b2 * a3\n"
"b 2f @ jump to end\n"
/* tails==3 final tail*/
"5: @ tail=3\n"
"vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n"
"vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n"
"vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n"
"vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n"
"vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n"
"vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n"
"vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n"
"vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n"
/*aptr - 16*/
"sub %[a_ptr], %[a_ptr], #16 @ tail--\n"
"2: @ check relu\n"
"cmp %[relu], #0 @ check if has relu\n"
"ble 6f @ skip relu if relu <= 0\n"
"vmov.u32 q0, #0 @ for relu\n"
"vmax.f32 q8, q8, q0 @ for relu\n"
"vmax.f32 q9, q9, q0 @ for relu\n"
"vmax.f32 q10, q10, q0 @ for relu\n"
"vmax.f32 q11, q11, q0 @ for relu\n"
"vmax.f32 q12, q12, q0 @ for relu\n"
"vmax.f32 q13, q13, q0 @ for relu\n"
"vmax.f32 q14, q14, q0 @ for relu\n"
"vmax.f32 q15, q15, q0 @ for relu\n"
"6: @ store result\n"
"vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n"
"vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n"
"vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n"
"vst1.32 {d28-d31}, [%[c_ptr3]]! @ store r3\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr0] "+r"(c_ptr0),
[c_ptr1] "+r"(c_ptr1), [c_ptr2] "+r"(c_ptr2),
[c_ptr3] "+r"(c_ptr3), [k] "+r"(k), [tails] "+r"(tails)
: [bias_ptr] "r"(bias_local), [relu] "r"(is_relu)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10",
"q11", "q12", "q13", "q14", "q15", "cc", "memory");
if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i];
*pout1++ = cout1[i];
*pout2++ = cout2[i];
*pout3++ = cout3[i];
}
}
}
}
}
}
#endif // __aarch64__
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/cpu_info.h"
#include "paddle/fluid/lite/core/lite_tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#ifdef __aarch64__
constexpr int MBLOCK = 8;
constexpr int NBLOCK = 12;
constexpr int KBLOCK = 4;
inline int get_hblock(ARMArch arch) { return MBLOCK; }
#else
constexpr int MBLOCK_A73 = 4;
constexpr int MBLOCK_OTH = 6;
constexpr int NBLOCK = 8;
constexpr int KBLOCK = 4;
inline int get_hblock(ARMArch arch) {
if (arch == kA73) {
return MBLOCK_A73;
} else {
return MBLOCK_OTH;
}
}
#endif // __aarch64__
void prepackA(float* out, const float* in, const int ldin, const int m0,
const int mmax, const int k0, const int kmax, bool is_trans,
ARMContext* ctx);
void prepackA(TensorLite* tout, const TensorLite& tin, int m, int k, int group,
bool is_trans, ARMContext* ctx);
void sgemm_prepack(const float* A_packed, const float* B, const float* bias,
float* C, int M, int N, int K, bool is_bias, bool is_relu,
bool is_transB, ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -23,7 +23,8 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_l
cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite)
cc_library(scope_lite SRCS scope.cc)
cc_library(context_lite SRCS context.cc DEPS any_lite)
cc_library(cpu_info_lite SRCS cpu_info.cc)
cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite target_wrapper_lite ${tensor_lite})
cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
......
......@@ -12,8 +12,317 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by chunwei on 19-2-22.
//
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/cpu_info.h"
#ifdef LITE_WITH_ANDROID
#include <sys/syscall.h>
#include <unistd.h>
#endif
#if __APPLE__
#include "TargetConditionals.h"
#if TARGET_OS_IPHONE
#include <mach/machine.h>
#include <sys/sysctl.h>
#include <sys/types.h>
#endif // TARGET_OS_IPHONE
#endif // __APPLE__
namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
void ARMContext::SetCache(int l1size, int l2size, int l3size) {
DeviceInfo& dev = DeviceInfo::Global();
int cpu_count = arm_get_cpucount();
dev.L1_cache_.resize(cpu_count);
dev.L2_cache_.resize(cpu_count);
dev.L3_cache_.resize(cpu_count);
for (int i = 0; i < cpu_count; ++i) {
dev.L1_cache_[i] = l1size;
dev.L2_cache_[i] = l2size;
dev.L3_cache_[i] = l3size;
}
workspace_.Resize({2 * (l1size + l2size)});
}
ARMContext::ARMContext() {
active_ids_ = {0};
mode_ = LITE_POWER_HIGH;
DeviceInfo& dev = DeviceInfo::Global();
workspace_.Resize(
{static_cast<int64_t>(dev.L2_cache_[active_ids_[0]] / sizeof(float))});
#ifdef TARGET_IOS
arch_ = APPLE; // use 6x8
#else
if (dev.big_core_ids_.size() > 0) {
arch_ = dev.archs_[dev.big_core_ids_[0]];
}
#endif
}
PowerMode ARMContext::mode() const { return mode_; }
int ARMContext::threads() const { return active_ids_.size(); }
ARMContext::ARMContext(const ARMContext& ctx) {
mode_ = ctx.mode_;
active_ids_ = ctx.active_ids_;
workspace_ = ctx.workspace_;
arch_ = ctx.arch_;
count_ = ctx.count_;
}
ARMContext& ARMContext::operator=(const ARMContext& ctx) {
mode_ = ctx.mode_;
active_ids_ = ctx.active_ids_;
workspace_ = ctx.workspace_;
arch_ = ctx.arch_;
count_ = ctx.count_;
return *this;
}
void ARMContext::BindDev() {
#ifdef USE_OPENMP
int num_threads = active_ids_.size();
omp_set_num_threads(num_threads);
#ifdef LITE_WITH_ANDROID
std::vector<int> ssarets;
for (int j = 0; j < num_threads; ++j) {
ssarets.push_back(0);
}
#pragma omp parallel for
for (int i = 0; i < num_threads; i++) {
ssarets[i] = set_sched_affinity(active_ids_);
}
for (int i = 0; i < num_threads; i++) {
if (ssarets[i] != 0) {
LOGE("set cpu affinity failed, cpuID: %d\n", active_ids_[i]);
return;
}
}
#endif // LITE_WITH_ANDROID
#else // USE_OPENMP
#ifdef LITE_WITH_ANDROID
std::vector<int> cpuid1;
cpuid1.push_back(active_ids_[0]);
int ssaret = set_sched_affinity(cpuid1);
if (ssaret != 0) {
printf("set cpu affinity failed, cpuID: %d\n", active_ids_[0]);
return;
}
#endif // LITE_WITH_ANDROID
#endif // USE_OPENMP
}
void ARMContext::SetRunMode(PowerMode mode, int threads) {
DeviceInfo& dev = DeviceInfo::Global();
int big_core_size = dev.big_core_ids_.size();
int small_core_size = dev.little_core_ids_.size();
if (threads > big_core_size + small_core_size) {
threads = big_core_size + small_core_size;
}
#ifdef USE_OPENMP
count_++;
int shift_num = (count_ / 10) % big_core_size;
switch (mode) {
case LITE_POWER_FULL:
mode_ = mode;
active_ids_.clear();
for (int i = 0; i < threads; ++i) {
if (i < big_core_size) {
active_ids_.push_back(dev.big_core_ids_[i]);
} else {
active_ids_.push_back(dev.little_core_ids_[i - big_core_size]);
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
case LITE_POWER_HIGH:
active_ids_.clear();
if (big_core_size > 0) {
mode_ = LITE_POWER_HIGH;
if (threads > big_core_size) {
LOGE("threads: %d, exceed the big cores size: %d\n", threads,
big_core_size);
active_ids_ = dev.big_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.big_core_ids_[i]);
}
}
} else {
mode_ = LITE_POWER_LOW;
LOGE("HIGH POWER MODE is not support, switch to little cores\n");
if (threads > small_core_size) {
active_ids_ = dev.little_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.little_core_ids_[i]);
}
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
case LITE_POWER_LOW:
active_ids_.clear();
if (small_core_size > 0) {
mode_ = LITE_POWER_LOW;
if (threads > small_core_size) {
LOGW("threads: %d, exceed the little cores size: %d\n", threads,
small_core_size);
active_ids_ = dev.little_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.little_core_ids_[i]);
}
}
} else {
mode_ = LITE_POWER_HIGH;
LOGW("LOW POWER MODE is not support, switch to big cores\n");
if (threads > big_core_size) {
active_ids_ = dev.big_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.big_core_ids_[i]);
}
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
case LITE_POWER_NO_BIND:
mode_ = LITE_POWER_NO_BIND;
active_ids_.clear();
if (threads > dev.core_ids_.size()) {
active_ids_.resize(dev.core_ids_.size());
} else {
active_ids_.resize(threads);
}
break;
case LITE_POWER_RAND_HIGH:
active_ids_.clear();
if (big_core_size > 0) {
mode_ = LITE_POWER_RAND_HIGH;
if (threads > big_core_size) {
LOGW("threads: %d, exceed the big cores size: %d\n", threads,
big_core_size);
active_ids_ = dev.big_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(
dev.big_core_ids_[(i + shift_num) % big_core_size]);
}
}
} else {
mode_ = LITE_POWER_LOW;
LOGW("HIGH POWER MODE is not support, switch to little cores\n");
if (threads > small_core_size) {
active_ids_ = dev.little_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.little_core_ids_[i]);
}
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
case LITE_POWER_RAND_LOW:
active_ids_.clear();
if (small_core_size > 0) {
mode_ = LITE_POWER_RAND_LOW;
if (threads > small_core_size) {
LOGW("threads: %d, exceed the little cores size: %d\n", threads,
small_core_size);
active_ids_ = dev.little_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(
dev.little_core_ids_[(i + shift_num) % small_core_size]);
}
}
} else {
mode_ = LITE_POWER_HIGH;
LOGW("LOW POWER MODE is not support, switch to big cores\n");
if (threads > big_core_size) {
active_ids_ = dev.big_core_ids_;
} else {
for (int i = 0; i < threads; ++i) {
active_ids_.push_back(dev.big_core_ids_[i]);
}
}
}
if (active_ids_.size() == 0) {
active_ids_.push_back(0);
}
break;
}
//! fix multi-threads LITE_POWER_HIGH mode
if (mode_ == LITE_POWER_NO_BIND || threads > 1) {
int threads = active_ids_.size();
omp_set_num_threads(threads);
} else {
if (check_online(active_ids_)) {
BindDev();
} else {
LOG(ERROR) << "core id " << active_ids_[0]
<< " is offline, switch to NO BIND MODE";
int threads = active_ids_.size();
omp_set_num_threads(threads);
}
}
#else
if (big_core_size > 0) {
active_ids_ = {dev.big_core_ids_[0]};
} else {
active_ids_ = {0};
}
#endif
//! alloc memory for sgemm in this context
int temp_mem_size =
DeviceInfo::Global().L2_cache_[active_ids_[0]] / sizeof(float);
workspace_.Resize({temp_mem_size});
arch_ = DeviceInfo::Global().archs_[active_ids_[0]];
}
ARMArch ARMContext::arch() const { return arch_; }
void ARMContext::SetArch(ARMArch arch) { arch_ = arch; }
int ARMContext::l1_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L1_cache_[active_ids_[0]];
}
int ARMContext::l2_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L2_cache_[active_ids_[0]];
}
int ARMContext::l3_cache_size() const {
DeviceInfo& dev = DeviceInfo::Global();
return dev.L3_cache_[active_ids_[0]];
}
bool ARMContext::ExtendWorkspace(DDimLite dims) {
auto count = dims.product();
auto old = workspace_.dims();
if (count == old.product()) {
return false;
}
workspace_.Resize(
{static_cast<int64_t>(count + l2_cache_size() / sizeof(float))});
return true;
}
#endif // LITE_WITH_ARM
} // namespace lite
} // namespace paddle
......@@ -26,6 +26,8 @@
#include <memory>
#include <set>
#include <vector>
#include "paddle/fluid/lite/core/cpu_info.h"
#include "paddle/fluid/lite/core/lite_tensor.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle {
......@@ -34,7 +36,44 @@ namespace lite {
struct HostContext {};
#ifdef LITE_WITH_ARM
struct ARMContext {};
struct ARMContext {
public:
ARMContext();
ARMContext(PowerMode mode, int threads);
ARMContext(const ARMContext& ctx);
ARMContext& operator=(const ARMContext& ctx);
void SetRunMode(PowerMode mode, int threads);
void SetCache(int l1size, int l2size, int l3size);
void SetArch(ARMArch arch);
void BindDev();
PowerMode mode() const;
int threads() const;
ARMArch arch() const;
template <typename T>
T* workspace_data() {
return workspace_.mutable_data<T>();
}
int l1_cache_size() const;
int l2_cache_size() const;
int l3_cache_size() const;
bool ExtendWorkspace(DDimLite dims);
private:
// LITE_POWER_HIGH stands for using big cores,
// LITE_POWER_LOW stands for using small core,
// LITE_POWER_FULL stands for using all cores
ARMArch arch_;
PowerMode mode_;
std::vector<int> active_ids_;
TensorLite workspace_;
int64_t count_{0};
};
#endif
#ifdef LITE_WITH_CUDA
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/cpu_info.h"
#include <cstdarg>
namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
void DeviceInfo::get_info(DeviceInfo* dev) {
set_default_cache(dev);
dev->compute_core_num_ = arm_get_cpucount();
dev->max_memory_ = arm_get_meminfo();
// get max freq
#ifdef LITE_WITH_ANDROID
std::vector<int> max_freq(dev->compute_core_num_);
for (int i = 0; i < dev->compute_core_num_; ++i) {
max_freq[i] = get_max_freq_khz(i) / 1000;
}
std::string cpu_name = arm_get_cpu_name();
if (get_cpu_info_from_name(dev, cpu_name) != true) {
arm_sort_cpuid_by_max_frequency(dev->compute_core_num_, &dev->core_ids_,
max_freq, &dev->cluster_ids_);
dev->big_core_ids_.clear();
dev->little_core_ids_.clear();
for (int i = 0; i < dev->cluster_ids_.size(); ++i) {
if (dev->cluster_ids_[i] == 0) {
dev->big_core_ids_.push_back(dev->core_ids_[i]);
} else {
dev->little_core_ids_.push_back(dev->core_ids_[i]);
}
}
arm_get_cpu_arch(&dev->archs_);
}
LOG(INFO) << "ARM multiprocessors number: " << dev->compute_core_num_;
for (int i = 0; i < dev->compute_core_num_; ++i) {
LOG(INFO) << "ARM multiprocessors ID: " << dev->core_ids_[i]
<< ", frequence: " << max_freq[i]
<< ", cluster ID: " << dev->cluster_ids_[dev->core_ids_[i]]
<< ", CPU ARCH: A" << dev->archs_[i];
}
LOG(INFO) << "L1 DataCache size is: ";
for (int i = 0; i < dev->compute_core_num_; ++i) {
LOG(INFO) << dev->L1_cache_[i] / 1024 << " KB";
}
LOG(INFO) << "L2 Cache size is: ";
for (int i = 0; i < dev->compute_core_num_; ++i) {
LOG(INFO) << dev->L2_cache_[i] / 1024 << " KB";
}
LOG(INFO) << "Total memory: " << dev->max_memory_ << "KB";
dev->max_freq_ = max_freq[0];
for (int j = 1; j < dev->compute_core_num_; ++j) {
if (dev->max_freq_ < max_freq[j]) {
dev->max_freq_ = max_freq[j];
}
}
#elif defined(TARGET_IOS)
arm_get_cpu_arch(&dev->archs_);
#endif
}
// cache_id : 0 -> L1, 1 -> L2, 2 -> L3
void set_cache_info(DeviceInfo* cpu_info, int cache_id, int argc, ...) {
va_list arg_ptr;
va_start(arg_ptr, argc);
std::vector<int>* cache;
switch (cache_id) {
case 0:
cache = &cpu_info->L1_cache_;
break;
case 1:
cache = &cpu_info->L2_cache_;
break;
case 2:
cache = &cpu_info->L3_cache_;
break;
default:
break;
}
int core_num = cpu_info->compute_core_num_;
cache->resize(core_num);
if (argc == 1) {
int cache_size = va_arg(arg_ptr, int);
for (int i = 0; i < core_num; ++i) {
(*cache)[i] = cache_size;
}
} else {
int big_core_num = cpu_info->big_core_ids_.size();
int little_core_num = cpu_info->little_core_ids_.size();
int big_core_cache_size = va_arg(arg_ptr, int);
int little_core_cache_size = va_arg(arg_ptr, int);
for (int i = 0; i < big_core_num; ++i) {
(*cache)[cpu_info->big_core_ids_[i]] = big_core_cache_size;
}
for (int i = 0; i < little_core_num; ++i) {
(*cache)[cpu_info->little_core_ids_[i]] = little_core_cache_size;
}
}
va_end(arg_ptr);
}
void set_arch_info(DeviceInfo* cpu_info, int argc, ...) {
va_list arg_ptr;
va_start(arg_ptr, argc);
int core_num = cpu_info->compute_core_num_;
cpu_info->archs_.resize(core_num);
if (argc == 1) {
ARMArch arch = (ARMArch)va_arg(arg_ptr, int);
for (int i = 0; i < core_num; ++i) {
cpu_info->archs_[i] = arch;
}
} else {
ARMArch big_core_arch = (ARMArch)va_arg(arg_ptr, int);
ARMArch little_core_arch = (ARMArch)va_arg(arg_ptr, int);
int big_core_num = cpu_info->big_core_ids_.size();
int little_core_num = cpu_info->little_core_ids_.size();
for (int i = 0; i < big_core_num; ++i) {
cpu_info->archs_[cpu_info->big_core_ids_[i]] = big_core_arch;
}
for (int i = 0; i < little_core_num; ++i) {
cpu_info->archs_[cpu_info->little_core_ids_[i]] = little_core_arch;
}
}
va_end(arg_ptr);
}
bool get_cpu_info_from_name(DeviceInfo* cpu_info, std::string hardware_name) {
/* Snapdragon */
if (hardware_name.find("SDM845") != std::string::npos) { // 845
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 2, kA75, kA55);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 2, 256 * 1024, 128 * 1024);
set_cache_info(cpu_info, 2, 1, 2048 * 1024);
return true;
} else if (hardware_name.find("SDM710") != std::string::npos) { // 710
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 0, 0};
set_arch_info(cpu_info, 2, kA75, kA55);
return true;
} else if (hardware_name.find("MSM8998") != std::string::npos) { // 835
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 2, kA73, kA53);
set_cache_info(cpu_info, 0, 2, 64 * 1024);
set_cache_info(cpu_info, 1, 2, 1024 * 1024,
/*real cache size is 2M, while that will get bad performace
on conv3x3s1 or gemm, set to 1M or 512K*/
1024 * 1024);
return true;
} else if (hardware_name.find("MSM8996") != std::string::npos) { // 820
cpu_info->compute_core_num_ = 4;
cpu_info->core_ids_ = {0, 1, 2, 3};
cpu_info->big_core_ids_ = {2, 3};
cpu_info->little_core_ids_ = {0, 1};
cpu_info->cluster_ids_ = {1, 1, 0, 0};
set_arch_info(cpu_info, 1, kA72);
set_cache_info(cpu_info, 0, 1, 24 * 1024);
set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024);
return true;
} else if (hardware_name.find("SDM660") != std::string::npos ||
hardware_name.find("SDM636") != std::string::npos) { // 660, 636
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA73);
set_cache_info(cpu_info, 0, 2, 64 * 1024, 32 * 1024);
set_cache_info(cpu_info, 1, 1, 1024 * 1024);
return true;
} else if (hardware_name.find("MSM8976") != std::string::npos) { // 652,653
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 2, kA72, kA53);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024);
return true;
} else if (hardware_name.find("MSM8953") != std::string::npos) { // 625
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->little_core_ids_ = {};
cpu_info->cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA53);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 1, 1024 * 1024);
return true;
} else if (hardware_name.find("MSM8939") != std::string::npos) { // 615
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {0, 1, 2, 3};
cpu_info->little_core_ids_ = {4, 5, 6, 7};
cpu_info->cluster_ids_ = {0, 0, 0, 0, 1, 1, 1, 1};
set_arch_info(cpu_info, 1, kA53);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 2, 512 * 1024, 256 * 1024);
return true;
/* MediaTek */
} else if (hardware_name.find("MT6797") !=
std::string::npos) { // X20/X23/X25/X27
cpu_info->compute_core_num_ = 10;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
cpu_info->big_core_ids_ = {8, 9};
cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0};
set_arch_info(cpu_info, 2, kA72, kA53);
set_cache_info(cpu_info, 0, 1, 32 * 1024);
set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024);
return true;
} else if (hardware_name.find("MT6799") != std::string::npos) { // X30
cpu_info->compute_core_num_ = 10;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
cpu_info->big_core_ids_ = {8, 9};
cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0};
set_arch_info(cpu_info, 2, kA73, kA53);
return true;
} else if (hardware_name.find("MT6795") != std::string::npos ||
hardware_name.find("MT6762") != std::string::npos ||
hardware_name.find("MT6755T") != std::string::npos ||
hardware_name.find("MT6755S") != std::string::npos ||
hardware_name.find("MT6753") != std::string::npos ||
hardware_name.find("MT6752") != std::string::npos ||
hardware_name.find("MT6750") != std::string::npos) {
// X10, P22, P15/P18, MT6753, MT6752/MT6752M, MT6750
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->little_core_ids_ = {};
cpu_info->cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA53);
return true;
} else if (hardware_name.find("MT6758") != std::string::npos ||
hardware_name.find("MT6757") != std::string::npos ||
hardware_name.find("MT6763") != std::string::npos ||
hardware_name.find("MT6755M") != std::string::npos ||
hardware_name.find("MT6755") !=
std::string::npos) { // P30, P20/P25, P23, P10
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA53);
return true;
} else if (hardware_name.find("MT6771") != std::string::npos) { // P60
cpu_info->compute_core_num_ = 8;
cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
cpu_info->big_core_ids_ = {4, 5, 6, 7};
cpu_info->little_core_ids_ = {0, 1, 2, 3};
cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0};
set_arch_info(cpu_info, 2, kA73, kA53);
return true;
} else if (hardware_name.find("MT6765") != std::string::npos ||
hardware_name.find("MT6739") != std::string::npos ||
hardware_name.find("MT6738") != std::string::npos ||
hardware_name.find("MT6737") !=
std::string::npos) { // A22, MT6739, MT6738, MT6767
cpu_info->compute_core_num_ = 4;
cpu_info->core_ids_ = {0, 1, 2, 3};
cpu_info->big_core_ids_ = {0, 0, 0, 0};
cpu_info->little_core_ids_ = {};
cpu_info->cluster_ids_ = {0, 0, 0, 0};
set_arch_info(cpu_info, 1, kA53);
return true;
}
return false;
}
size_t arm_get_meminfo() {
#ifdef LITE_WITH_ANDROID
// get cpu count from /proc/cpuinfo
FILE* fp = fopen("/proc/meminfo", "rb");
if (!fp) {
return 1;
}
size_t memsize = 0;
char line[1024];
while (!feof(fp)) {
char* s = fgets(line, 1024, fp);
if (!s) {
break;
}
sscanf(s, "MemTotal: %d kB", &memsize);
}
fclose(fp);
return memsize;
#elif defined(TARGET_IOS)
// to be implemented
printf("not implemented\n");
return 0;
#endif
}
int arm_get_cpucount() {
#ifdef LITE_WITH_ANDROID
// get cpu count from /sys/devices/system/cpu/cpunum/uevent
int max_cpu_count = 20;
int count = 0;
for (int i = 0; i < max_cpu_count; ++i) {
char path[256];
snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/uevent", i);
FILE* fp = fopen(path, "rb");
if (!fp) {
break;
}
count++;
fclose(fp);
}
if (count < 1) {
count = 1;
}
return count;
#elif defined(TARGET_IOS)
int count = 0;
size_t len = sizeof(count);
sysctlbyname("hw.ncpu", &count, &len, NULL, 0);
if (count < 1) {
count = 1;
}
return count;
#else
return 1;
#endif
}
void arm_get_cpu_arch(std::vector<ARMArch>* archs) {
#ifdef LITE_WITH_ANDROID
archs->clear();
//! get CPU ARCH
FILE* fp = fopen("/proc/cpuinfo", "rb");
if (!fp) {
return;
}
char line[1024];
while (!feof(fp)) {
char* s = fgets(line, 1024, fp);
if (!s) {
break;
}
if (strstr(line, "part") != NULL) {
int arch_id = 0;
sscanf(s, "CPU part\t: %x", &arch_id);
switch (arch_id) {
case 0xd03:
archs->push_back(kA53);
break;
case 0xd05:
archs->push_back(kA55);
break;
case 0xd07:
archs->push_back(kA57);
break;
case 0xd08:
archs->push_back(kA72);
break;
case 0xd09:
archs->push_back(kA73);
break;
case 0xd0a:
archs->push_back(kA75);
break;
case 0x800:
// 835
archs->push_back(kA73);
break;
case 0x205:
// 820
archs->push_back(kA72);
break;
default:
LOG(ERROR) << "unknow type";
archs->push_back(kARMArch_UNKOWN);
}
}
}
fclose(fp);
int cpu_count = arm_get_cpucount();
if (archs->size() < cpu_count) {
for (int i = archs->size(); i < cpu_count; ++i) {
archs->push_back(archs->at(i - 1));
}
}
#endif
#ifdef TARGET_IOS
int cpu_count = arm_get_cpucount();
for (int i = 0; i < cpu_count; ++i) {
archs->push_back(APPLE);
}
#endif
}
#ifdef LITE_WITH_ANDROID
void set_default_cache(DeviceInfo* dev) {
int cpu_count = arm_get_cpucount();
dev->L1_cache_.resize(cpu_count);
dev->L2_cache_.resize(cpu_count);
dev->L3_cache_.resize(cpu_count);
#ifdef TARGET_IOS
for (int i = 0; i < cpu_count; ++i) {
dev->L1_cache_[i] = 64 * 1024;
dev->L2_cache_[i] = 2048 * 1024;
dev->L3_cache_[i] = 0;
}
#else
for (int i = 0; i < cpu_count; ++i) {
dev->L1_cache_[i] = 32 * 1024;
dev->L2_cache_[i] = 512 * 1024;
dev->L3_cache_[i] = 0;
}
#endif
}
std::string arm_get_cpu_name() {
FILE* fp = fopen("/proc/cpuinfo", "rb");
if (!fp) {
return "";
}
char line[1024];
while (!feof(fp)) {
char* s = fgets(line, 1024, fp);
if (!s) {
break;
}
if (strstr(line, "Hardware") != NULL) {
fclose(fp);
return std::string(line);
}
}
fclose(fp);
return "";
}
int get_max_freq_khz(int cpuid) {
// first try, for all possible cpu
char path[256];
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", cpuid);
FILE* fp = fopen(path, "rb");
if (!fp) {
// second try, for online cpu
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state",
cpuid);
fp = fopen(path, "rb");
if (!fp) {
// third try, for online cpu
snprintf(path, sizeof(path),
"/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", cpuid);
fp = fopen(path, "rb");
if (!fp) {
return -1;
}
int max_freq_khz = -1;
fscanf(fp, "%d", &max_freq_khz);
fclose(fp);
return max_freq_khz;
}
}
int max_freq_khz = 0;
while (!feof(fp)) {
int freq_khz = 0;
int nscan = fscanf(fp, "%d %*d", &freq_khz);
if (nscan != 1) {
break;
}
if (freq_khz > max_freq_khz) {
max_freq_khz = freq_khz;
}
}
fclose(fp);
return max_freq_khz;
}
int arm_sort_cpuid_by_max_frequency(int cpu_count, std::vector<int>* cpuids,
const std::vector<int>& cpu_freq,
std::vector<int>* cluster_ids) {
if (cpu_count == 0) {
return 0;
}
cpuids->resize(cpu_count);
cluster_ids->resize(cpu_count);
for (int i = 0; i < cpu_count; i++) {
cpuids->at(i) = i;
}
// sort cpuid as big core first
// simple bubble sort
for (int i = 0; i < cpu_count; i++) {
for (int j = i + 1; j < cpu_count; j++) {
if (cpu_freq[i] < cpu_freq[j]) {
// swap
int tmp = cpuids->at(i);
cpuids->at(i) = cpuids->at(j);
cpuids->at(j) = tmp;
}
}
}
// SMP
int mid_max_freq_khz =
(cpu_freq[cpuids->at(0)] + cpu_freq[cpuids->at(cpu_count - 1)]) / 2;
for (int i = 0; i < cpu_count; i++) {
cpuids->at(i) = i;
if (cpu_freq[i] >= mid_max_freq_khz) {
cluster_ids->at(i) = 0;
} else {
cluster_ids->at(i) = 1;
}
}
return 0;
}
int check_online(const std::vector<int>& core_ids) {
if (core_ids.size() == 0) {
return 0;
}
char path[256];
int online = 1;
for (int i = 0; i < core_ids.size(); ++i) {
snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/online",
core_ids[i]);
FILE* fp = fopen(path, "rb");
if (!fp) {
return 0;
}
int cur_online = 0;
fscanf(fp, "%d", &cur_online);
online &= cur_online;
fclose(fp);
}
return online;
}
int set_sched_affinity(const std::vector<int>& cpuids) {
// #define CPU_SETSIZE 1024
// #define __NCPUBITS (8 * sizeof (unsigned long))
// typedef struct
// {
// unsigned long __bits[CPU_SETSIZE / __NCPUBITS];
// } cpu_set_t;
// set affinity for thread
#ifdef __GLIBC__
pid_t pid = syscall(SYS_gettid);
#else
pid_t pid = gettid();
#endif
cpu_set_t mask;
CPU_ZERO(&mask);
for (int i = 0; i < cpuids.size(); i++) {
CPU_SET(cpuids[i], &mask);
}
int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask);
if (syscallret) {
LOG(ERROR) << "syscall error " << syscallret;
return -1;
}
return 0;
}
#endif // LITE_WITH_ANDROID
#endif // LITE_WITH_ARM
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/utils/cp_logging.h"
#ifdef LITE_WITH_ANDROID
#include <sys/syscall.h>
#include <unistd.h>
#endif
#if __APPLE__
#include "TargetConditionals.h"
#if TARGET_OS_IPHONE
#include <mach/machine.h>
#include <sys/sysctl.h>
#include <sys/types.h>
#endif // TARGET_OS_IPHONE
#endif // __APPLE__
namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
typedef enum {
LITE_POWER_HIGH = 0,
LITE_POWER_LOW = 1,
LITE_POWER_FULL = 2,
LITE_POWER_NO_BIND = 3,
LITE_POWER_RAND_HIGH = 4,
LITE_POWER_RAND_LOW = 5
} PowerMode;
typedef enum {
kAPPLE = 0,
kA53 = 53,
kA55 = 55,
kA57 = 57,
kA72 = 72,
kA73 = 73,
kA75 = 75,
kA76 = 76,
kARMArch_UNKOWN = -1
} ARMArch;
class DeviceInfo {
public:
int idx_;
int max_freq_;
int min_freq_;
int generate_arch_;
int compute_core_num_;
int max_memory_;
int sharemem_size_;
std::string device_name_;
std::string compute_ability_;
std::vector<int> L1_cache_;
std::vector<int> L2_cache_;
std::vector<int> L3_cache_;
std::vector<int> core_ids_;
std::vector<int> big_core_ids_;
std::vector<int> little_core_ids_;
std::vector<int> cluster_ids_;
std::vector<ARMArch> archs_;
static DeviceInfo& Global() {
static auto* x = new DeviceInfo;
return *x;
}
static void init_info() {
auto& info = Global();
get_info(&info);
}
private:
DeviceInfo() = default;
static void get_info(DeviceInfo* dev);
};
size_t arm_get_meminfo();
int arm_get_cpucount();
void arm_get_cpu_arch(std::vector<ARMArch>* archs);
bool get_cpu_info_from_name(DeviceInfo* cpu_info, std::string hardware_name);
#ifdef LITE_WITH_ANDROID
void set_default_cache(DeviceInfo* dev);
std::string arm_get_cpu_name();
int get_max_freq_khz(int cpuid);
int arm_sort_cpuid_by_max_frequency(int cpu_count, std::vector<int>* cpuids,
const std::vector<int>& cpu_freq,
std::vector<int>* cluster_ids);
int check_online(const std::vector<int>& core_ids);
int set_sched_affinity(const std::vector<int>& cpuids);
#endif // LITE_WITH_ANDROID
#endif // LITE_WITH_ARM
} // namespace lite
} // namespace paddle
......@@ -44,7 +44,7 @@ class KernelBase {
virtual void Run() = 0;
void SetContext(std::unique_ptr<KernelContext>&& ctx) {
context_ = std::move(ctx);
ctx_ = std::move(ctx);
}
template <typename T>
void SetParam(T param) {
......@@ -86,7 +86,7 @@ class KernelBase {
virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0;
const KernelContext* context() const { return context_.get(); }
const KernelContext* context() const { return ctx_.get(); }
virtual std::string name() const = 0;
// Short human-readable document.
......@@ -134,7 +134,7 @@ class KernelBase {
void Torch() {}
protected:
std::unique_ptr<KernelContext> context_;
std::unique_ptr<KernelContext> ctx_;
mutable operators::param_t param_;
// The corresponding op type.
std::string op_type_{};
......@@ -152,9 +152,6 @@ template <TargetType Target, PrecisionType Precision,
DataLayoutType DataLayout = DataLayoutType::kNCHW>
class KernelLite : public KernelBase {
public:
// Set runtime context.
void SetContext(std::unique_ptr<KernelContext>&& ctx) { ctx_ = ctx; }
// Run the kernel.
virtual void Run() { CHECK(false) << "Not Implemented"; }
......@@ -168,9 +165,6 @@ class KernelLite : public KernelBase {
KernelLite() = default;
virtual ~KernelLite() = default;
protected:
std::unique_ptr<KernelContext> ctx_;
};
template <TargetType Target, PrecisionType Precision, DataLayoutType DataLayout>
......
......@@ -14,6 +14,7 @@
#pragma once
#include <algorithm>
#include <functional> // for multiplies
#include <memory>
#include <numeric>
#include <vector>
......@@ -40,6 +41,10 @@ class DDimLite : public DDimBase<DDimLite> {
size_t size() const { return data_.size(); }
bool empty() const { return data_.empty(); }
value_type product() const {
return std::accumulate(std::begin(data_), std::end(data_), 1,
std::multiplies<value_type>());
}
const std::vector<value_type> &data() const { return data_; }
private:
......@@ -61,8 +66,10 @@ class TensorLite : public TensorBase<TensorLite> {
}
void Resize(const DDimLite &ddim) { dims_ = ddim; }
void Resize(const std::vector<int64_t> &x) { dims_ = DDimLite(x); }
const DDimLite &dims() const { return dims_; }
int64_t numel() const { return dims_.product(); }
const LoD &lod() const { return lod_; }
LoD *mutable_lod() { return &lod_; }
......
......@@ -32,7 +32,6 @@ class RuntimeContextAssignPass : public StmtPass {
if (!node.IsStmt()) continue;
auto& inst = node.AsStmt();
switch (inst.picked_kernel().target()) {
case TARGET(kHost):
case TARGET(kX86):
......@@ -42,6 +41,11 @@ class RuntimeContextAssignPass : public StmtPass {
case TARGET(kCUDA):
inst.picked_kernel().SetContext(NewCudaContext());
break;
#endif
#ifdef LITE_WITH_ARM
case TARGET(kARM):
inst.picked_kernel().SetContext(NewARMContext());
break;
#endif
default:
LOG(FATAL) << "unsupported target "
......@@ -54,9 +58,18 @@ class RuntimeContextAssignPass : public StmtPass {
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<HostContext>();
// Some initialization here.
return ctx;
}
#ifdef LITE_WITH_ARM
std::unique_ptr<KernelContext> NewARMContext() {
DeviceInfo::init_info();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
return ctx;
}
#endif
#ifdef LITE_WITH_CUDA
std::unique_ptr<KernelContext> NewCudaContext() {
std::unique_ptr<KernelContext> ctx(new KernelContext);
......@@ -66,9 +79,7 @@ class RuntimeContextAssignPass : public StmtPass {
cuda.blas_fp32 = cublas_fp32_;
return ctx;
}
#endif
#ifdef LITE_WITH_CUDA
void InitCudaBlas() {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
}
......
message(STATUS "add lite kernels")
set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite ${tensor_lite})
set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite context_lite ${tensor_lite})
add_subdirectory(host)
add_subdirectory(arm)
add_subdirectory(cuda)
......
......@@ -4,11 +4,13 @@ endif()
message(STATUS "compile with lite ARM kernels")
cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm eigen3)
set(arm_kernels
fc_compute_arm
relu_compute_arm
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <Eigen/Core>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
......@@ -22,24 +22,42 @@ namespace lite {
namespace kernels {
namespace arm {
// NOTE should use pure std C++ implementation.
void FcCompute::Run() {
auto& param = this->Param<operators::FcParam>();
auto x_dims = param.input->dims();
auto w_dims = param.w->dims();
CHECK_GE(param.input->dims().size(), 2UL);
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
fc_compute_eigen(
param.input->data<float>(), // x
param.input->dims().Slice(0, param.in_num_col_dims).production(),
param.input->dims()
.Slice(param.in_num_col_dims, param.input->dims().size())
.production(),
param.w->data<float>(), // w
param.w->dims()[1], // w_w
param.w->dims()[0], // w_h
param.bias->data<float>(), // b
param.output->mutable_data<float>());
const auto* i_data = param.input->data<float>();
const auto* w_data = param.w->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
int x_h = x_dims.Slice(0, param.in_num_col_dims).production();
int x_w = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production();
int n = w_dims[1];
CHECK_EQ(x_w, static_cast<int>(w_dims[0]));
auto& ctx = this->ctx_->template As<ARMContext>();
if (x_h > 1) {
float* packed_in = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.l2_cache_size() / sizeof(float);
lite::arm::math::prepackA(packed_in, i_data, x_w, 0, x_h, 0, x_w, false,
&ctx);
lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, x_h, n,
x_w, false, false, false, &ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n);
lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n);
}
} else {
// use sgemmv
// sgemv((const float*)weights, (const float*)din, (float*)dout,
// false, n, x_w, _param->_flag_bias, (float*)bias, false);
}
}
TargetType FcCompute::target() const { return TARGET(kARM); }
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/operators/fc_op.h"
......@@ -34,52 +33,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~FcCompute() = default;
};
template <typename T>
void fc_compute_eigen(const T* x, int x_w, int x_h, //
const T* w, int w_w, int w_h, //
const T* b, //
T* out) {
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_h);
Out = X * W.transpose();
if (b) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_h);
Out = Out.array().rowwise() + B.transpose().array();
}
}
template <typename T>
__attribute__((optimize("unroll-loops"))) //
T dot(const T* x, const T* y, int dim) {
T out{};
for (int i = 0; i < dim; i++) {
out += x[i] * y[i];
}
return out;
}
template <typename T>
void fc_compute_naive(const T* x, int x_w, int x_h, //
const T* w, int w_w, int w_h, //
const T* b, //
T* out) {
CHECK_EQ(x_w, w_w);
// out shape: (x_h, w_w)
memset(out, 0, x_h * w_h * sizeof(T));
for (int r = 0; r < x_h; r++) {
for (int c = 0; c < w_h; c++) {
out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c];
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/lite/kernels/arm/fc_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
......@@ -22,60 +23,79 @@ namespace lite {
namespace kernels {
namespace arm {
TEST(fc_compute_naive, test) {
lite::Tensor x, w, b, out, out1;
const int batch_size = 2;
TEST(fc_arm, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
}
TEST(fc_arm, init) {
FcCompute fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kARM));
}
TEST(fc_arm, compare_test) {
lite::Tensor x, w, b, out, ref;
constexpr int batch_size = 2;
x.Resize({batch_size, 3});
w.Resize({4, 3});
w.Resize({3, 4});
b.Resize({1, 4});
out.Resize({batch_size, 4});
out1.Resize({batch_size, 4});
ref.Resize({batch_size, 4});
auto x_data = x.mutable_data<float>();
auto w_data = w.mutable_data<float>();
auto b_data = b.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto out_data1 = out1.mutable_data<float>();
auto ref_data = ref.mutable_data<float>();
for (int i = 0; i < product(x.dims()); i++) x_data[i] = i;
for (int i = 0; i < product(w.dims()); i++) w_data[i] = i;
for (int i = 0; i < product(b.dims()); i++) b_data[i] = i;
for (int64_t i = 0; i < x.dims().product(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().product(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < b.dims().product(); i++) {
b_data[i] = static_cast<float>(i);
}
fc_compute_naive(x_data, 3, batch_size, //
// TODO(TJ): enable bias soon
b_data = nullptr;
lite::arm::math::fc_compute_eigen(x_data, batch_size, 3, //
w_data, 3, 4, //
b_data, out_data);
fc_compute_eigen(x_data, 3, batch_size, //
w_data, 3, 4, //
b_data, out_data1);
for (int i = 0; i < product(out.dims()); i++) {
EXPECT_NEAR(out_data[0], out_data1[0], 1e-6);
}
}
b_data, ref_data);
TEST(fc_arm, init) {
// fc compute kernel
FcCompute fc;
ASSERT_EQ(fc.precision(), PRECISION(kFloat));
ASSERT_EQ(fc.target(), TARGET(kARM));
}
operators::FcParam param;
TEST(fc_arm, algorithm) {
using matrix_t = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
using matrix_map_t = Eigen::Map<matrix_t>;
param.in_num_col_dims = 1;
param.input = &x;
param.w = &w;
param.bias = nullptr;
param.output = &out;
param.in_mat_dims = x.dims();
// dim 10, 20
std::vector<float> input(10 * 20);
std::vector<float> w(20 * 20);
std::vector<float> output(10 * 20);
DeviceInfo::init_info();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
Eigen::Map<const matrix_t> input_mat(input.data(), 10, 20);
Eigen::Map<const matrix_t> weight_mat(w.data(), 20, 20);
matrix_map_t output_mat(output.data(), 10, 20);
VLOG(3) << "output vs ref";
for (int i = 0; i < out.dims().product(); i++) {
VLOG(3) << out_data[i] << " vs " << ref_data[i];
}
output_mat = weight_mat.transpose() * input_mat;
for (int i = 0; i < out.dims().product(); ++i) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
}
TEST(fc_arm, compute) {
TEST(fc_arm, num_col_dims) {
FcCompute fc;
operators::FcParam param;
......@@ -84,20 +104,28 @@ TEST(fc_arm, compute) {
lite::Tensor bias;
lite::Tensor output;
x.Resize(DDim(std::vector<int64_t>({1, 10, 20})));
w.Resize(DDim(std::vector<int64_t>({20, 20})));
bias.Resize(DDim(std::vector<int64_t>({1, 10})));
output.Resize(DDim(std::vector<int64_t>({10, 20})));
x.Resize({1, 2, 3});
w.Resize({3, 4});
bias.Resize({1, 4});
output.Resize({2, 4});
auto* x_data = x.mutable_data<float>();
auto* w_data = w.mutable_data<float>();
auto* bias_data = bias.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < 10 * 20; i++) x_data[i] = i;
for (int i = 0; i < 20 * 20; i++) w_data[i] = i;
for (int i = 0; i < 10; i++) bias_data[i] = i;
for (int i = 0; i < 10 * 20; i++) output_data[i] = 0;
for (int64_t i = 0; i < x.dims().product(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < w.dims().product(); i++) {
w_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < bias.dims().product(); i++) {
bias_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < output.dims().product(); i++) {
output_data[i] = static_cast<float>(i);
}
param.in_num_col_dims = 2;
param.input = &x;
......@@ -106,20 +134,13 @@ TEST(fc_arm, compute) {
param.output = &output;
param.in_mat_dims = x.dims();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
DeviceInfo::init_info();
fc.SetParam(param);
fc.SetContext(std::move(ctx));
fc.Run();
LOG(INFO) << "x";
for (int i = 0; i < 10 * 20; i++) LOG(INFO) << x_data[i];
LOG(INFO) << "output:";
for (int i = 0; i < 10 * 20; i++) LOG(INFO) << output.data<float>()[i];
}
TEST(fc, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("fc");
ASSERT_TRUE(fc);
}
} // namespace arm
......
......@@ -35,8 +35,8 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using param_t = operators::MulParam;
void Run() override {
CHECK(context_) << "running context should be set first";
auto& context = context_->As<CUDAContext>();
CHECK(ctx_) << "running context should be set first";
auto& context = ctx_->As<CUDAContext>();
CHECK(context.blas_fp32) << "blas should init first";
/*
auto& blas = *context.blas_fp32;
......
......@@ -60,7 +60,7 @@ class SquareCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::ActivationParam;
void Run() override {
auto& context = context_->As<X86Context>();
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationParam>();
CHECK(context.x86_device_context);
......@@ -82,7 +82,7 @@ class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::ActivationGradParam;
void Run() override {
auto& context = context_->As<X86Context>();
auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationGradParam>();
CHECK(context.x86_device_context);
param.X_grad->template mutable_data<T>();
......
......@@ -38,7 +38,7 @@ class ElementwiseSubCompute
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = context_->As<X86Context>();
auto& context = ctx_->As<X86Context>();
CHECK(context.x86_device_context);
param.Out->template mutable_data<T>();
......
......@@ -22,7 +22,9 @@ function cmake_arm {
-DLITE_WITH_CUDA=OFF \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DWITH_TESTING=ON \
-DWITH_MKL=OFF \
-DWITH_MKLDNN=OFF
make cxx_api_lite_bin -j8
}
function build {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册