未验证 提交 f72a124b 编写于 作者: J Jiaying Zhao 提交者: GitHub

add gemm_int8 arm64 version without openmp (#1708)

* add gemm_int8 arm64 version without openmp

* add gemm_int8 arm64 version with openmp
上级 65a18a0f
......@@ -28,7 +28,7 @@ limitations under the License. */
#if __aarch64__
#define MR_INT8 4
#define NR_INT8 2
#define NR_INT8 4
#define MR 6
#define NR 16
#else
......@@ -181,12 +181,15 @@ class Gemm {
std::string mode, float *bias, float *bias1);
// 8 bits function cluster begins
// 8 bits int small block inner product
// 8 bits int small block inner product, data packed k = 1
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bits int small block inner product, data packed k = 16
void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
void AddDot4x4(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bits int inner product
......@@ -203,14 +206,16 @@ class Gemm {
// 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_4c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
......@@ -219,6 +224,8 @@ class Gemm {
const int8_t *A, int32_t lda, int8_t *buffer);
void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
void PackMatrixB_omp_4c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
// 8 bits int matrix product
template <typename Itype, typename Btype, typename Otype>
......@@ -314,7 +321,11 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t mc, nc;
for (int32_t j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
PackMatrixB_4c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8);
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8);
#endif
for (int32_t i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
......@@ -375,7 +386,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO(paddle mobile)
PackMatrixB_omp_4c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif
......@@ -397,7 +408,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__
// TODO(paddle mobile)
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif
......@@ -421,7 +432,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO(paddle mobile)
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
......@@ -451,7 +462,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO(paddle mobile)
PackMatrixB_4c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif
......
此差异已折叠。
......@@ -37,7 +37,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
// TODO
// PackMatrixB_omp_8c used only for aarch32
#else
asm volatile(
// "pld [%[b0]] \n\t"
......@@ -133,7 +133,19 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO
asm volatile(
"ld1 {v0.16b}, [%[a0]], #16 \n\t"
"ld1 {v1.16b}, [%[a1]], #16 \n\t"
"ld1 {v2.16b}, [%[a2]], #16 \n\t"
"ld1 {v3.16b}, [%[a3]], #16 \n\t"
"st1 {v0.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v1.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v2.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v3.16b}, [%[local_buffer]], #16 \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "v0", "v1", "v2", "v3");
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -213,7 +225,19 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO
asm volatile(
"ld1 {v0.16b}, [%[a0]], #16 \n\t"
"ld1 {v1.16b}, [%[a1]], #16 \n\t"
"ld1 {v2.16b}, [%[a2]], #16 \n\t"
"ld1 {v3.16b}, [%[a3]], #16 \n\t"
"st1 {v0.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v1.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v2.16b}, [%[local_buffer]], #16 \n\t"
"st1 {v3.16b}, [%[local_buffer]], #16 \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "v0", "v1", "v2", "v3");
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -343,6 +367,87 @@ void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_omp_4c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t j = 0; j < n; j += 4) {
int8_t *local_buffer = buffer + j * KC;
const int8_t *b0 = &B(0, j);
const int8_t *b1 = b0 + 1;
const int8_t *b2 = b0 + 2;
const int8_t *b3 = b0 + 3;
if (j > j_length) {
switch (n_tail) {
case 1:
b1 = zero_int8;
case 2:
b2 = zero_int8;
case 3:
b3 = zero_int8;
break;
default:
break;
}
}
for (int32_t i = 0; i < k_count; ++i) {
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b2;
b2 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b3;
b3 += ldb;
}
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b2;
b2 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b3;
b3 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -174,7 +174,7 @@ int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr,
int lda = k;
int ldb = n;
int ldc = n;
float scale = 0.00628f;
float scale = 1;
default_random_engine e;
uniform_int_distribution<int8_t> pixel(-127, 127);
int8_t *a = static_cast<int8_t *>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册