提交 5e882269 编写于 作者: Z ZhenWang

add int8_t type sgemm_omp

上级 23571a48
...@@ -3147,6 +3147,7 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, ...@@ -3147,6 +3147,7 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
} }
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
template <>
void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) { bool relu, float *bias) {
......
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include <string> #include <string>
#include "common/log.h" #include "common/log.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#ifdef _OPENMP
#include <omp.h>
#endif
// 矩阵取值运算宏,假设矩阵按行存储 // 矩阵取值运算宏,假设矩阵按行存储
#define A(i, j) A[(i)*lda + (j)] #define A(i, j) A[(i)*lda + (j)]
...@@ -172,11 +175,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -172,11 +175,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
const float *B, int ldb, float *C, int ldc, float *p, const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1); std::string mode, float *bias, float *bias1);
// 32位 float 矩阵乘法(openmp 多线程版本)
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本) // 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本)
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C, int lda, const float *B, int ldb, float beta, float *C,
...@@ -228,6 +226,14 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -228,6 +226,14 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int matrix product // 8 bits int matrix product
template <typename Itype, typename Btype, typename Otype> template <typename Itype, typename Btype, typename Otype>
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A,
int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C,
int32_t ldc, bool relu, Btype *bias);
template <typename Otype>
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
Otype *C, int32_t ldc, bool relu, int32_t *bias);
template <typename Itype, typename Btype, typename Otype>
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A, void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const Itype *A,
int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C, int32_t lda, const Itype *B, int32_t ldb, float beta, Otype *C,
int32_t ldc, bool relu, Btype *bias); int32_t ldc, bool relu, Btype *bias);
...@@ -235,10 +241,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -235,10 +241,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta, Otype *C, int32_t lda, const int8_t *B, int32_t ldb, float beta, Otype *C,
int32_t ldc, bool relu, int32_t *bias); int32_t ldc, bool relu, int32_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
int32_t *C, int32_t ldc, bool relu, int32_t *bias);
// 8 bits int write back // 8 bits int write back
// C = A * B // C = A * B
void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc);
...@@ -332,6 +334,131 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, ...@@ -332,6 +334,131 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
paddle_mobile::memory::Free(zero_int8); paddle_mobile::memory::Free(zero_int8);
} }
// 8 bits int matrix product (m*k x k*n), omp version
template <typename Otype>
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
float beta, Otype *C, int32_t ldc, bool relu,
int32_t *bias) {
#ifdef _OPENMP
int32_t max_threads = omp_get_max_threads();
#else
int32_t max_threads = 1;
#endif
int32_t L1 = 64 / max_threads * 1024;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(int8_t));
if (MC == 0) {
MC = MR_INT8;
} else {
int32_t mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// 补齐 B
NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8;
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO()
#else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) {
NC = NR_INT8;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
}
// 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__
// TODO()
#else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
}
packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int32_t i = 0; i < m; i += MC) {
#ifdef _OPENMP
int32_t local_threads = omp_get_thread_num();
#else
int32_t local_threads = 0;
#endif
int32_t mc;
mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO()
#else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
if (bias == nullptr) {
InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu);
} else {
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
}
}
} else {
#pragma omp parallel for
for (int32_t j = 0; j < n; j += NC) {
#ifdef _OPENMP
int32_t local_threads = omp_get_thread_num();
#else
int32_t local_threads = 0;
#endif
int32_t nc;
nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO()
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif
if (bias == nullptr) {
InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu);
} else {
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -27,130 +27,6 @@ namespace paddle_mobile { ...@@ -27,130 +27,6 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
// 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
float beta, int32_t *C, int32_t ldc, bool relu,
int32_t *bias) {
#ifdef _OPENMP
int32_t max_threads = omp_get_max_threads();
#else
int32_t max_threads = 1;
#endif
int32_t L1 = 64 / max_threads * 1024;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(int8_t));
if (MC == 0) {
MC = MR_INT8;
} else {
int32_t mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// 补齐 B
NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8;
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO
#else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) {
NC = NR_INT8;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
}
// 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__
// TODO
#else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
}
packedC_int32 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int32_t i = 0; i < m; i += MC) {
#ifdef _OPENMP
int32_t local_threads = omp_get_thread_num();
#else
int32_t local_threads = 0;
#endif
int32_t mc;
mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO
#else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
// InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta,
// local_C,
// &C(i, 0), ldc, relu, bias + i);
if (bias == nullptr) {
InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu);
}
}
} else {
#pragma omp parallel for
for (int32_t j = 0; j < n; j += NC) {
#ifdef _OPENMP
int32_t local_threads = omp_get_thread_num();
#else
int32_t local_threads = 0;
#endif
int32_t nc;
nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif
// InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta,
// local_C,
// &C(0, j), ldc, relu, bias);
if (bias == nullptr) {
InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu);
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(zero_int8);
}
void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) { const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail; const int32_t j_length = n - n_tail;
......
...@@ -54,8 +54,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, ...@@ -54,8 +54,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
#ifdef _OPENMP #ifdef _OPENMP
if (bias != nullptr) { if (bias != nullptr) {
// TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias); matrix_out->data<int8_t>(), N, relu, bias);
} else { } else {
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta, gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
...@@ -73,10 +72,9 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, ...@@ -73,10 +72,9 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
} else { } else {
#ifdef _OPENMP #ifdef _OPENMP
if (bias != nullptr) { if (bias != nullptr) {
// TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K, matrix_b.data<int8_t>(), N, beta,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(), matrix_out->data<int8_t>(), N, relu, bias);
N, relu, bias);
} else { } else {
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K, gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_b.data<int8_t>(), N, beta,
......
...@@ -201,8 +201,7 @@ int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr) { ...@@ -201,8 +201,7 @@ int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr) {
paddle_mobile::operators::math::Gemm gemm; paddle_mobile::operators::math::Gemm gemm;
#ifdef _OPENMP #ifdef _OPENMP
// TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. gemm.Sgemm_omp(m, n, k, scale, a, lda, b, ldb, static_cast<float>(0), c, ldc,
gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast<float>(0), c, ldc,
relu, bias); relu, bias);
#else #else
gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast<float>(0), c, ldc, gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast<float>(0), c, ldc,
......
...@@ -95,7 +95,7 @@ int TestMulOP() { ...@@ -95,7 +95,7 @@ int TestMulOP() {
int main() { int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(8); paddle_mobile.SetThreadNum(4);
paddle_mobile::TestMulOP<int8_t, int32_t>(); paddle_mobile::TestMulOP<int8_t, int32_t>();
paddle_mobile::TestMulOP<float, float>(); paddle_mobile::TestMulOP<float, float>();
return 0; return 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册