diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index ea023bc134033aee6577ebf06c95f2a762d08bca..8498992fcecbcb2c9a773fba874e108c013a04fc 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -209,12 +209,18 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, int32_t lda, int8_t *buffer); void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); + void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); + void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); // 8 bits int matrix product void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, int32_t ldc, bool relu, int8_t *bias); - + void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, + int32_t *C, int32_t ldc, bool relu, int8_t *bias); // 8 bits int write back // C = alpha * A * B + beta * C void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index 5dd8a7c3131543f426f32e258efb3181be9b2f61..b16db7fe6acf0c3c7fb2902c9fb3f6e3dc81a65f 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -30,7 +30,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO +// TODO(wzzju) #else const int8_t *a_ptr, *b_ptr; a_ptr = a; @@ -246,7 +246,7 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO +// TODO(wzzju) #else const int8_t *a_ptr, *b_ptr; a_ptr = a; @@ -546,8 +546,12 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, #pragma omp parallel for for (int32_t j = 0; j < nc; j += NR) { for (int32_t i = 0; i < mc; i += MR_INT8) { +#if __aarch64__ + // TODO(wzzju) +#else // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#endif // __aarch64__ } } if (alpha != 1) { @@ -682,7 +686,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, const int8_t *b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ - // TODO + // TODO(wzzju) #else asm volatile( // "pld [%[b0]] \n\t" @@ -791,7 +795,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO +// TODO(wzzju) #else int32_t nc1 = nc >> 4; int32_t _nc1 = nc & 15; diff --git a/src/operators/math/gemm_omp_int8.cpp b/src/operators/math/gemm_omp_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..21256cccfcc6dcc647f34a2129616b70804d398f --- /dev/null +++ b/src/operators/math/gemm_omp_int8.cpp @@ -0,0 +1,235 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "common/log.h" +#include "memory/t_malloc.h" +#include "operators/math/gemm.h" +#if __ARM_NEON +#include +#endif +#ifdef _OPENMP +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +// 8 bits int matrix product (m*k x k*n) +void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, + const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, + int8_t beta, int32_t *C, int32_t ldc, bool relu, + int8_t *bias) { +#ifdef _OPENMP + int32_t max_threads = omp_get_max_threads(); +#else + int32_t max_threads = 1; +#endif + + int32_t L1 = 64 / max_threads * 1024; + KC = k; + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); + memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(int8_t)); + if (MC == 0) { + MC = MR_INT8; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; + } + // 补齐 B + NC = (n + NR - 1) / NR * NR; + + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); +#if __aarch64__ + // TODO(wzzju) +#else + PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8); +#endif + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(int8_t)); + if (NC == 0) { + NC = NR; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // 补齐 A + MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; + + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); +#if __aarch64__ + // TODO(wzzju) +#else + PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8); +#endif + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); + } + packedC_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); + + if (m > n) { +#pragma omp parallel for + for (int32_t i = 0; i < m; i += MC) { +#ifdef _OPENMP + int32_t local_threads = omp_get_thread_num(); +#else + int32_t local_threads = 0; +#endif + + int32_t mc; + mc = s_min(m - i, MC); + int8_t *local_A = packedA_int8 + MC * KC * local_threads; + int32_t *local_C = packedC_int8 + MC * NC * local_threads; +#if __aarch64__ + // TODO(wzzju) +#else + PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A); +#endif + InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C, + &C(i, 0), ldc, relu, bias + i); + } + } else { +#pragma omp parallel for + for (int32_t j = 0; j < n; j += NC) { +#ifdef _OPENMP + int32_t local_threads = omp_get_thread_num(); +#else + int32_t local_threads = 0; +#endif + int32_t nc; + nc = s_min(n - j, NC); + int8_t *local_B = packedB_int8 + KC * NC * local_threads; + int32_t *local_C = packedC_int8 + MC * NC * local_threads; +#if __aarch64__ + // TODO(wzzju) +#else + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B); +#endif + InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C, + &C(0, j), ldc, relu, bias); + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int8); + paddle_mobile::memory::Free(zero_int8); +} + +void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; +#pragma omp parallel for + for (int32_t j = 0; j < j_length; j += NR) { + int8_t *local_buffer = buffer + j * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j); +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + // "pld [%[b0]] \n\t" + "vld1.s8 {d0}, [%[b0]] \n\t" + "vst1.s8 {d0}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "q0"); +#endif // __aarch64__ +#else + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; +#endif // __ARM_NEON + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j_length); + for (int32_t j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int32_t j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer) { + const int i_length = m - m_tail; +#pragma omp parallel for + for (int32_t i = 0; i < i_length; i += MR_INT8) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + int8_t *local_buffer = buffer + i * k; + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + } + } + + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + int8_t *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + } + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp index 70677223d12ded2da07ab53bc371f1e8da9fe293..e02824b290ebc0080613e2ae2365626d79576c9e 100644 --- a/src/operators/math/math_function_int8.cpp +++ b/src/operators/math/math_function_int8.cpp @@ -51,12 +51,23 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, } } +#ifdef _OPENMP + gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); +#else gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, bias); +#endif } else { +#ifdef _OPENMP + gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); +#else gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, bias); +#endif } } } // namespace math diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp index 80ddd40e121c81032c903955bd7116cf52695569..87f8d945648577ef1414417b57f4013d288dc043 100644 --- a/test/common/test_gemm_int8_accuracy.cpp +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -20,6 +20,9 @@ limitations under the License. */ #include "common/log.h" #include "memory/t_malloc.h" #include "operators/math/gemm.h" +#ifdef _OPENMP +#include +#endif // _OPENMP #define a(i, j) a[(i)*lda + (j)] #define b(i, j) b[(i)*ldb + (j)] @@ -84,8 +87,13 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { } paddle_mobile::operators::math::Gemm gemm; +#ifdef _OPENMP + gemm.Sgemm_omp(m, n, k, static_cast(1), a, lda, b, ldb, + static_cast(0), c, ldc, relu, nullptr); +#else gemm.Sgemm(m, n, k, static_cast(1), a, lda, b, ldb, static_cast(0), c, ldc, relu, nullptr); +#endif int eq = 0; int neq = 0; for (int i = 0; i < m * n; ++i) { @@ -119,12 +127,17 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { } int main() { - do_sgemm(9, 9, 9, false, 10); +#ifdef _OPENMP + omp_set_num_threads(8); +#endif + do_sgemm(9, 9, 9, false, 1); do_sgemm(10, 6, 12, false, 0); do_sgemm(512, 256, 384, false, 0); do_sgemm(1366, 768, 256, false, 0); do_sgemm(1255, 755, 333, false, 0); - do_sgemm(555, 777, 999, false, 0); + do_sgemm(599, 1133, 393, false, 0); + do_sgemm(777, 555, 999, false, 0); + do_sgemm(333, 797, 939, false, 0); do_sgemm(1024, 1024, 1024, false, 0); return 0; diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 89f0012ae8effaab383719c1b85748c24eb2bf73..14da4ba284b5ac7b0660bd15de871fdf5ed04cdd 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -28,7 +28,7 @@ limitations under the License. */ int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(1); + paddle_mobile.SetThreadNum(8); Tensor aa, bb, cc; auto aaptr = aa.mutable_data({m, k}); auto bbptr = bb.mutable_data({k, n}); diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 10dab2cda1b3c692f42cf8760eb2b48ae6451f39..262ee960e1c777d369d3b510eb31e5ed47b3493c 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -93,6 +93,8 @@ int TestMulOP() { } // namespace paddle_mobile int main() { + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(8); paddle_mobile::TestMulOP(); paddle_mobile::TestMulOP(); return 0;