提交 25fa2b86 编写于 作者: J Jiaying Zhao 提交者: GitHub

add gemm_int8 arm64 version without openmp (#1708)

* add gemm_int8 arm64 version without openmp

* add gemm_int8 arm64 version with openmp
上级 6a31b740
......@@ -28,7 +28,7 @@ limitations under the License. */
#if __aarch64__
#define MR_INT8 4
#define NR_INT8 2
#define NR_INT8 4
#define MR 6
#define NR 16
#else
......@@ -181,12 +181,15 @@ class Gemm {
std::string mode, float *bias, float *bias1);
// 8 bits function cluster begins
// 8 bits int small block inner product
// 8 bits int small block inner product, data packed k = 1
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bits int small block inner product, data packed k = 16
void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
void AddDot4x4(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bits int inner product
......@@ -203,13 +206,15 @@ class Gemm {
// 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
void PackMatrixB_4c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
......@@ -219,6 +224,8 @@ class Gemm {
const int8_t *A, int32_t lda, int8_t *buffer);
void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
void PackMatrixB_omp_4c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
// 8 bits int matrix product
template <typename Itype, typename Btype, typename Otype>
......@@ -314,7 +321,11 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t mc, nc;
for (int32_t j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
PackMatrixB_4c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8);
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8);
#endif
for (int32_t i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
......@@ -375,7 +386,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO(paddle mobile)
PackMatrixB_omp_4c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif
......@@ -397,7 +408,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__
// TODO(paddle mobile)
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif
......@@ -421,7 +432,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO(paddle mobile)
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
......@@ -451,7 +462,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO(paddle mobile)
PackMatrixB_4c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif
......
......@@ -31,7 +31,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO()
// AddDot4x8 used only for aarch32
#else
const int8_t *a_ptr, *b_ptr;
a_ptr = a;
......@@ -249,7 +249,7 @@ void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO
// AddDot4x2 used only for aarch32
#else
#define PADDLE_LABEL_LOOP "1"
#define PADDLE_LABEL_AFTER_LOOP "2"
......@@ -371,12 +371,226 @@ void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
#endif // __ARM_NEON
}
void Gemm::AddDot4x4(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
#define PADDLE_LABEL_LOOP "1"
#define PADDLE_LABEL_AFTER_LOOP "2"
asm volatile(
// load data from matrix a and b,and set zero to result register
"ld1 {v0.16b}, [%[b]], #16\n"
"dup v16.4s, wzr\n"
"ld1 {v4.16b}, [%[a]], #16\n"
"dup v17.4s, wzr\n"
"ld1 {v1.16b}, [%[b]], #16\n"
"dup v18.4s, wzr\n"
"ld1 {v5.16b}, [%[a]], #16\n"
"dup v19.4s, wzr\n"
"ld1 {v2.16b}, [%[b]], #16\n"
"dup v20.4s, wzr\n"
"ld1 {v3.16b}, [%[b]], #16\n"
"dup v21.4s, wzr\n"
"ld1 {v6.16b}, [%[a]], #16\n"
"dup v22.4s, wzr\n"
"ld1 {v7.16b}, [%[a]], #16\n"
"dup v23.4s, wzr\n"
"dup v24.4s, wzr\n"
"dup v25.4s, wzr\n"
"dup v26.4s, wzr\n"
"dup v27.4s, wzr\n"
"dup v28.4s, wzr\n"
"dup v29.4s, wzr\n"
"dup v30.4s, wzr\n"
"dup v31.4s, wzr\n"
// Multiply ldc by 4 == sizeof(int32)
"lsl %[ldc], %[ldc], #2\n"
// first half
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smull v12.8h, v0.8b, v5.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
"subs %[k], %[k], #16\n"
// skip the loop
"beq " PADDLE_LABEL_AFTER_LOOP "f\n"
// loop
PADDLE_LABEL_LOOP
":\n"
// first half
"sadalp v16.4s, v8.8h\n"
"ld1 {v4.16b}, [%[a]], #16\n"
"smull v8.8h, v0.8b, v6.8b\n"
"sadalp v17.4s, v9.8h\n"
"ld1 {v5.16b}, [%[a]], #16\n"
"smull v9.8h, v1.8b, v6.8b\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"smull v11.8h, v3.8b, v6.8b\n"
"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v0.8b, v7.8b\n"
"sadalp v21.4s, v13.8h\n"
"smull v13.8h, v1.8b, v7.8b\n"
"sadalp v22.4s, v14.8h\n"
"smull v14.8h, v2.8b, v7.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"ld1 {v6.16b}, [%[a]], #16\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"ld1 {v0.16b}, [%[b]], #16\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"ld1 {v1.16b}, [%[b]], #16\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"ld1 {v2.16b}, [%[b]], #16\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"ld1 {v3.16b}, [%[b]], #16\n"
// first half
"sadalp v24.4s, v8.8h\n"
"smull v8.8h, v0.8b, v4.8b\n"
"sadalp v25.4s, v9.8h\n"
"ld1 {v7.16b}, [%[a]], #16\n"
"smull v9.8h, v1.8b, v4.8b\n"
"sadalp v26.4s, v10.8h\n"
"smull v10.8h, v2.8b, v4.8b\n"
"sadalp v27.4s, v11.8h\n"
"smull v11.8h, v3.8b, v4.8b\n"
"sadalp v28.4s, v12.8h\n"
"smull v12.8h, v0.8b, v5.8b\n"
"sadalp v29.4s, v13.8h\n"
"smull v13.8h, v1.8b, v5.8b\n"
"sadalp v30.4s, v14.8h\n"
"smull v14.8h, v2.8b, v5.8b\n"
"sadalp v31.4s, v15.8h\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
// Loop
"subs %[k], %[k], #16\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
"bne " PADDLE_LABEL_LOOP "b\n"
// Final
PADDLE_LABEL_AFTER_LOOP
":\n"
// first half
"sadalp v16.4s, v8.8h\n"
"smull v8.8h, v0.8b, v6.8b\n"
"sadalp v17.4s, v9.8h\n"
"smull v9.8h, v1.8b, v6.8b\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"smull v11.8h, v3.8b, v6.8b\n"
"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v0.8b, v7.8b\n"
"sadalp v21.4s, v13.8h\n"
"smull v13.8h, v1.8b, v7.8b\n"
"sadalp v22.4s, v14.8h\n"
"smull v14.8h, v2.8b, v7.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v24.4s, v8.8h\n"
"sadalp v25.4s, v9.8h\n"
"sadalp v26.4s, v10.8h\n"
"sadalp v27.4s, v11.8h\n"
"sadalp v28.4s, v12.8h\n"
"sadalp v29.4s, v13.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"
// Reduce 32bit accumulators horizontally.
"addp v0.4s, v16.4s, v17.4s\n"
"addp v1.4s, v18.4s, v19.4s\n"
"addp v2.4s, v20.4s, v21.4s\n"
"addp v3.4s, v22.4s, v23.4s\n"
"addp v4.4s, v24.4s, v25.4s\n"
"addp v5.4s, v26.4s, v27.4s\n"
"addp v6.4s, v28.4s, v29.4s\n"
"addp v7.4s, v30.4s, v31.4s\n"
// Reduce 32bit accumulators horizontally, second pass
// (each pass adds pairwise. we need to add 4-wise).
"addp v12.4s, v0.4s, v1.4s\n"
"addp v13.4s, v2.4s, v3.4s\n"
"addp v14.4s, v4.4s, v5.4s\n"
"addp v15.4s, v6.4s, v7.4s\n"
"st1 {v12.4s}, [%[c]], %[ldc] \n\t"
"st1 {v13.4s}, [%[c]], %[ldc] \n\t"
"st1 {v14.4s}, [%[c]], %[ldc] \n\t"
"st1 {v15.4s}, [%[c]] \n\t"
: [k] "+r"(k), [a] "+r"(a), [b] "+r"(b), [c] "+r"(c) // outputs
: [ldc] "r"(ldc) // inputs
: "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31"); // clobbers
#undef PADDLE_LABEL_AFTER_LOOP
#undef PADDLE_LABEL_LOOP
#else
// AddDot4x2 used only for aarch64
#endif // __aarch64__
#endif // __ARM_NEON
}
// 8 bits int small block inner product
void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO
// AddDot6x8 used only for aarch32
#else
const int8_t *a_ptr, *b_ptr;
a_ptr = a;
......@@ -681,10 +895,8 @@ void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
for (int32_t j = 0; j < nc; j += NR_INT8) {
for (int32_t i = 0; i < mc; i += MR_INT8) {
#if __aarch64__
// TODO
AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif // __aarch64__
}
......@@ -704,10 +916,8 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha,
for (int32_t j = 0; j < nc; j += NR_INT8) {
for (int32_t i = 0; i < mc; i += MR_INT8) {
#if __aarch64__
// TODO
AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif // __aarch64__
}
......@@ -730,6 +940,149 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha,
int32_t *c, int32_t *C, int32_t ldc, bool relu,
int32_t *bias, bool addOnRow) {}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) {
const int8_t *a0, *a1, *a2, *a3;
for (int32_t i = 0; i < m - m_tail; i += 4) {
a0 = A + i * lda;
a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda;
for (int32_t j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
}
}
if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda;
a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
}
}
}
// 8 bits int PackMatrixA_6r
void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
for (int32_t i = 0; i < i_length; i += 6) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
const int8_t *a4 = A + (i + 4) * lda;
const int8_t *a5 = A + (i + 5) * lda;
int8_t *local_buffer = buffer + i * k;
for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
const int8_t *a4 = a0 + 4 * lda;
const int8_t *a5 = a0 + 5 * lda;
int8_t *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
case 4:
a4 = zero_int8;
case 5:
a5 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
for (int32_t j = 0; j < j_length; j += 8) {
int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
// PackMatrixB_8c used only for aarch32
#else
asm volatile(
// "pld [%[b0]] \n\t"
"vld1.s8 {d0}, [%[b0]] \n\t"
"vst1.s8 {d0}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "q0");
#endif // __aarch64__
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j_length);
for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int32_t j = n; j < j_length + 8; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
......@@ -746,7 +1099,19 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO
asm volatile(
"ld1 {v0.16b}, [%[a0]], #16 \n\t"
"ld1 {v1.16b}, [%[a1]], #16 \n\t"
"ld1 {v2.16b}, [%[a2]], #16 \n\t"
"ld1 {v3.16b}, [%[a3]], #16 \n\t"
"st1 {v0.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v1.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v2.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v3.16b}, [%[local_buffer]], #16 \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "v0", "v1", "v2", "v3");
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -826,7 +1191,19 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO
asm volatile(
"ld1 {v0.16b}, [%[a0]], #16 \n\t"
"ld1 {v1.16b}, [%[a1]], #16 \n\t"
"ld1 {v2.16b}, [%[a2]], #16 \n\t"
"ld1 {v3.16b}, [%[a3]], #16 \n\t"
"st1 {v0.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v1.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v2.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v3.16b}, [%[local_buffer]], #16 \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "v0", "v1", "v2", "v3");
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -887,103 +1264,6 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
}
}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) {
const int8_t *a0, *a1, *a2, *a3;
for (int32_t i = 0; i < m - m_tail; i += 4) {
a0 = A + i * lda;
a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda;
for (int32_t j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
}
}
if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda;
a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
}
}
}
// 8 bits int PackMatrixA_6r
void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
for (int32_t i = 0; i < i_length; i += 6) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
const int8_t *a4 = A + (i + 4) * lda;
const int8_t *a5 = A + (i + 5) * lda;
int8_t *local_buffer = buffer + i * k;
for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
const int8_t *a4 = a0 + 4 * lda;
const int8_t *a5 = a0 + 5 * lda;
int8_t *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
case 4:
a4 = zero_int8;
case 5:
a5 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
......@@ -1052,46 +1332,79 @@ void Gemm::PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail,
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer) {
void Gemm::PackMatrixB_4c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
for (int32_t j = 0; j < j_length; j += 8) {
int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
// TODO
#else
asm volatile(
// "pld [%[b0]] \n\t"
"vld1.s8 {d0}, [%[b0]] \n\t"
"vst1.s8 {d0}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "q0");
#endif // __aarch64__
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
for (int32_t j = 0; j < n; j += 4) {
int8_t *local_buffer = buffer + j * KC;
const int8_t *b0 = &B(0, j);
const int8_t *b1 = b0 + 1;
const int8_t *b2 = b0 + 2;
const int8_t *b3 = b0 + 3;
if (j > j_length) {
switch (n_tail) {
case 1:
b1 = zero_int8;
case 2:
b2 = zero_int8;
case 3:
b3 = zero_int8;
break;
default:
break;
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j_length);
for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
for (int32_t i = 0; i < k_count; ++i) {
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = n; j < j_length + 8; ++j) {
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b2;
b2 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b3;
b3 += ldb;
}
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b2;
b2 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b3;
b3 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
......@@ -1104,7 +1417,35 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO
int32_t nc1 = nc / 4;
int32_t _nc1 = nc % 4;
int32_t *c_ptr, *C_ptr;
int32x4_t cv;
for (int32_t i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
for (int32_t j = 0; j < nc1; ++j) {
cv = vld1q_s32(c_ptr);
vst1q_s32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_s32(c_ptr);
if (_nc1 >= 1) {
vst1q_lane_s32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_s32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_s32(C_ptr, cv, 2);
}
}
}
#else
int32_t nc1 = nc >> 4;
int32_t _nc1 = nc & 15;
......@@ -1168,7 +1509,67 @@ void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON
#if __aarch64__
// TODO
int32_t nc1 = nc >> 3;
int32_t _nc1 = nc & 7;
int32_t *c_ptr;
int8_t *C_ptr;
int32x4_t cv0;
int32x4_t cv1;
int16x8_t cv_h;
int8x8_t cv_b;
int32x4_t biasv;
int8_t min = -127;
int8x8_t minv = vdup_n_s8(min);
for (int32_t i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_s32(bias + i);
for (int32_t j = 0; j < nc1; ++j) {
cv0 = vld1q_s32(c_ptr);
cv1 = vld1q_s32(c_ptr + 4);
cv0 = vqaddq_s32(cv0, biasv);
cv1 = vqaddq_s32(cv1, biasv);
cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1));
cv_b = vqmovn_s16(cv_h);
cv_b = vmax_s8(cv_b, minv);
vst1_s8(C_ptr, cv_b);
c_ptr += 8;
C_ptr += 8;
}
if (_nc1 != 0) {
cv0 = vld1q_s32(c_ptr);
cv1 = vld1q_s32(c_ptr + 4);
cv0 = vqaddq_s32(cv0, biasv);
cv1 = vqaddq_s32(cv1, biasv);
cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1));
cv_b = vqmovn_s16(cv_h);
cv_b = vmax_s8(cv_b, minv);
switch (_nc1) {
case 7:
vst1_lane_s8(C_ptr + 6, cv_b, 6);
case 6:
vst1_lane_s8(C_ptr + 5, cv_b, 5);
case 5:
vst1_lane_s8(C_ptr + 4, cv_b, 4);
case 4:
vst1_lane_s8(C_ptr + 3, cv_b, 3);
case 3:
vst1_lane_s8(C_ptr + 2, cv_b, 2);
case 2:
vst1_lane_s8(C_ptr + 1, cv_b, 1);
case 1:
vst1_lane_s8(C_ptr, cv_b, 0);
default:
break;
}
}
}
#else
int8_t narrow = -128;
int32_t nc1 = nc >> 3;
......@@ -1291,7 +1692,74 @@ void Gemm::WriteWithAddScaleT(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON
#if __aarch64__
// TODO
int32_t nc1 = nc >> 3;
int32_t _nc1 = nc & 7;
int32_t *c_ptr;
int8_t *C_ptr;
int32x4_t cv0;
int32x4_t cv1;
int16x8_t cv_h;
int8x8_t cv_b;
int32_t *bias_ptr;
int32x4_t biasv0;
int32x4_t biasv1;
int8_t min = -127;
int8x8_t minv = vdup_n_s8(min);
for (int32_t i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
bias_ptr = bias;
for (int32_t j = 0; j < nc1; ++j) {
cv0 = vld1q_s32(c_ptr);
cv1 = vld1q_s32(c_ptr + 4);
biasv0 = vld1q_s32(bias_ptr);
biasv1 = vld1q_s32(bias_ptr + 4);
cv0 = vqaddq_s32(cv0, biasv0);
cv1 = vqaddq_s32(cv1, biasv1);
cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1));
cv_b = vqmovn_s16(cv_h);
cv_b = vmax_s8(cv_b, minv);
vst1_s8(C_ptr, cv_b);
c_ptr += 8;
C_ptr += 8;
bias_ptr += 8;
}
if (_nc1 != 0) {
cv0 = vld1q_s32(c_ptr);
cv1 = vld1q_s32(c_ptr + 4);
biasv0 = vld1q_s32(bias_ptr);
biasv1 = vld1q_s32(bias_ptr + 4);
cv0 = vqaddq_s32(cv0, biasv0);
cv1 = vqaddq_s32(cv1, biasv1);
cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1));
cv_b = vqmovn_s16(cv_h);
cv_b = vmax_s8(cv_b, minv);
switch (_nc1) {
case 7:
vst1_lane_s8(C_ptr + 6, cv_b, 6);
case 6:
vst1_lane_s8(C_ptr + 5, cv_b, 5);
case 5:
vst1_lane_s8(C_ptr + 4, cv_b, 4);
case 4:
vst1_lane_s8(C_ptr + 3, cv_b, 3);
case 3:
vst1_lane_s8(C_ptr + 2, cv_b, 2);
case 2:
vst1_lane_s8(C_ptr + 1, cv_b, 1);
case 1:
vst1_lane_s8(C_ptr, cv_b, 0);
default:
break;
}
}
}
#else
int8_t narrow = -128;
int32_t nc1 = nc >> 3;
......@@ -1414,7 +1882,67 @@ void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON
#if __aarch64__
// TODO
int32_t nc1 = nc >> 3;
int32_t _nc1 = nc & 7;
int32_t *c_ptr;
int8_t *C_ptr;
int32x4_t cv0;
int32x4_t cv1;
int16x8_t cv_h;
int8x8_t cv_b;
int32x4_t biasv;
int32x4_t zero = vdupq_n_s32(0);
for (int32_t i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_s32(bias + i);
for (int32_t j = 0; j < nc1; ++j) {
cv0 = vld1q_s32(c_ptr);
cv1 = vld1q_s32(c_ptr + 4);
cv0 = vqaddq_s32(cv0, biasv);
cv1 = vqaddq_s32(cv1, biasv);
cv0 = vmaxq_s32(cv0, zero);
cv1 = vmaxq_s32(cv1, zero);
cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1));
cv_b = vqmovn_s16(cv_h);
vst1_s8(C_ptr, cv_b);
c_ptr += 8;
C_ptr += 8;
}
if (_nc1 != 0) {
cv0 = vld1q_s32(c_ptr);
cv1 = vld1q_s32(c_ptr + 4);
cv0 = vqaddq_s32(cv0, biasv);
cv1 = vqaddq_s32(cv1, biasv);
cv0 = vmaxq_s32(cv0, zero);
cv1 = vmaxq_s32(cv1, zero);
cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1));
cv_b = vqmovn_s16(cv_h);
switch (_nc1) {
case 7:
vst1_lane_s8(C_ptr + 6, cv_b, 6);
case 6:
vst1_lane_s8(C_ptr + 5, cv_b, 5);
case 5:
vst1_lane_s8(C_ptr + 4, cv_b, 4);
case 4:
vst1_lane_s8(C_ptr + 3, cv_b, 3);
case 3:
vst1_lane_s8(C_ptr + 2, cv_b, 2);
case 2:
vst1_lane_s8(C_ptr + 1, cv_b, 1);
case 1:
vst1_lane_s8(C_ptr, cv_b, 0);
default:
break;
}
}
}
#else
int32_t zero = 0;
int32_t nc1 = nc >> 3;
......
......@@ -37,7 +37,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
// TODO
// PackMatrixB_omp_8c used only for aarch32
#else
asm volatile(
// "pld [%[b0]] \n\t"
......@@ -133,7 +133,19 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO
asm volatile(
"ld1 {v0.16b}, [%[a0]], #16 \n\t"
"ld1 {v1.16b}, [%[a1]], #16 \n\t"
"ld1 {v2.16b}, [%[a2]], #16 \n\t"
"ld1 {v3.16b}, [%[a3]], #16 \n\t"
"st1 {v0.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v1.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v2.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v3.16b}, [%[local_buffer]], #16 \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "v0", "v1", "v2", "v3");
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -213,7 +225,19 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO
asm volatile(
"ld1 {v0.16b}, [%[a0]], #16 \n\t"
"ld1 {v1.16b}, [%[a1]], #16 \n\t"
"ld1 {v2.16b}, [%[a2]], #16 \n\t"
"ld1 {v3.16b}, [%[a3]], #16 \n\t"
"st1 {v0.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v1.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v2.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v3.16b}, [%[local_buffer]], #16 \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "v0", "v1", "v2", "v3");
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -343,6 +367,87 @@ void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_omp_4c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t j = 0; j < n; j += 4) {
int8_t *local_buffer = buffer + j * KC;
const int8_t *b0 = &B(0, j);
const int8_t *b1 = b0 + 1;
const int8_t *b2 = b0 + 2;
const int8_t *b3 = b0 + 3;
if (j > j_length) {
switch (n_tail) {
case 1:
b1 = zero_int8;
case 2:
b2 = zero_int8;
case 3:
b3 = zero_int8;
break;
default:
break;
}
}
for (int32_t i = 0; i < k_count; ++i) {
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b2;
b2 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b3;
b3 += ldb;
}
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b2;
b2 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b3;
b3 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -174,7 +174,7 @@ int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr,
int lda = k;
int ldb = n;
int ldc = n;
float scale = 0.00628f;
float scale = 1;
default_random_engine e;
uniform_int_distribution<int8_t> pixel(-127, 127);
int8_t *a = static_cast<int8_t *>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册