From 5e882269f818b7272a61f83de42f1b7e694551f3 Mon Sep 17 00:00:00 2001 From: ZhenWang Date: Tue, 4 Dec 2018 17:33:57 +0800 Subject: [PATCH] add int8_t type sgemm_omp --- src/operators/math/gemm.cpp | 1 + src/operators/math/gemm.h | 145 ++++++++++++++++++++-- src/operators/math/gemm_omp_int8.cpp | 124 ------------------ src/operators/math/math_function_int8.cpp | 12 +- test/common/test_gemm_int8_accuracy.cpp | 5 +- test/operators/test_mul_op.cpp | 2 +- 6 files changed, 145 insertions(+), 144 deletions(-) diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 297ca2538d..ae324dbfd3 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3147,6 +3147,7 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, } // 32位 float 矩阵乘法 +template <> void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *bias) { diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index bccddffa56..61e957100b 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -16,6 +16,9 @@ limitations under the License. */ #include #include "common/log.h" #include "memory/t_malloc.h" +#ifdef _OPENMP +#include +#endif // 矩阵取值运算宏,假设矩阵按行存储 #define A(i, j) A[(i)*lda + (j)] @@ -172,11 +175,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, const float *B, int ldb, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); - // 32位 float 矩阵乘法(openmp 多线程版本) - void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, - bool relu, float *bias); - // 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本) void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, @@ -228,6 +226,14 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits int matrix product template + void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A, + int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C, + int32_t ldc, bool relu, Btype *bias); + template + void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + Otype *C, int32_t ldc, bool relu, int32_t *bias); + template void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A, int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C, int32_t ldc, bool relu, Btype *bias); @@ -235,10 +241,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, float beta, Otype *C, int32_t ldc, bool relu, int32_t *bias); - - void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, float beta, - int32_t *C, int32_t ldc, bool relu, int32_t *bias); // 8 bits int write back // C = A * B void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); @@ -332,6 +334,131 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, paddle_mobile::memory::Free(zero_int8); } +// 8 bits int matrix product (m*k x k*n), omp version +template +void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, + const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, + float beta, Otype *C, int32_t ldc, bool relu, + int32_t *bias) { +#ifdef _OPENMP + int32_t max_threads = omp_get_max_threads(); +#else + int32_t max_threads = 1; +#endif + + int32_t L1 = 64 / max_threads * 1024; + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(int8_t)); + if (MC == 0) { + MC = MR_INT8; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; + } + // 补齐 B + NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8; + + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); +#if __aarch64__ + // TODO() +#else + PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); +#endif + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(int8_t)); + if (NC == 0) { + NC = NR_INT8; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; + } + // 补齐 A + MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; + + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); +#if __aarch64__ + // TODO() +#else + PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); +#endif + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); + } + packedC_int32 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); + + if (m > n) { +#pragma omp parallel for + for (int32_t i = 0; i < m; i += MC) { +#ifdef _OPENMP + int32_t local_threads = omp_get_thread_num(); +#else + int32_t local_threads = 0; +#endif + + int32_t mc; + mc = s_min(m - i, MC); + int8_t *local_A = packedA_int8 + MC * KC * local_threads; + int32_t *local_C = packedC_int32 + MC * NC * local_threads; +#if __aarch64__ + // TODO() +#else + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); +#endif + if (bias == nullptr) { + InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C, + &C(i, 0), ldc, relu); + } else { + InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C, + &C(i, 0), ldc, relu, bias + i); + } + } + } else { +#pragma omp parallel for + for (int32_t j = 0; j < n; j += NC) { +#ifdef _OPENMP + int32_t local_threads = omp_get_thread_num(); +#else + int32_t local_threads = 0; +#endif + int32_t nc; + nc = s_min(n - j, NC); + int8_t *local_B = packedB_int8 + KC * NC * local_threads; + int32_t *local_C = packedC_int32 + MC * NC * local_threads; +#if __aarch64__ + // TODO() +#else + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); +#endif + if (bias == nullptr) { + InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C, + &C(0, j), ldc, relu); + } else { + InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C, + &C(0, j), ldc, relu, bias); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int32); + paddle_mobile::memory::Free(zero_int8); +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm_omp_int8.cpp b/src/operators/math/gemm_omp_int8.cpp index 8203d27594..61f0be418f 100644 --- a/src/operators/math/gemm_omp_int8.cpp +++ b/src/operators/math/gemm_omp_int8.cpp @@ -27,130 +27,6 @@ namespace paddle_mobile { namespace operators { namespace math { -// 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, - const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, - float beta, int32_t *C, int32_t ldc, bool relu, - int32_t *bias) { -#ifdef _OPENMP - int32_t max_threads = omp_get_max_threads(); -#else - int32_t max_threads = 1; -#endif - - int32_t L1 = 64 / max_threads * 1024; - const int32_t k_complete = (k + 15) - ((k + 15) & 15); - KC = k_complete; - zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); - memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); - if (m > n) { - // 对 A 分块 - MC = L1 / (KC * sizeof(int8_t)); - if (MC == 0) { - MC = MR_INT8; - } else { - int32_t mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; - } - // 补齐 B - NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8; - - packedB_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); -#if __aarch64__ - // TODO -#else - PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); -#endif - packedA_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); - } else { - // 对 B 分块 - NC = L1 / (KC * sizeof(int8_t)); - if (NC == 0) { - NC = NR_INT8; - } else { - int32_t nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; - } - // 补齐 A - MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; - - packedA_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); -#if __aarch64__ - // TODO -#else - PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); -#endif - packedB_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); - } - packedC_int32 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); - - if (m > n) { -#pragma omp parallel for - for (int32_t i = 0; i < m; i += MC) { -#ifdef _OPENMP - int32_t local_threads = omp_get_thread_num(); -#else - int32_t local_threads = 0; -#endif - - int32_t mc; - mc = s_min(m - i, MC); - int8_t *local_A = packedA_int8 + MC * KC * local_threads; - int32_t *local_C = packedC_int32 + MC * NC * local_threads; -#if __aarch64__ - // TODO -#else - PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); -#endif - // InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, - // local_C, - // &C(i, 0), ldc, relu, bias + i); - if (bias == nullptr) { - InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C, - &C(i, 0), ldc, relu); - } - } - } else { -#pragma omp parallel for - for (int32_t j = 0; j < n; j += NC) { -#ifdef _OPENMP - int32_t local_threads = omp_get_thread_num(); -#else - int32_t local_threads = 0; -#endif - int32_t nc; - nc = s_min(n - j, NC); - int8_t *local_B = packedB_int8 + KC * NC * local_threads; - int32_t *local_C = packedC_int32 + MC * NC * local_threads; -#if __aarch64__ - // TODO -#else - PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); -#endif - // InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, - // local_C, - // &C(0, j), ldc, relu, bias); - if (bias == nullptr) { - InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C, - &C(0, j), ldc, relu); - } - } - } - - paddle_mobile::memory::Free(packedA_int8); - paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int32); - paddle_mobile::memory::Free(zero_int8); -} - void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer) { const int32_t j_length = n - n_tail; diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp index fe6b05ae1a..a407a2915d 100644 --- a/src/operators/math/math_function_int8.cpp +++ b/src/operators/math/math_function_int8.cpp @@ -54,9 +54,8 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, #ifdef _OPENMP if (bias != nullptr) { - // TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. - gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); } else { gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, bias); @@ -73,10 +72,9 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, } else { #ifdef _OPENMP if (bias != nullptr) { - // TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. - gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), - N, relu, bias); + gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); } else { gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp index 6e2d838955..a1920ba2bb 100644 --- a/test/common/test_gemm_int8_accuracy.cpp +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -201,9 +201,8 @@ int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr) { paddle_mobile::operators::math::Gemm gemm; #ifdef _OPENMP - // TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. - gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast(0), c, ldc, - relu, bias); + gemm.Sgemm_omp(m, n, k, scale, a, lda, b, ldb, static_cast(0), c, ldc, + relu, bias); #else gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast(0), c, ldc, relu, bias); diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 2734bbeace..99a2219749 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -95,7 +95,7 @@ int TestMulOP() { int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(8); + paddle_mobile.SetThreadNum(4); paddle_mobile::TestMulOP(); paddle_mobile::TestMulOP(); return 0; -- GitLab